RasaHQ/rasa_core

View on GitHub
rasa/core/run.py

Summary

Maintainability
A
3 hrs
Test Coverage
import asyncio
from functools import partial

import argparse
import logging
from sanic import Sanic
from sanic_cors import CORS
from typing import List, Optional, Text

import rasa.core.cli.arguments
import rasa.utils

import rasa.core
from rasa.core import constants, utils, cli
from rasa.core.channels import (BUILTIN_CHANNELS, InputChannel, console)
from rasa.core.interpreter import NaturalLanguageInterpreter
from rasa.core.tracker_store import TrackerStore
from rasa.core.utils import AvailableEndpoints, read_yaml_file

logger = logging.getLogger()  # get the root logger


def create_argument_parser():
    """Parse all the command line arguments for the run script."""

    parser = argparse.ArgumentParser(
        description='starts the bot')
    parser.add_argument(
        '-d', '--core',
        required=True,
        type=str,
        help="core model to run")
    parser.add_argument(
        '-u', '--nlu',
        type=str,
        help="nlu model to run")

    cli.arguments.add_logging_option_arguments(parser)
    cli.run.add_run_arguments(parser)
    return parser


def create_http_input_channels(
        channel: Optional[Text],
        credentials_file: Optional[Text]
) -> List['InputChannel']:
    """Instantiate the chosen input channel."""

    if credentials_file:
        all_credentials = read_yaml_file(credentials_file)
    else:
        all_credentials = {}

    if channel:
        return [_create_single_channel(channel, all_credentials.get(channel))]
    else:
        return [_create_single_channel(c, k)
                for c, k in all_credentials.items()]


def _create_single_channel(channel, credentials):
    from rasa.core.channels import BUILTIN_CHANNELS

    if channel in BUILTIN_CHANNELS:
        return BUILTIN_CHANNELS[channel].from_credentials(credentials)
    else:
        # try to load channel based on class name
        try:
            input_channel_class = utils.class_from_module_path(channel)
            return input_channel_class.from_credentials(credentials)
        except (AttributeError, ImportError):
            raise Exception(
                "Failed to find input channel class for '{}'. Unknown "
                "input channel. Check your credentials configuration to "
                "make sure the mentioned channel is not misspelled. "
                "If you are creating your own channel, make sure it "
                "is a proper name of a class in a module.".format(channel))


def configure_app(input_channels=None,
                  cors=None,
                  auth_token=None,
                  enable_api=True,
                  jwt_secret=None,
                  jwt_method=None,
                  route="/webhooks/",
                  port=None):
    """Run the agent."""
    from rasa.core import server

    if enable_api:
        app = server.create_app(cors_origins=cors,
                                auth_token=auth_token,
                                jwt_secret=jwt_secret,
                                jwt_method=jwt_method)
    else:
        app = Sanic(__name__)
        CORS(app,
             resources={r"/*": {"origins": cors or ""}},
             automatic_options=True)

    if input_channels:
        rasa.core.channels.channel.register(input_channels,
                                            app,
                                            route=route)
    else:
        input_channels = []

    if logger.isEnabledFor(logging.DEBUG):
        utils.list_routes(app)

    # configure async loop logging
    async def configure_logging():
        if logger.isEnabledFor(logging.DEBUG):
            utils.enable_async_loop_debugging(asyncio.get_event_loop())

    app.add_task(configure_logging)

    if "cmdline" in {c.name() for c in input_channels}:
        async def run_cmdline_io(running_app: Sanic):
            """Small wrapper to shut down the server once cmd io is done."""
            await asyncio.sleep(1)  # allow server to start
            await console.record_messages(
                server_url=constants.DEFAULT_SERVER_FORMAT.format(port))

            logger.info("Killing Sanic server now.")
            running_app.stop()  # kill the sanic serverx

        app.add_task(run_cmdline_io)

    return app


def serve_application(core_model=None,
                      nlu_model=None,
                      channel=None,
                      port=constants.DEFAULT_SERVER_PORT,
                      credentials_file=None,
                      cors=None,
                      auth_token=None,
                      enable_api=True,
                      jwt_secret=None,
                      jwt_method=None,
                      endpoints=None
                      ):
    if not channel and not credentials_file:
        channel = "cmdline"

    input_channels = create_http_input_channels(channel, credentials_file)

    app = configure_app(input_channels, cors, auth_token, enable_api,
                        jwt_secret, jwt_method, port=port)

    logger.info("Starting Rasa Core server on "
                "{}".format(constants.DEFAULT_SERVER_FORMAT.format(port)))

    app.register_listener(
        partial(load_agent_on_start, core_model, endpoints, nlu_model),
        'before_server_start')
    app.run(host='0.0.0.0', port=port,
            access_log=logger.isEnabledFor(logging.DEBUG))


# noinspection PyUnusedLocal
async def load_agent_on_start(core_model, endpoints, nlu_model, app, loop):
    """Load an agent.

    Used to be scheduled on server start
    (hence the `app` and `loop` arguments)."""
    from rasa.core import broker
    from rasa.core.agent import Agent

    _interpreter = NaturalLanguageInterpreter.create(nlu_model,
                                                     endpoints.nlu)
    _broker = broker.from_endpoint_config(endpoints.event_broker)

    _tracker_store = TrackerStore.find_tracker_store(
        None, endpoints.tracker_store, _broker)

    if endpoints and endpoints.model:
        from rasa.core import agent

        app.agent = Agent(interpreter=_interpreter,
                          generator=endpoints.nlg,
                          tracker_store=_tracker_store,
                          action_endpoint=endpoints.action)

        await agent.load_from_server(app.agent,
                                     model_server=endpoints.model)
    else:
        app.agent = Agent.load(core_model,
                               interpreter=_interpreter,
                               generator=endpoints.nlg,
                               tracker_store=_tracker_store,
                               action_endpoint=endpoints.action)

    return app.agent


if __name__ == '__main__':
    # Running as standalone python application
    arg_parser = create_argument_parser()
    cmdline_args = arg_parser.parse_args()

    logging.getLogger('werkzeug').setLevel(logging.WARN)
    logging.getLogger('engineio').setLevel(logging.WARN)
    logging.getLogger('matplotlib').setLevel(logging.WARN)
    logging.getLogger('socketio').setLevel(logging.ERROR)

    rasa.utils.configure_colored_logging(cmdline_args.loglevel)
    utils.configure_file_logging(cmdline_args.loglevel,
                                 cmdline_args.log_file)

    _endpoints = AvailableEndpoints.read_endpoints(cmdline_args.endpoints)

    serve_application(cmdline_args.core,
                      cmdline_args.nlu,
                      cmdline_args.connector,
                      cmdline_args.port,
                      cmdline_args.credentials,
                      cmdline_args.cors,
                      cmdline_args.auth_token,
                      cmdline_args.enable_api,
                      cmdline_args.jwt_secret,
                      cmdline_args.jwt_method,
                      _endpoints)