whylabs/whylogs-python

View on GitHub
python/whylogs/core/view/segmented_dataset_profile_view.py

Summary

Maintainability
A
1 hr
Test Coverage
import tempfile
from datetime import datetime
from logging import getLogger
from typing import IO, Any, Dict, Optional, Tuple

from whylogs.api.writer.writer import Writable
from whylogs.core.proto import (
    ChunkHeader,
    ChunkMessage,
    ChunkOffsets,
    DatasetProfileHeader,
    DatasetProperties,
    DatasetSegmentHeader,
    MetricComponentMessage,
)
from whylogs.core.proto import Segment as SegmentMessage
from whylogs.core.segment import Segment
from whylogs.core.segmentation_partition import SegmentationPartition
from whylogs.core.utils import write_delimited_protobuf
from whylogs.core.utils.timestamp_calculations import to_utc_milliseconds
from whylogs.core.view.dataset_profile_view import (
    WHYLOGS_MAGIC_HEADER_BYTES,
    DatasetProfileView,
)
from whylogs.migration.converters import (
    _generate_segment_tags_metadata,
    v1_to_dataset_profile_message_v0,
)

logger = getLogger(__name__)

_BUFFER_CHUNK = 1024


class SegmentedDatasetProfileView(Writable):
    _profile_view: DatasetProfileView
    _segment: Segment
    _partition: SegmentationPartition

    def __init__(
        self,
        *,
        profile_view: DatasetProfileView,
        segment: Segment,
        partition: SegmentationPartition,
    ):
        self._profile_view = profile_view
        self._segment = segment
        self._partition = partition

    @property
    def segment(self) -> Segment:
        return self._segment

    @property
    def partition(self) -> SegmentationPartition:
        return self._partition

    @property
    def profile_view(self) -> DatasetProfileView:
        return self._profile_view

    @property
    def dataset_timestamp(self) -> Optional[datetime]:
        return self.profile_view.dataset_timestamp

    @property
    def creation_timestamp(self) -> Optional[datetime]:
        return self.profile_view.creation_timestamp

    @property
    def model_performance_metrics(self) -> Any:
        return self.profile_view.model_performance_metrics

    @property
    def metadata(self) -> Dict[str, str]:
        return self.profile_view.metadata

    def get_default_path(self) -> str:
        return f"profile_{self._profile_view.creation_timestamp}_{self.get_segment_string()}.bin"

    def get_segment_string(self) -> str:
        return f"{self._segment.parent_id}_{'_'.join(self._segment.key)}"

    def _write_as_v0_message(self, path: Optional[str] = None, **kwargs: Any) -> Tuple[bool, str]:
        message_v0 = v1_to_dataset_profile_message_v0(self.profile_view, self.segment, self.partition)
        file_to_write = kwargs.get("file")
        path = file_to_write.name if file_to_write else path or self.get_default_path()
        if file_to_write:
            write_delimited_protobuf(file_to_write, message_v0)
        else:
            with open(path, "w+b") as out_f:
                write_delimited_protobuf(out_f, message_v0)
        return True, path

    def _copy_write(
        self,
        source_profile_file: IO[bytes],
        output_file: IO[bytes],
        total_len: int,
        dataset_segment_header: DatasetSegmentHeader,
        dataset_header: DatasetProfileHeader,
    ):
        output_file.write(WHYLOGS_MAGIC_HEADER_BYTES)
        write_delimited_protobuf(output_file, dataset_segment_header)
        logger.debug(
            f"Writing segmented profile file: whylogs file and segment headers wrote {output_file.tell()} bytes"
        )
        write_delimited_protobuf(output_file, dataset_header)
        logger.debug(
            f"Writing segmented profile file: and with dataset header wrote a total of {output_file.tell()} bytes before writing the chunks."
        )

        source_profile_file.seek(0)
        while source_profile_file.tell() < total_len:
            buffer = source_profile_file.read(_BUFFER_CHUNK)
            output_file.write(buffer)
        logger.debug(f"Writing segmented profile file: complete! total of {output_file.tell()} bytes written.")

    def _write_v1(self, path: Optional[str] = None, **kwargs: Any) -> Tuple[bool, str]:
        all_metric_component_names = set()
        file_to_write = kwargs.get("file")
        path = file_to_write.name if file_to_write else path or self.get_default_path()

        # capture the list of all metric component paths
        for col in self.profile_view._columns.values():
            all_metric_component_names.update(col.get_metric_component_paths())
        metric_name_list = list(all_metric_component_names)
        metric_name_list.sort()
        metric_name_indices: Dict[str, int] = {}
        metric_index_to_name: Dict[int, str] = {}
        for i in range(0, len(metric_name_list)):
            metric_name_indices[metric_name_list[i]] = i
            metric_index_to_name[i] = metric_name_list[i]

        column_chunk_offsets: Dict[str, ChunkOffsets] = {}

        with tempfile.TemporaryFile("w+b") as f:
            for col_name in sorted(self.profile_view._columns.keys()):
                column_chunk_offsets[col_name] = ChunkOffsets(offsets=[f.tell()])

                col = self.profile_view._columns[col_name]

                # for a given column, turn it into a ChunkMessage.
                indexed_component_messages: Dict[int, MetricComponentMessage] = {}
                metric_components = col.to_protobuf().metric_components
                for m_name, m_component in metric_components.items():
                    index = metric_name_indices.get(m_name)
                    if index is None:
                        raise ValueError(f"Missing metric from index mapping. Metric name: {m_name}")
                    indexed_component_messages[index] = m_component

                chunk_msg = ChunkMessage(metric_components=indexed_component_messages)
                chunk_header = ChunkHeader(type=ChunkHeader.ChunkType.COLUMN, length=chunk_msg.ByteSize())
                write_delimited_protobuf(f, chunk_header)
                f.write(chunk_msg.SerializeToString())

            total_len = f.tell()
            f.flush()

            # calculate segment tags based on columnar segments
            segment_message_tags, segment_tags, segment_metadata = _generate_segment_tags_metadata(
                self.segment, self.partition
            )

            if not segment_tags:
                raise NotImplementedError(
                    f"Serialization of segments requires segment tags but none calculated for partition: {self.partition} and segments: {self.segment}"
                )

            segment_message = SegmentMessage()
            segment_message.tags.extend(segment_tags)
            segments_message_field = [segment_message]

            # Make sure to promote the contained profile's metadata to the DatasetProperties
            if self.metadata:
                segment_metadata.update(self.metadata)

            properties = DatasetProperties(
                dataset_timestamp=to_utc_milliseconds(self.profile_view._dataset_timestamp),
                creation_timestamp=to_utc_milliseconds(self.profile_view._creation_timestamp),
                metadata=segment_metadata,
                tags=segment_message_tags,
            )

            logger.info(f"Constructed DatasetProperties for segmented profile file: {properties}")

            dataset_header = DatasetProfileHeader(
                column_offsets=column_chunk_offsets,
                properties=properties,
                length=total_len,
                indexed_metric_paths=metric_index_to_name,
            )

            # TODO: multi segment file format requires multiple offset calculations.
            #  Single segment file only supported initially.
            #  This creates the offsets dictionary with correct keys but 0 offset values, later update offsets
            #  based on file position offset in the temp file written, similarly to how Chunk
            segment_offsets: Dict[int, int] = {
                n: 0 for n in range(len(segments_message_field))
            }  # update this after writing

            # single file segments.
            dataset_segment_header = DatasetSegmentHeader(has_segments=True, offsets=segment_offsets)

            # TODO: calculate other segments offsets when we support multiple segments per file
            write_delimited_protobuf(f, dataset_segment_header)
            first_segment_offset = f.tell() - total_len
            dataset_segment_header.offsets[0] = first_segment_offset

            # only single segment files at first.
            dataset_segment_header.segments.extend(segments_message_field)

            if file_to_write:
                self._copy_write(f, file_to_write, total_len, dataset_segment_header, dataset_header)
            else:
                with open(path, "w+b") as out_f:
                    self._copy_write(f, out_f, total_len, dataset_segment_header, dataset_header)
        return True, path

    def write(self, path: Optional[str] = None, **kwargs: Any) -> Tuple[bool, str]:
        if kwargs.get("use_v0") or self.profile_view.model_performance_metrics:
            if self.profile_view.model_performance_metrics:
                logger.info("Converting segmented profile with performance metrics to v0 format before writing.")
            else:
                logger.info("writing segmented profile as v0 format.")
            return self._write_as_v0_message(path, **kwargs)
        else:
            return self._write_v1(path, **kwargs)