rasa/core/brokers/kafka.py

Summary

Maintainability
B
4 hrs
Test Coverage
F
52%
import asyncio
import os
import json
import logging
import structlog
import threading
from asyncio import AbstractEventLoop
from typing import Any, Text, List, Optional, Union, Dict, TYPE_CHECKING
import time

from rasa.core.brokers.broker import EventBroker
from rasa.core.exceptions import KafkaProducerInitializationError
from rasa.shared.utils.io import DEFAULT_ENCODING
from rasa.utils.endpoints import EndpointConfig
import rasa.shared.utils.common

if TYPE_CHECKING:
    from confluent_kafka import KafkaError, Producer, Message

logger = logging.getLogger(__name__)
structlogger = structlog.get_logger()


class KafkaEventBroker(EventBroker):
    """Kafka event broker."""

    def __init__(
        self,
        url: Union[Text, List[Text], None],
        topic: Text = "rasa_core_events",
        client_id: Optional[Text] = None,
        partition_by_sender: bool = False,
        sasl_username: Optional[Text] = None,
        sasl_password: Optional[Text] = None,
        sasl_mechanism: Optional[Text] = "PLAIN",
        ssl_cafile: Optional[Text] = None,
        ssl_certfile: Optional[Text] = None,
        ssl_keyfile: Optional[Text] = None,
        ssl_check_hostname: bool = False,
        security_protocol: Text = "SASL_PLAINTEXT",
        **kwargs: Any,
    ) -> None:
        """Kafka event broker.

        Args:
            url: 'url[:port]' string (or list of 'url[:port]'
                strings) that the producer should contact to bootstrap initial
                cluster metadata. This does not have to be the full node list.
                It just needs to have at least one broker that will respond to a
                Metadata API Request.
            topic: Topics to subscribe to.
            client_id: A name for this client. This string is passed in each request
                to servers and can be used to identify specific server-side log entries
                that correspond to this client. Also submitted to `GroupCoordinator` for
                logging with respect to producer group administration.
            partition_by_sender: Flag to configure whether messages are partitioned by
                sender_id or not
            sasl_username: Username for plain authentication.
            sasl_password: Password for plain authentication.
            sasl_mechanism: Authentication mechanism when security_protocol is
                configured for SASL_PLAINTEXT or SASL_SSL.
                Valid values are: PLAIN, GSSAPI, OAUTHBEARER, SCRAM-SHA-256,
                SCRAM-SHA-512. Default: `PLAIN`
            ssl_cafile: Optional filename of ca file to use in certificate
                verification.

            ssl_certfile : Optional filename of file in pem format containing
                the client certificate, as well as any ca certificates needed to
                establish the certificate's authenticity.

            ssl_keyfile : Optional filename containing the client private key.

            ssl_check_hostname : Flag to configure whether ssl handshake
                should verify that the certificate matches the broker's hostname.

            security_protocol : Protocol used to communicate with brokers.
                Valid values are: PLAINTEXT, SSL, SASL_PLAINTEXT, SASL_SSL.
        """
        self.producer: Optional[Producer] = None
        self.url = url
        self.topic = topic
        self.client_id = client_id
        self.partition_by_sender = partition_by_sender
        self.security_protocol = security_protocol.upper()
        self.sasl_username = sasl_username
        self.sasl_password = sasl_password
        self.sasl_mechanism = sasl_mechanism
        self.ssl_cafile = ssl_cafile
        self.ssl_certfile = ssl_certfile
        self.ssl_keyfile = ssl_keyfile
        self.queue_size = kwargs.get("queue_size")
        self.ssl_check_hostname = "https" if ssl_check_hostname else None

        # Async producer implementation followed from confluent-kafka asyncio example:
        # https://github.com/confluentinc/confluent-kafka-python/blob/master/examples/asyncio_example.py#L88  # noqa: E501
        self._loop = asyncio.get_event_loop()
        self._cancelled = False
        self._poll_thread = threading.Thread(target=self._poll_loop)
        self._poll_thread.start()

    @classmethod
    async def from_endpoint_config(
        cls,
        broker_config: EndpointConfig,
        event_loop: Optional[AbstractEventLoop] = None,
    ) -> Optional["KafkaEventBroker"]:
        """Creates broker. See the parent class for more information."""
        if broker_config is None:
            return None

        return cls(broker_config.url, **broker_config.kwargs)

    def publish(
        self,
        event: Dict[Text, Any],
        retries: int = 60,
        retry_delay_in_seconds: float = 5,
    ) -> None:
        """Publishes events."""
        from confluent_kafka import KafkaException

        if retries == 1:
            retries = 2

        if self.producer is None:
            self.producer = self._create_producer()
            try:
                self._check_kafka_connection()
                logger.debug("Connection to kafka successful.")
            except KafkaException:
                logger.debug("Failed to connect kafka.")
                return
        while retries:
            try:
                self._publish(event)
                return
            except BufferError as e:
                logger.error(
                    f"Could not publish message to kafka url '{self.url}'. "
                    f"Failed with error: {e}"
                )
                self.producer.poll(1)
                retries -= 1
            except Exception as e:
                logger.error(
                    f"Could not publish message to kafka url '{self.url}'. "
                    f"Failed with error: {e}"
                )
                try:
                    self._check_kafka_connection()
                except KafkaException:
                    logger.debug("Connection to kafka lost, reconnecting...")
                    self.producer = self._create_producer()
                    try:
                        self._check_kafka_connection()
                        logger.debug("Reconnection to kafka successful")
                        self._publish(event)
                        return
                    except KafkaException:
                        pass
                retries -= 1
                time.sleep(retry_delay_in_seconds)

        logger.error("Failed to publish Kafka event.")

    def _check_kafka_connection(self) -> None:
        """Verifies connection with Kafka.

        Raises:
            KafkaException: if Kafka is disconnected.
        """
        if self.producer is not None:
            self.producer.list_topics(timeout=5)

    def _get_kafka_config(self) -> Dict[Text, Any]:
        config = {
            "client.id": self.client_id,
            "bootstrap.servers": self.url,
            "error_cb": kafka_error_callback,
        }
        if self.queue_size:
            config["queue.buffering.max.messages"] = self.queue_size

        if self.security_protocol == "PLAINTEXT":
            authentication_params: Dict[Text, Any] = {
                "security.protocol": self.security_protocol,
            }
        elif self.security_protocol == "SASL_PLAINTEXT":
            authentication_params = {
                "sasl.username": self.sasl_username,
                "sasl.password": self.sasl_password,
                "sasl.mechanism": self.sasl_mechanism,
                "security.protocol": self.security_protocol,
            }
        elif self.security_protocol == "SSL":
            authentication_params = {
                "ssl.ca.location": self.ssl_cafile,
                "ssl.certificate.location": self.ssl_certfile,
                "ssl.key.location": self.ssl_keyfile,
                "security.protocol": self.security_protocol,
            }
        elif self.security_protocol == "SASL_SSL":
            authentication_params = {
                "sasl.username": self.sasl_username,
                "sasl.password": self.sasl_password,
                "ssl.ca.location": self.ssl_cafile,
                "ssl.certificate.location": self.ssl_certfile,
                "ssl.key.location": self.ssl_keyfile,
                "ssl.endpoint.identification.algorithm": self.ssl_check_hostname,
                "security.protocol": self.security_protocol,
                "sasl.mechanism": self.sasl_mechanism,
            }
        else:
            raise ValueError(
                f"Cannot initialise `KafkaEventBroker`: "
                f"Invalid `security_protocol` ('{self.security_protocol}')."
            )

        return {**config, **authentication_params}

    def _create_producer(self) -> "Producer":
        import confluent_kafka

        try:
            return confluent_kafka.Producer(self._get_kafka_config())
        except confluent_kafka.KafkaException as e:
            raise KafkaProducerInitializationError(
                f"Cannot initialise `KafkaEventBroker`: {e}"
            )

    def _publish(self, event: Dict[Text, Any]) -> None:
        if self.partition_by_sender:
            partition_key = bytes(event.get("sender_id"), encoding=DEFAULT_ENCODING)
        else:
            partition_key = None

        headers = []
        if self.rasa_environment:
            headers = [
                (
                    "RASA_ENVIRONMENT",
                    bytes(self.rasa_environment, encoding=DEFAULT_ENCODING),
                )
            ]

        reduced_event = rasa.shared.core.events.remove_parse_data(event)
        structlogger.debug(
            "kafka.publish.event",
            event_info="Logging a reduced version of the Kafka event",
            topic=self.topic,
            rasa_event=reduced_event,
            partition_key=partition_key,
            headers=headers,
        )

        serialized_event = json.dumps(event).encode(DEFAULT_ENCODING)

        if self.producer is not None:
            self.producer.produce(
                self.topic,
                value=serialized_event,
                key=partition_key,
                headers=headers,
                on_delivery=delivery_report,
            )

    def _close(self) -> None:
        self._cancelled = True
        self._poll_thread.join()

    @rasa.shared.utils.common.lazy_property
    def rasa_environment(self) -> Optional[Text]:
        """Get value of the `RASA_ENVIRONMENT` environment variable."""
        return os.environ.get("RASA_ENVIRONMENT", "RASA_ENVIRONMENT_NOT_SET")

    def _poll_loop(self) -> None:
        """Polls the producer for events.

        Required to trigger the on_delivery callback passed to produce method.
        """
        if self.producer is not None:
            while not self._cancelled:
                self.producer.poll(0.1)


def kafka_error_callback(err: "KafkaError") -> None:
    """Callback for Kafka errors.

    Any exception raised from this callback will be re-raised from the
    triggering flush() call.
    """
    from confluent_kafka import KafkaException, KafkaError

    # handle authentication / connection related issues, likely pointing
    # to a configuration error
    if (
        err.code() == KafkaError._ALL_BROKERS_DOWN
        or err.code() == KafkaError._AUTHENTICATION
        or err.code() == KafkaError._MAX_POLL_EXCEEDED
    ):
        raise KafkaException(err)
    else:
        logger.warning("A KafkaError has been raised.", exc_info=True)


def delivery_report(err: Exception, msg: "Message") -> None:
    """Reports the failure or success of a message delivery.

    Args:
        err (KafkaError): The error that occurred on None on success.
        msg (Message): The message that was produced or failed.
    """
    if err is not None:
        logger.error(f"Delivery failed for User record {msg.key()}: {err}")
        return

    logger.info(
        f"User record {msg.key()} successfully produced to "
        f"{msg.topic()} [{msg.partition()}] at offset {msg.offset()}."
    )