navikt/dataverk

View on GitHub
src/dataverk/connectors/kafka.py

Summary

Maintainability
A
0 mins
Test Coverage
import math

import pandas as pd
import json
import struct
import requests
import avro
import avro.schema
import avro.io
import io
import time

from dataverk.utils import mapping_util
from kafka import KafkaConsumer
from collections.abc import Mapping, Sequence
from enum import Enum
from datetime import datetime
from dataverk.abc.base import DataverkBase
import streamz


class KafkaFetchMode(Enum):
    FROM_BEGINNING = "from_beginning"
    LAST_COMMITED_OFFSET = "last_commited_offset"


class KafkaConnector(DataverkBase):

    def __init__(self, consumer: KafkaConsumer, settings: Mapping, topics: Sequence, fetch_mode: str):
        """ Dataverk Kafka consumer class

        :param settings: Mapping object containing project settings
        :param topics: Sequence of topics to subscribe to
        :param fetch_mode: str describing fetch mode (from_beginning, last_committed_offset)
        """
        super().__init__()
        self._topics = topics
        self._fetch_mode = fetch_mode
        self._schema_cache = {}
        self._consumer = consumer
        self.log.info(f"KafkaConsumer created with fetch mode set to '{fetch_mode}'")
        self._read_until_timestamp = self._get_current_timestamp_in_ms()
        self._schema_registry_url = mapping_util.safe_get_nested(settings,
                                                                 keys=("kafka", "schema_registry"),
                                                                 default="http://localhost:8081")

    def get_pandas_df(self, strategy=None, fields=None, max_mesgs=math.inf):
        """ Read kafka topics, commits offset and returns result as pandas dataframe

        :return: pd.Dataframe containing kafka messages read. NB! Commits offset
        """
        if strategy is None:
            records = self._read_kafka_raw(max_mesgs, fields)
        else:
            records = self._read_kafka_accumulated(max_mesgs, strategy)
        try:
            df = pd.DataFrame.from_records(records)
        except ValueError:
            df = pd.DataFrame.from_records(records, index=[0])
        self._commit_offsets()
        self._consumer.close()

        return df

    def get_message_fields(self):
        """ Read single kafka message from topic and return message fields

        :return: list: message fields
        """
        for message in self._consumer:
            try:
                schema_res = self._get_schema_from_registry(message=message)
                schema = schema_res.json()["schema"]
            except (AttributeError, KeyError):
                mesg = json.loads(message.value.decode('utf8'))
            else:
                mesg = self._decode_avro_message(schema=schema, message=message)

            return mesg.keys()

    def _read_kafka_raw(self, max_mesgs, fields):
        start_time = time.time()

        self.log.info(f"Reading kafka topic {self._topics}. Fetch mode {self._fetch_mode}")

        data = list()

        for message in self._consumer:
            mesg = self._parse_kafka_message(message)
            data.append(self._extract_requested_fields(mesg, fields))
            if self._is_requested_messages_read(message, max_mesgs, len(data)):
                break

        self.log.info(f"({len(data)} messages read from kafka topic(s) {self._topics} in {time.time() - start_time} sec. Fetch mode {self._fetch_mode}")

        return data

    def _read_kafka_accumulated(self, max_mesgs, strategy):
        data = {}
        mesg_count = 0
        stream = streamz.Stream()
        acc = stream.accumulate(strategy, start=data)

        for message in self._consumer:
            mesg = self._parse_kafka_message(message)
            mesg_count += 1
            stream.emit(mesg)
            if self._is_requested_messages_read(message, max_mesgs, mesg_count):
                break
        return data

    def _parse_kafka_message(self, message):
        try:
            schema_res = self._get_schema_from_registry(message=message)
            schema = schema_res.json()["schema"]
        except (AttributeError, KeyError):
            try:
                res = json.loads(message.value.decode('utf8'))
            except (json.JSONDecodeError, UnicodeDecodeError):
                res = {}
            return res
        else:
            return self._decode_avro_message(schema=schema, message=message)

    def _get_schema_from_registry(self, message):
        schema_id = struct.unpack(">L", message.value[1:5])[0]
        if schema_id in self._schema_cache:
            return self._schema_cache[schema_id]
        else:
            schema = requests.get(self._schema_registry_url + '/schemas/ids/' + str(schema_id))
            self._schema_cache[schema_id] = schema
            return schema

    @staticmethod
    def _decode_avro_message(schema, message):
        schema = avro.schema.Parse(schema)
        bytes_reader = io.BytesIO(message.value[5:])
        decoder = avro.io.BinaryDecoder(bytes_reader)
        reader = avro.io.DatumReader(schema)
        return reader.read(decoder)

    @staticmethod
    def _get_current_timestamp_in_ms():
        return int(datetime.now().timestamp() * 1000)

    def _is_requested_messages_read(self, message, max_mesgs, mesgs_read):
        if message.timestamp >= self._read_until_timestamp:
            return True
        elif mesgs_read >= max_mesgs:
            return True
        return False

    def _commit_offsets(self):
        """ Commits the offsets to kafka when the KafkaConsumer object is configured with group_id
        """
        if self._consumer.config["group_id"] is not None:
            self._consumer.commit()

    def _extract_requested_fields(self, mesg, fields):
        if fields is not None:
            return [self._extract_field(mesg, field) for field in fields]
        else:
            return mesg

    @staticmethod
    def _extract_field(mesg, field):
        try:
            return mesg[field]
        except KeyError:
            return None


