juntossomosmais/django-stomp

View on GitHub
django_stomp/services/producer.py

Summary

Maintainability
A
35 mins
Test Coverage
import json
import logging
import uuid
from contextlib import contextmanager
from typing import Dict
from typing import Optional

from django.core.serializers.json import DjangoJSONEncoder
from request_id_django_log.request_id import current_request_id
from request_id_django_log.settings import NO_REQUEST_ID
from stomp import Connection
from stomp.connect import StompConnection11

from django_stomp.helpers import clean_dict_with_falsy_or_strange_values
from django_stomp.helpers import create_dlq_destination_from_another_destination
from django_stomp.helpers import retry
from django_stomp.helpers import set_ssl_connection
from django_stomp.helpers import slow_down
from django_stomp.settings import DEFAULT_STOMP_SSL_VERSION
from django_stomp.settings import STOMP_USE_SSL

logger = logging.getLogger("django_stomp")


class Publisher:
    """
    Class used to publish messages to brokers using the STOMP protocol. Some headers are removed
    if they are in send() method as they cause unexpected behavior/errors.

    Such headers are defined in the UNSAFE_OR_RESERVED_BROKER_HEADERS_FOR_REMOVAL class variable which is used
    for sanitizing the user-supplied headers.
    """

    UNSAFE_OR_RESERVED_BROKER_HEADERS_FOR_REMOVAL = [
        # RabbitMQ unsafe headers
        "message-id",
        "transaction",
        "redelivered",
        "subscription",
        # Unsafe in a way that only django-stomp/broker should create
        "destination",
        "content-length",
        "content-type",
    ]

    def __init__(self, connection: StompConnection11, connection_configuration: Dict) -> None:
        self._connection_configuration = connection_configuration
        self.connection = connection
        self._default_content_type = "application/json;charset=utf-8"

    def is_open(self):
        return self.connection.is_connected()

    @slow_down
    def start(self):
        self.connection.connect(**self._connection_configuration)
        logger.debug("Connected")

    def close(self):
        disconnect_receipt = str(uuid.uuid4())
        self.connection.disconnect(receipt=disconnect_receipt)
        logger.debug("Disconnected")

    def start_if_not_open(self):
        if not self.is_open():
            logger.info("It is not open. Starting...")
            self.start()

    def send(self, body: dict, queue: str, headers=None, persistent=True, attempt=10):
        """
        Builds the final message headers/body and sends to the broker with the STOMP protocol. Attempt retries
        are ignored if the publisher is currently being used in a transaction in order to avoid already closed
        transactions errors due to the STOMP protocol behavior.
        """
        headers = self._build_final_headers(queue, headers, persistent)
        send_data = self._build_send_data(queue, body, headers)

        if self._is_publisher_in_transaction():
            self._send_to_broker_without_retry_attempts(send_data)
        else:
            self._send_to_broker(send_data, how_many_attempts=attempt)

    def _build_final_headers(self, queue: str, headers: Optional[Dict], persistent: bool) -> Dict:
        """
        Builds the message final headers. Removes unsafe or broker-reserved headers.
        Standard headers values override headers values to reduce possible errors.
        """
        if headers is None:
            headers = {}

        standard_headers = {
            "correlation-id": self._get_correlation_id(headers),
            "tshoot-destination": queue,
            # RabbitMQ
            # These two parameters must be set on consumer side as well, otherwise you'll get precondition_failed
            "x-dead-letter-routing-key": create_dlq_destination_from_another_destination(queue),
            "x-dead-letter-exchange": "",
        }

        # safety: standard_headers must override headers values, so order MATTERS here!
        mixed_headers = {**headers, **standard_headers}

        if persistent:
            self._add_persistent_messaging_header(mixed_headers)

        final_headers = self._remove_unsafe_or_reserved_for_broker_use_headers(mixed_headers)
        return final_headers

    def _get_correlation_id(self, headers: Optional[Dict]) -> str:
        """
        Gets the correlation id for the message. If 'correlation-id' is in the headers, this value is used.
        Otherwise, the value of current_request_id() is returned or a new one is generated as a last resort.
        """
        if headers and "correlation-id" in headers:
            return headers["correlation-id"]

        correlation_id = current_request_id() if current_request_id() != NO_REQUEST_ID else uuid.uuid4()
        return correlation_id

    def _remove_unsafe_or_reserved_for_broker_use_headers(self, headers: Dict) -> Dict:
        """
        Removes headers that are used internally by the brokers or that might cause unexpected behaviors (unsafe).
        """
        clean_headers = {
            key: headers[key] for key in headers if key not in self.UNSAFE_OR_RESERVED_BROKER_HEADERS_FOR_REMOVAL
        }

        return clean_headers

    def _build_send_data(self, queue: str, body: Dict, headers: Dict) -> Dict:
        """
        Builds the final data shape required to send messages using the STOMP protocol.
        """
        send_data = {
            "destination": queue,
            "body": json.dumps(body, cls=DjangoJSONEncoder),
            "headers": headers,
            "content_type": self._default_content_type,
            "transaction": getattr(self, "_tmp_transaction_id", None),
        }

        send_data = clean_dict_with_falsy_or_strange_values(send_data)
        return send_data

    def _send_to_broker(self, send_data: Dict, how_many_attempts: int) -> None:
        """
        Sends the actual data to the broker using the STOMP protocol with some retry attempts if
        a connection problem occurs.
        """

        def _internal_send_logic():
            self.start_if_not_open()
            self.connection.send(**send_data)

        retry(_internal_send_logic, attempt=how_many_attempts)

    def _send_to_broker_without_retry_attempts(self, send_data: Dict) -> None:
        """
        Sends the actual data to the broker using the STOMP protocol WITHOUT any retry attempts as reconnecting
        to the broker while a transaction was previously created will lead to 'bad transaction' errors because STOMP 1.1
        protocol closes any transactions if the producer had TCP connection problems or sends a DISCONNECT frame.

        Hence, when a producer sends a BEGIN frame, all subsequent SEND frames (messages) must always use the SAME
        connection that was used to start the transaction.

        -> STOMP 1.1 specification: https://stomp.github.io/stomp-specification-1.1.html#BEGIN
        """
        self.connection.send(**send_data)  # bare sending without retries

    def _is_publisher_in_transaction(self) -> bool:
        """
        Checks if the publisher is currently being used in a transaction.
        """
        return hasattr(self, "_tmp_transaction_id")  # attribute set by do_inside_transaction contextmanager

    @staticmethod
    def _add_persistent_messaging_header(headers: Dict) -> Dict:
        value = {"persistent": "true"}

        if headers:
            headers.update(value)
            return headers

        return value

    @contextmanager
    def auto_open_close_connection(self):
        try:
            self.start_if_not_open()
            yield self
        finally:
            if self.is_open():
                self.close()

    @contextmanager
    def do_inside_transaction(self):
        try:
            self.start_if_not_open()
            transaction_id = self.connection.begin()
            logger.debug("Created transaction ID: %s", transaction_id)
            setattr(self, "_tmp_transaction_id", transaction_id)
            yield self
            self.connection.commit(transaction_id)
        except BaseException as e:
            logger.exception("Error inside transaction")
            if hasattr(self, "_tmp_transaction_id"):
                self.connection.abort(getattr(self, "_tmp_transaction_id"))
            raise e
        finally:
            if hasattr(self, "_tmp_transaction_id"):
                delattr(self, "_tmp_transaction_id")


