whylabs/whylogs-python

View on GitHub
python/whylogs/api/logger/result_set.py

Summary

Maintainability
C
1 day
Test Coverage
from abc import ABC, abstractmethod
from datetime import datetime
from logging import getLogger
from typing import Any, Dict, List, Optional, Tuple, Union

from whylogs.api.reader import Reader, Readers
from whylogs.api.writer import Writer, Writers
from whylogs.api.writer.writer import Writable
from whylogs.core import DatasetProfile, DatasetProfileView, Segment
from whylogs.core.metrics.metrics import Metric
from whylogs.core.model_performance_metrics import ModelPerformanceMetrics
from whylogs.core.segmentation_partition import SegmentationPartition
from whylogs.core.utils import ensure_timezone
from whylogs.core.view.dataset_profile_view import _MODEL_PERFORMANCE
from whylogs.core.view.segmented_dataset_profile_view import SegmentedDatasetProfileView

logger = getLogger(__name__)


def _merge_metrics(
    lhs_metrics: Optional[Dict[str, Any]], rhs_metrics: Optional[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
    if not rhs_metrics:
        return lhs_metrics
    if not lhs_metrics:
        return rhs_metrics
    lhs_keys = lhs_metrics.keys()
    rhs_keys = rhs_metrics.keys()
    merged_metrics: Dict[str, Any] = dict()
    lhs_only = lhs_keys - rhs_keys
    rhs_only = rhs_keys - lhs_keys
    intersection_keys = lhs_keys & rhs_keys
    for key in lhs_only:
        merged_metrics[key] = lhs_metrics[key]

    for key in rhs_only:
        merged_metrics[key] = rhs_metrics[key]

    for key in intersection_keys:
        merged_metrics[key] = lhs_metrics[key].merge(rhs_metrics[key])

    return merged_metrics


def _accumulate_properties(acc: Optional[Dict[str, Any]], props: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
    if not props:
        return acc
    if acc is None:
        acc = dict()
    intersection_keys = acc.keys() & props.keys()
    for key in intersection_keys:
        if acc[key] != props[key]:
            logger.warning(f"Merging result properties collision, using {key}:{acc[key]} and dropping {props[key]}.")
    acc.update(props)
    return acc


def _merge_segments(
    lhs_segments: Dict[Segment, Union[DatasetProfile, DatasetProfileView]],
    rhs_segments: Dict[Segment, Union[DatasetProfile, DatasetProfileView]],
) -> Dict[Segment, DatasetProfileView]:
    lhs_keys = lhs_segments.keys()
    rhs_keys = rhs_segments.keys()
    merged_segments: Dict[Segment, DatasetProfileView] = dict()
    lhs_only = lhs_keys - rhs_keys
    rhs_only = rhs_keys - lhs_keys
    intersection_keys = lhs_keys & rhs_keys
    for key in lhs_only:
        left_value = lhs_segments[key]
        left_view = left_value if isinstance(left_value, DatasetProfileView) else left_value.view()
        merged_segments[key] = left_view

    for key in rhs_only:
        right_value = rhs_segments[key]
        right_view = right_value if isinstance(right_value, DatasetProfileView) else right_value.view()
        merged_segments[key] = right_view

    for key in intersection_keys:
        left_value = lhs_segments[key]
        left_view = left_value if isinstance(left_value, DatasetProfileView) else left_value.view()
        right_value = lhs_segments[key]
        right_view = right_value if isinstance(right_value, DatasetProfileView) else right_value.view()
        merged_segments[key] = left_view.merge(right_view)

    return merged_segments


def _merge_partitioned_segments(
    lhs_segments: Dict[str, Dict[Segment, DatasetProfile]], rhs_segments: Dict[str, Dict[Segment, DatasetProfile]]
) -> Dict[str, Dict[Segment, DatasetProfile]]:
    lhs_partitions = lhs_segments.keys()
    rhs_partitions = rhs_segments.keys()
    merged_partitions: Dict[str, Dict[Segment, DatasetProfile]] = dict()
    lhs_only = lhs_partitions - rhs_partitions
    rhs_only = rhs_partitions - lhs_partitions
    intersection_keys = lhs_partitions & rhs_partitions
    for key in lhs_only:
        merged_partitions[key] = lhs_segments[key]

    for key in rhs_only:
        merged_partitions[key] = rhs_segments[key]

    for key in intersection_keys:
        merged_partitions[key] = _merge_segments(lhs_segments[key], rhs_segments[key])

    return merged_partitions


def _merge_partitions(
    lhs_partitions: List[SegmentationPartition], rhs_partitions: List[SegmentationPartition]
) -> List[SegmentationPartition]:
    return list(set(lhs_partitions).union(set(rhs_partitions)))


class ResultSetWriter:
    """
    Result of a logging call.
    A result set might contain one or multiple profiles or profile views.
    """

    def __init__(self, results: "ResultSet", writer: Writer):
        self._result_set = results
        self._writer = writer

    def option(self, **kwargs: Any) -> "ResultSetWriter":
        self._writer.option(**kwargs)
        return self

    def write(self, **kwargs: Any) -> List:
        # multi-profile writer
        files = self._result_set.get_writables()
        statuses: List[Tuple[bool, str]] = list()
        if not files:
            logger.warning("Attempt to write a result set with no writables returned, nothing written!")
            return statuses
        if hasattr(self._writer, "_reference_profile_name"):
            if self._writer._reference_profile_name is not None and isinstance(self._result_set, SegmentedResultSet):
                # a segmented reference profile name needs to have access to the complete result set
                response = self._writer.write(file=self._result_set, **kwargs)
                statuses.append(response)
                return statuses
        logger.debug(f"About to write {len(files)} files:")
        # TODO: special handling of large number of files, handle throttling
        for view in files:
            response = self._writer.write(file=view, **kwargs)
            statuses.append(response)
        logger.debug(f"Completed writing {len(files)} files!")

        return statuses


class ResultSetReader:
    def __init__(self, reader: Reader) -> None:
        self._reader = reader

    def option(self, **kwargs: Any) -> "ResultSetReader":
        self._reader.option(**kwargs)
        return self

    def read(self, **kwargs: Any) -> "ResultSet":
        return self._reader.read(**kwargs)


class ResultSet(ABC):
    """
    A holder object for profiling results.

    A whylogs.log call can result in more than one profile. This wrapper class
    simplifies the navigation among these profiles.

    Note that currently we only hold one profile but we're planning to add other
    kinds of profiles such as segmented profiles here.
    """

    @staticmethod
    def read(multi_profile_file: str) -> "ResultSet":
        # TODO: parse multiple profile
        view = DatasetProfileView.read(multi_profile_file)
        return ViewResultSet(view=view)

    @staticmethod
    def reader(name: str = "local") -> "ResultSetReader":
        reader = Readers.get(name)
        return ResultSetReader(reader=reader)

    def writer(self, name: str = "local") -> "ResultSetWriter":
        writer = Writers.get(name)
        return ResultSetWriter(results=self, writer=writer)

    @abstractmethod
    def view(self) -> Optional[DatasetProfileView]:
        pass

    @abstractmethod
    def profile(self) -> Optional[DatasetProfile]:
        pass

    @property
    def metadata(self) -> Optional[Dict[str, str]]:
        if hasattr(self, "_metadata"):
            return self._metadata
        return None

    def get_writables(self) -> Optional[List[Writable]]:
        return [self.view()]

    def set_dataset_timestamp(self, dataset_timestamp: datetime) -> None:
        ensure_timezone(dataset_timestamp)
        profile = self.profile()
        if profile is None:
            raise ValueError("Cannot set timestamp on a result set without a profile!")
        else:
            profile.set_dataset_timestamp(dataset_timestamp)

    @property
    def count(self) -> int:
        result = 0
        if self.view() is not None:
            result = 1
        return result

    @property
    def performance_metrics(self) -> Optional[ModelPerformanceMetrics]:
        profile = self.profile()
        if profile:
            return profile.model_performance_metrics
        return None

    def add_model_performance_metrics(self, metrics: ModelPerformanceMetrics) -> None:
        profile = self.profile()
        if profile:
            profile.add_model_performance_metrics(metrics)
        else:
            raise ValueError("Cannot add performance metrics to a result set with no profile!")

    def add_metric(self, name: str, metric: Metric) -> None:
        profile = self.profile()
        if profile:
            profile.add_dataset_metric(name, metric)
        else:
            raise ValueError(f"Cannot add {name} metric {metric} to a result set with no profile!")

    def merge(self, other: "ResultSet") -> "ResultSet":
        raise NotImplementedError("This result set did not define merge, see ProfileResultSet or SegmentedResulSet.")


class ViewResultSet(ResultSet):
    def __init__(self, view: DatasetProfileView) -> None:
        self._view = view

    def profile(self) -> Optional[DatasetProfile]:
        raise ValueError("No profile available. Can only view")

    def view(self) -> Optional[DatasetProfileView]:
        return self._view

    @property
    def metadata(self) -> Optional[Dict[str, str]]:
        view = self.view()
        if view:
            return view.metadata
        else:
            return None

    @staticmethod
    def zero() -> "ViewResultSet":
        return ViewResultSet(DatasetProfileView.zero())

    def merge(self, other: "ResultSet") -> "ViewResultSet":
        if other is None:
            return self

        lhs_view = self._view or DatasetProfileView.zero()
        if not isinstance(other, (ViewResultSet, ProfileResultSet)):
            logger.warning(f"Merging potentially incompatible ViewResultSet and {type(other)}")
        return ViewResultSet(lhs_view.merge(other.view()))

    def set_dataset_timestamp(self, dataset_timestamp: datetime) -> None:
        ensure_timezone(dataset_timestamp)
        view = self.view()
        if view is None:
            raise ValueError("Cannot set timestamp on a view result set without a view!")
        else:
            view.dataset_timestamp = dataset_timestamp


class ProfileResultSet(ResultSet):
    def __init__(self, profile: DatasetProfile) -> None:
        self._profile = profile

    def profile(self) -> Optional[DatasetProfile]:
        return self._profile

    def view(self) -> Optional[DatasetProfileView]:
        return self._profile.view()

    @property
    def metadata(self) -> Optional[Dict[str, str]]:
        view = self.view()
        if view:
            return view.metadata
        else:
            return None

    @staticmethod
    def zero() -> "ProfileResultSet":
        return ProfileResultSet(DatasetProfile())

    def merge(self, other: "ResultSet") -> ViewResultSet:
        if other is None:
            return self

        lhs_profile = self.view() or DatasetProfileView()
        if not isinstance(other, (ProfileResultSet, ViewResultSet)):
            logger.error(f"Merging potentially incompatible ProfileResultSet and {type(other)}")
        return ViewResultSet(lhs_profile.merge(other.view()))


class SegmentedResultSet(ResultSet):
    def __init__(
        self,
        segments: Dict[str, Dict[Segment, Union[DatasetProfile, DatasetProfileView]]],
        partitions: Optional[List[SegmentationPartition]] = None,
        metrics: Optional[Dict[str, Any]] = None,
        properties: Optional[Dict[str, Any]] = None,
    ) -> None:
        self._segments = segments
        self._partitions = partitions
        self._metrics = metrics or dict()
        self._dataset_properties = properties or dict()
        self._metadata: Dict[str, str] = dict()

    def profile(self, segment: Optional[Segment] = None) -> Optional[Union[DatasetProfile, DatasetProfileView]]:
        if not self._segments:
            return None
        elif segment:
            paritition_segments = self._segments.get(segment.parent_id)
            return paritition_segments.get(segment) if paritition_segments else None
        # special case return a single segment if there is only one, even if not specified
        elif len(self._segments) == 1:
            for partition_id in self._segments:
                segments = self._segments.get(partition_id)
                number_of_segments = len(segments) if segments else 0
                if number_of_segments == 1:
                    single_dictionary: Dict[Segment, Union[DatasetProfile, DatasetProfileView]] = (
                        segments if segments else dict()
                    )
                    for key in single_dictionary:
                        return single_dictionary[key]

        raise ValueError(
            f"A profile was requested from a segmented result set without specifying which segment to return: {self._segments}"
        )

    @property
    def dataset_properties(self) -> Optional[Dict[str, Any]]:
        return self._dataset_properties

    @property
    def dataset_metrics(self) -> Optional[Dict[str, Any]]:
        return self._metrics

    @property
    def partitions(self) -> Optional[List[SegmentationPartition]]:
        return self._partitions

    def set_dataset_timestamp(self, dataset_timestamp: datetime) -> None:
        # TODO: pull dataset_timestamp up into a result set scoped property
        segment_keys = self.segments()
        if not segment_keys:
            return
        for key in segment_keys:
            profile = self.profile(segment=key)
            if profile:
                profile.set_dataset_timestamp(dataset_timestamp)

    def segments(self, restrict_to_parition_id: Optional[str] = None) -> Optional[List[Segment]]:
        result: Optional[List[Segment]] = None
        if not self._segments:
            return result
        result = list()
        if restrict_to_parition_id:
            segments = self._segments.get(restrict_to_parition_id)
            if segments:
                for segment in segments:
                    result.append(segment)
        else:
            for partition_id in self._segments:
                for segment in self._segments[partition_id]:
                    result.append(segment)
        return result

    @property
    def count(self) -> int:
        result = 0
        if self._segments:
            for segment_key in self._segments:
                profiles = self._segments[segment_key]
                result += len(profiles)
        return result

    def segments_in_partition(
        self, partition: SegmentationPartition
    ) -> Optional[Dict[Segment, Union[DatasetProfile, DatasetProfileView]]]:
        return self._segments.get(partition.id)

    def view(self, segment: Optional[Segment] = None) -> Optional[DatasetProfileView]:
        result = self.profile(segment)
        view = result.view() if isinstance(result, DatasetProfile) else result
        return view

    def get_model_performance_metrics_for_segment(self, segment: Segment) -> Optional[ModelPerformanceMetrics]:
        if segment.parent_id in self._segments:
            profile = self._segments[segment.parent_id][segment]
            if not profile:
                logger.warning(
                    f"No profile found for segment {segment} when requesting model performance metrics, returning None!"
                )
                return None

            if isinstance(profile, DatasetProfileView):
                view = profile
            else:
                if hasattr(profile, "view"):
                    view = profile.view()
                else:
                    logger.error(
                        f"Unexpected type: {type(profile)} -> {profile}, cannot check for model performance metrics."
                    )
                    return None
            return view.model_performance_metrics
        return None

    def get_writables(self) -> Optional[List[Writable]]:
        results: Optional[List[Writable]] = None
        if self._segments:
            results = []
            logger.info(f"Building list of: {self.count} SegmentedDatasetProfileViews in SegmentedResultSet.")
            # TODO: handle more than one partition
            if not self.partitions:
                raise ValueError(
                    f"Building list of: {self.count} SegmentedDatasetProfileViews in SegmentedResultSet but no partitions found: {self.partitions}."
                )
            if len(self.partitions) > 1:
                logger.error(
                    f"Building list of: {self.count} SegmentedDatasetProfileViews in SegmentedResultSet but found more than one partition: "
                    f"{self.partitions}. Using first partition only!!"
                )
            first_partition = self.partitions[0]
            segments = self.segments_in_partition(first_partition)
            if segments:
                for segment_key in segments:
                    profile = segments[segment_key]
                    metric = self.get_model_performance_metrics_for_segment(segment_key)
                    if metric:
                        profile.add_model_performance_metrics(metric)
                        logger.debug(
                            f"Found model performance metrics: {metric}, adding to segmented profile: {segment_key}."
                        )
                    view = profile.view() if isinstance(profile, DatasetProfile) else profile
                    segmented_profile = SegmentedDatasetProfileView(
                        profile_view=view, segment=segment_key, partition=first_partition
                    )
                    if self.metadata:
                        segmented_profile.metadata.update(self.metadata)
                    results.append(segmented_profile)
            else:
                logger.warning(
                    f"Found no segments in partition: {first_partition} even though we have: {self.count} segments overall"
                )
            logger.info(f"From list of: {self.count} SegmentedDatasetProfileViews using {len(results)}")
        else:
            logger.warning(
                f"Attempt to build segmented results for writing but there are no segments in this result set: {self._segments}. returning None."
            )
        return results

    def add_metrics_for_segment(self, metrics: ModelPerformanceMetrics, segment: Segment) -> None:
        if segment.parent_id in self._segments:
            profile = self._segments[segment.parent_id][segment]
            profile.add_model_performance_metrics(metrics)

    @staticmethod
    def zero() -> "SegmentedResultSet":
        return SegmentedResultSet(segments=dict())

    @property
    def model_performance_metric(self) -> Optional[ModelPerformanceMetrics]:
        if self._metrics:
            return self._metrics.get(_MODEL_PERFORMANCE)
        return None

    def add_model_performance_metrics(self, metrics: ModelPerformanceMetrics) -> None:
        if self._metrics:
            self._metrics[_MODEL_PERFORMANCE] = metrics
        else:
            self._metrics = {_MODEL_PERFORMANCE: metrics}

    def add_metric(self, name: str, metric: Metric) -> None:
        if not self._metrics:
            self._metrics = dict()
        self._metrics[name] = metric

    def merge(self, other: "ResultSet") -> "SegmentedResultSet":
        if other is None:
            return self

        if isinstance(other, SegmentedResultSet):
            lhs_partitions: List[SegmentationPartition] = self.partitions or list()
            rhs_partitions: List[SegmentationPartition] = other.partitions or list()

            lhs_segments: Dict[str, Dict[Segment, DatasetProfile]] = self._segments or dict()
            rhs_segments: Dict[str, Dict[Segment, DatasetProfile]] = other._segments or dict()

            merged_segments = _merge_partitioned_segments(lhs_segments, rhs_segments)
            merged_metrics = _merge_metrics(self.dataset_metrics, other.dataset_metrics)
            merged_partitions = _merge_partitions(lhs_partitions, rhs_partitions)
            properties = _accumulate_properties(self._dataset_properties, other.dataset_properties)

            return SegmentedResultSet(merged_segments, merged_partitions, metrics=merged_metrics, properties=properties)
        else:
            raise ValueError(f"Cannot merge incompatible SegmentedResultSet and {type(other)}")