rasa/core/channels/socketio.py
import logging
import socketio
import uuid
from sanic import Blueprint, response
from socketio import AsyncServer
from typing import Optional, Text, Any, List, Dict
from rasa.core.channels import InputChannel
from rasa.core.channels.channel import (
UserMessage,
OutputChannel)
logger = logging.getLogger(__name__)
class SocketBlueprint(Blueprint):
def __init__(self, sio: AsyncServer, socketio_path, *args, **kwargs):
self.sio = sio
self.socketio_path = socketio_path
super(SocketBlueprint, self).__init__(*args, **kwargs)
def register(self, app, options):
self.sio.attach(app, self.socketio_path)
super(SocketBlueprint, self).register(app, options)
class SocketIOOutput(OutputChannel):
@classmethod
def name(cls):
return "socketio"
def __init__(self, sio, sid, bot_message_evt):
self.sio = sio
self.sid = sid
self.bot_message_evt = bot_message_evt
async def _send_message(self, socket_id, response):
# type: (Text, Any) -> None
"""Sends a message to the recipient using the bot event."""
await self.sio.emit(self.bot_message_evt, response, room=socket_id)
async def send_text_message(self, recipient_id: Text,
message: Text) -> None:
"""Send a message through this channel."""
await self._send_message(self.sid, {"text": message})
async def send_image_url(self, recipient_id: Text, image_url: Text) -> None:
"""Sends an image. Default will just post the url as a string."""
message = {
"attachment": {
"type": "image",
"payload": {"src": image_url}
}
}
await self._send_message(self.sid, message)
async def send_text_with_buttons(self, recipient_id: Text, text: Text,
buttons: List[Dict[Text, Any]],
**kwargs: Any) -> None:
"""Sends buttons to the output."""
message = {
"text": text,
"quick_replies": []
}
for button in buttons:
message["quick_replies"].append({
"content_type": "text",
"title": button['title'],
"payload": button['payload']
})
await self._send_message(self.sid, message)
async def send_custom_message(self, recipient_id: Text,
elements: List[Dict[Text, Any]]) -> None:
"""Sends elements to the output."""
message = {"attachment": {
"type": "template",
"payload": {
"template_type": "generic",
"elements": elements[0]
}}}
await self._send_message(self.sid, message)
class SocketIOInput(InputChannel):
"""A socket.io input channel."""
@classmethod
def name(cls):
return "socketio"
@classmethod
def from_credentials(cls, credentials):
credentials = credentials or {}
return cls(credentials.get("user_message_evt", "user_uttered"),
credentials.get("bot_message_evt", "bot_uttered"),
credentials.get("namespace"),
credentials.get("session_persistence", False),
credentials.get("socketio_path", "/socket.io"),
)
def __init__(self,
user_message_evt: Text = "user_uttered",
bot_message_evt: Text = "bot_uttered",
namespace: Optional[Text] = None,
session_persistence: bool = False,
socketio_path: Optional[Text] = '/socket.io'
):
self.bot_message_evt = bot_message_evt
self.session_persistence = session_persistence
self.user_message_evt = user_message_evt
self.namespace = namespace
self.socketio_path = socketio_path
def blueprint(self, on_new_message):
sio = AsyncServer(async_mode='sanic')
socketio_webhook = SocketBlueprint(sio, self.socketio_path,
'socketio_webhook', __name__)
@socketio_webhook.route("/", methods=['GET'])
async def health(request):
return response.json({"status": "ok"})
@sio.on('connect', namespace=self.namespace)
async def connect(sid, environ):
logger.debug("User {} connected to socketIO endpoint.".format(sid))
@sio.on('disconnect', namespace=self.namespace)
async def disconnect(sid):
logger.debug("User {} disconnected from socketIO endpoint."
"".format(sid))
@sio.on('session_request', namespace=self.namespace)
async def session_request(sid, data):
if data is None:
data = {}
if 'session_id' not in data or data['session_id'] is None:
data['session_id'] = uuid.uuid4().hex
await sio.emit("session_confirm", data['session_id'], room=sid)
logger.debug("User {} connected to socketIO endpoint."
"".format(sid))
@sio.on(self.user_message_evt, namespace=self.namespace)
async def handle_message(sid, data):
output_channel = SocketIOOutput(sio, sid, self.bot_message_evt)
if self.session_persistence:
if not data.get("session_id"):
logger.warning("A message without a valid sender_id "
"was received. This message will be "
"ignored. Make sure to set a proper "
"session id using the "
"`session_request` socketIO event.")
return
sender_id = data['session_id']
else:
sender_id = sid
message = UserMessage(data['message'], output_channel, sender_id,
input_channel=self.name())
await on_new_message(message)
return socketio_webhook