def build_publisher(**connection_params) -> Publisher:
    hosts, vhost = [(connection_params.get("host"), connection_params.get("port"))], connection_params.get("vhost")
    if connection_params.get("hostStandby") and connection_params.get("portStandby"):
        hosts.append((connection_params.get("hostStandby"), connection_params.get("portStandby")))

    use_ssl = STOMP_USE_SSL
    logger.debug(f"Use SSL? {use_ssl}. Version: {DEFAULT_STOMP_SSL_VERSION}")

    client_id = connection_params.get("client_id", uuid.uuid4())
    connection_configuration = {
        "username": connection_params.get("username"),
        "passcode": connection_params.get("password"),
        "wait": True,
        "headers": {"client-id": f"{client_id}-publisher"},
    }
    conn = Connection(hosts, vhost=vhost)

    if use_ssl:
        conn = set_ssl_connection(conn)

    publisher = Publisher(conn, connection_configuration)
    return publisher


@contextmanager
def auto_open_close_connection(publisher: Publisher):
    try:
        publisher.start()
        yield
    finally:
        if publisher.is_open():
            publisher.close()


@contextmanager
def do_inside_transaction(publisher: Publisher):
    try:
        publisher.start_if_not_open()
        transaction_id = publisher.connection.begin()
        logger.debug("Created transaction ID: %s", transaction_id)
        setattr(publisher, "_tmp_transaction_id", transaction_id)
        yield
        publisher.connection.commit(transaction_id)
    except BaseException as e:
        logger.exception("Error inside transaction")
        if hasattr(publisher, "_tmp_transaction_id"):
            publisher.connection.abort(getattr(publisher, "_tmp_transaction_id"))
        raise e
    finally:
        if hasattr(publisher, "_tmp_transaction_id"):
            delattr(publisher, "_tmp_transaction_id")