RasaHQ/rasa_core

View on GitHub
rasa/core/broker.py

Summary

Maintainability
A
2 hrs
Test Coverage
import json
import logging
from typing import Any, Dict, Optional, Text

from rasa.core.utils import EndpointConfig, class_from_module_path

logger = logging.getLogger(__name__)


def from_endpoint_config(broker_config: Optional[EndpointConfig]
                         ) -> Optional['EventChannel']:
    """Instantiate an event channel based on its configuration."""

    if broker_config is None:
        return None
    elif broker_config.type == 'pika' or broker_config.type is None:
        return PikaProducer.from_endpoint_config(broker_config)
    elif broker_config.type == 'file':
        return FileProducer.from_endpoint_config(broker_config)
    elif broker_config.type == 'kafka':
        return FileProducer.from_endpoint_config(broker_config)
    else:
        return load_event_channel_from_module_string(broker_config)


def load_event_channel_from_module_string(broker_config: EndpointConfig
                                          ) -> Optional['EventChannel']:
    """Instantiate an event channel based on its class name."""

    try:
        event_channel = class_from_module_path(broker_config.type)
        return event_channel.from_endpoint_config(broker_config)
    except (AttributeError, ImportError) as e:
        logger.warning("EventChannel type '{}' not found. "
                       "Not using any event channel. Error: {}"
                       .format(broker_config.type, e))
        return None


class EventChannel(object):
    @classmethod
    def from_endpoint_config(cls, broker_config: EndpointConfig
                             ) -> 'EventChannel':
        raise NotImplementedError("Event broker must implement the "
                                  "`from_endpoint_config` method.")

    def publish(self, event: Dict[Text, Any]) -> None:
        """Publishes a json-formatted Rasa Core event into an event queue."""

        raise NotImplementedError("Event broker must implement the `publish` "
                                  "method.")


class PikaProducer(EventChannel):
    def __init__(self, host, username, password,
                 queue='rasa_core_events',
                 loglevel=logging.INFO):
        import pika

        logging.getLogger('pika').setLevel(loglevel)

        self.queue = queue
        self.host = host
        self.credentials = pika.PlainCredentials(username, password)

    @classmethod
    def from_endpoint_config(cls, broker_config: Optional['EndpointConfig']
                             ) -> Optional['PikaProducer']:
        if broker_config is None:
            return None

        return cls(broker_config.url, **broker_config.kwargs)

    def publish(self, event):
        self._open_connection()
        self._publish(json.dumps(event))
        self._close()

    def _open_connection(self):
        import pika

        parameters = pika.ConnectionParameters(self.host,
                                               credentials=self.credentials,
                                               connection_attempts=20,
                                               retry_delay=5)
        self.connection = pika.BlockingConnection(parameters)
        self.channel = self.connection.channel()
        self.channel.queue_declare(self.queue, durable=True)

    def _publish(self, body):
        self.channel.basic_publish('', self.queue, body)
        logger.debug('Published pika events to queue {} at '
                     '{}:\n{}'.format(self.queue, self.host, body))

    def _close(self):
        self.connection.close()


class FileProducer(EventChannel):
    """Log events to a file in json format.

    There will be one event per line and each event is stored as json."""

    DEFAULT_LOG_FILE_NAME = "rasa_event.log"

    def __init__(self, path: Optional[Text] = None) -> None:
        self.path = path or self.DEFAULT_LOG_FILE_NAME
        self.event_logger = self._event_logger()

    @classmethod
    def from_endpoint_config(cls, broker_config: Optional['EndpointConfig']
                             ) -> Optional['FileProducer']:
        if broker_config is None:
            return None

        # noinspection PyArgumentList
        return cls(**broker_config.kwargs)

    def _event_logger(self):
        """Instantiate the file logger."""

        logger_file = self.path
        # noinspection PyTypeChecker
        query_logger = logging.getLogger('event-logger')
        query_logger.setLevel(logging.INFO)
        handler = logging.FileHandler(logger_file)
        handler.setFormatter(logging.Formatter('%(message)s'))
        query_logger.propagate = False
        query_logger.addHandler(handler)

        logger.info("Logging events to '{}'.".format(logger_file))

        return query_logger

    def publish(self, event: Dict) -> None:
        """Write event to file."""

        self.event_logger.info(json.dumps(event))
        self.event_logger.handlers[0].flush()


class KafkaProducer(EventChannel):
    def __init__(self, host,
                 sasl_username=None,
                 sasl_password=None,
                 ssl_cafile=None,
                 ssl_certfile=None,
                 ssl_keyfile=None,
                 ssl_check_hostname=False,
                 topic='rasa_core_events',
                 security_protocol='SASL_PLAINTEXT',
                 loglevel=logging.ERROR):

        self.host = host
        self.topic = topic
        self.security_protocol = security_protocol
        self.sasl_username = sasl_username
        self.sasl_password = sasl_password
        self.ssl_cafile = ssl_cafile
        self.ssl_certfile = ssl_certfile
        self.ssl_keyfile = ssl_keyfile
        self.ssl_check_hostname = ssl_check_hostname

        logging.getLogger('kafka').setLevel(loglevel)

    @classmethod
    def from_endpoint_config(cls, broker_config) -> Optional['KafkaProducer']:
        if broker_config is None:
            return None

        return cls(broker_config.url, **broker_config.kwargs)

    def publish(self, event):
        self._create_producer()
        self._publish(event)
        self._close()

    def _create_producer(self):
        import kafka

        if self.security_protocol == 'SASL_PLAINTEXT':
            self.producer = kafka.KafkaProducer(
                bootstrap_servers=[self.host],
                value_serializer=lambda v: json.dumps(v).encode('utf-8'),
                sasl_plain_username=self.sasl_username,
                sasl_plain_password=self.sasl_password,
                sasl_mechanism='PLAIN',
                security_protocol=self.security_protocol)
        elif self.security_protocol == 'SSL':
            self.producer = kafka.KafkaProducer(
                bootstrap_servers=[self.host],
                value_serializer=lambda v: json.dumps(v).encode('utf-8'),
                ssl_cafile=self.ssl_cafile,
                ssl_certfile=self.ssl_certfile,
                ssl_keyfile=self.ssl_keyfile,
                ssl_check_hostname=False,
                security_protocol=self.security_protocol)

    def _publish(self, event):
        self.producer.send(self.topic, event)

    def _close(self):
        self.producer.close()