def get_kafka_consumer(settings: Mapping, topics: Sequence, fetch_mode: str) -> KafkaConsumer:
    """ Factory method returning a KafkaConsumer object with desired configuration

    :param settings: Mapping object containing project settings
    :param topics: Sequence of topics to subscribe
    :param fetch_mode: str describing fetch mode (from_beginning, last_committed_offset)
    :return: KafkaConsumer object with desired configuration
    """
    if KafkaFetchMode(fetch_mode) is KafkaFetchMode.FROM_BEGINNING:
        group_id = None
    elif KafkaFetchMode(fetch_mode) is KafkaFetchMode.LAST_COMMITED_OFFSET:
        group_id = settings["kafka"].get("group_id", None)
    else:
        raise ValueError(f"{fetch_mode} is not a valid KafkaFetchMode. Valid fetch_modes are: "
                         f"'{KafkaFetchMode.FROM_BEGINNING}' and '{KafkaFetchMode.LAST_COMMITED_OFFSET}'")

    return KafkaConsumer(*topics,
                         group_id=group_id,
                         bootstrap_servers=mapping_util.safe_get_nested(settings, keys=("kafka", "brokers"), default="localhost:9092"),
                         security_protocol=mapping_util.safe_get_nested(settings, keys=("kafka", "security_protocol"), default="PLAINTEXT"),
                         sasl_mechanism=mapping_util.safe_get_nested(settings, keys=("kafka", "sasl_mechanism"), default=None),
                         sasl_plain_username=mapping_util.safe_get_nested(settings, keys=("kafka", "sasl_plain_username"), default=None),
                         sasl_plain_password=mapping_util.safe_get_nested(settings, keys=("kafka", "sasl_plain_password"), default=None),
                         ssl_cafile=mapping_util.safe_get_nested(settings, keys=("kafka", "ssl_cafile"), default=None),
                         auto_offset_reset=mapping_util.safe_get_nested(settings, keys=("kafka", "auto_offset_reset"), default='earliest'),
                         enable_auto_commit=False,
                         consumer_timeout_ms=mapping_util.safe_get_nested(settings, keys=("kafka", "consumer_timeout_ms"), default=1000),
                         heartbeat_interval_ms=mapping_util.safe_get_nested(settings, keys=("kafka", "heartbeat_interval_ms"), default=3000),
                         session_timeout_ms=mapping_util.safe_get_nested(settings, keys=("kafka", "session_timeout_ms"), default=10000),
                         max_poll_interval_ms=mapping_util.safe_get_nested(settings, keys=("kafka", "max_poll_interval_ms"), default=300000))