RasaHQ/rasa_core

View on GitHub
rasa/core/channels/console.py

Summary

Maintainability
A
1 hr
Test Coverage
# this builtin is needed so we can overwrite in test
import json
import logging

import aiohttp
from async_generator import async_generator, yield_
from prompt_toolkit.styles import Style

import questionary
import rasa.cli.utils
from rasa.cli import utils as cliutils
from rasa.core import utils
from rasa.core.channels import UserMessage
from rasa.core.channels.channel import (
    RestInput, button_to_string, element_to_string)
from rasa.core.constants import DEFAULT_SERVER_URL
from rasa.core.interpreter import INTENT_MESSAGE_PREFIX

logger = logging.getLogger(__name__)


def print_bot_output(message, color=rasa.cli.utils.bcolors.OKBLUE):
    if "text" in message:
        rasa.cli.utils.print_color(message.get("text"),
                                   color)

    if "image" in message:
        rasa.cli.utils.print_color("Image: " + message.get("image"),
                                   color)

    if "attachment" in message:
        rasa.cli.utils.print_color("Attachment: " + message.get("attachment"),
                                   color)

    if "buttons" in message:
        rasa.cli.utils.print_color("Buttons:", color)
        for idx, button in enumerate(message.get("buttons")):
            rasa.cli.utils.print_color(button_to_string(button, idx), color)

    if "elements" in message:
        for idx, element in enumerate(message.get("elements")):
            element_str = "Elements:\n" + element_to_string(element, idx)
            rasa.cli.utils.print_color(element_str, color)

    if "quick_replies" in message:
        for idx, element in enumerate(message.get("quick_replies")):
            element_str = "Quick Replies:\n" + button_to_string(element, idx)
            rasa.cli.utils.print_color(element_str, color)


def get_cmd_input():
    response = questionary.text("",
                                qmark="Your input ->",
                                style=Style([('qmark', '#b373d6'),
                                             ('', '#b373d6')])).ask()
    if response is not None:
        return response.strip()
    else:
        return None


async def send_message_receive_block(server_url,
                                     auth_token,
                                     sender_id,
                                     message):
    payload = {
        "sender": sender_id,
        "message": message
    }

    url = "{}/webhooks/rest/webhook?token={}".format(server_url, auth_token)
    async with aiohttp.ClientSession() as session:
        async with session.post(url,
                                json=payload,
                                raise_for_status=True) as resp:
            return await resp.json()


@async_generator  # needed for python 3.5 compatibility
async def send_message_receive_stream(server_url,
                                      auth_token,
                                      sender_id,
                                      message):
    payload = {
        "sender": sender_id,
        "message": message
    }

    url = "{}/webhooks/rest/webhook?stream=true&token={}".format(
        server_url, auth_token)

    # TODO: check if this properly receives UTF-8 data
    async with aiohttp.ClientSession() as session:
        async with session.post(url,
                                json=payload,
                                raise_for_status=True) as resp:

            async for line in resp.content:
                if line:
                    await yield_(json.loads(line.decode("utf-8")))


async def record_messages(server_url=DEFAULT_SERVER_URL,
                          auth_token=None,
                          sender_id=UserMessage.DEFAULT_SENDER_ID,
                          max_message_limit=None,
                          use_response_stream=True):
    """Read messages from the command line and print bot responses."""

    auth_token = auth_token if auth_token else ""

    exit_text = INTENT_MESSAGE_PREFIX + 'stop'

    cliutils.print_success("Bot loaded. Type a message and press enter "
                           "(use '{}' to exit): ".format(exit_text))

    num_messages = 0
    while not utils.is_limit_reached(num_messages, max_message_limit):
        text = get_cmd_input()
        if text == exit_text or text is None:
            break

        if use_response_stream:
            bot_responses = send_message_receive_stream(server_url,
                                                        auth_token,
                                                        sender_id, text)
            async for response in bot_responses:
                print_bot_output(response)
        else:
            bot_responses = await send_message_receive_block(server_url,
                                                             auth_token,
                                                             sender_id, text)
            for response in bot_responses:
                print_bot_output(response)

        num_messages += 1
    return num_messages


class CmdlineInput(RestInput):

    @classmethod
    def name(cls):
        return "cmdline"

    def url_prefix(self):
        return RestInput.name()