whylabs/whylogs-python

View on GitHub
python/whylogs/core/view/column_profile_view.py

Summary

Maintainability
A
3 hrs
Test Coverage
import logging
from typing import Any, Dict, List, Optional, TypeVar

from whylogs.core.configs import SummaryConfig
from whylogs.core.errors import DeserializationError, UnsupportedError
from whylogs.core.metrics import StandardMetric
from whylogs.core.metrics.metrics import _METRIC_DESERIALIZER_REGISTRY, Metric
from whylogs.core.proto import ColumnMessage, MetricComponentMessage, MetricMessage

logger = logging.getLogger(__name__)

METRIC = TypeVar("METRIC", bound=Metric)


class ColumnProfileView(object):
    def __init__(
        self,
        metrics: Dict[str, METRIC],
        success_count: int = 0,
        failure_count: int = 0,
    ):
        self._metrics: Dict[str, METRIC] = metrics.copy()
        self._success_count = success_count
        self._failure_count = failure_count

    def merge(self, other: "ColumnProfileView") -> "ColumnProfileView":
        all_names = set(self._metrics.keys()).union(other._metrics.keys())
        merged_metrics: Dict[str, Metric] = {}
        for n in all_names:
            lhs = self._metrics.get(n)
            rhs = other._metrics.get(n)

            res = lhs
            if lhs is None:
                res = rhs
            elif rhs is not None:
                res = lhs + rhs
            assert res is not None
            merged_metrics[n] = res
        return ColumnProfileView(
            metrics=merged_metrics,
            success_count=self._success_count + other._success_count,
            failure_count=self._failure_count + other._failure_count,
        )

    def __add__(self, other: "ColumnProfileView") -> "ColumnProfileView":
        return self.merge(other)

    def __getstate__(self) -> bytes:
        return self.serialize()

    def __setstate__(self, serialized_profile: bytes):
        profile = ColumnProfileView.deserialize(serialized_profile)
        self.__dict__.update(profile.__dict__)

    def serialize(self) -> bytes:
        return self.to_protobuf().SerializeToString()

    @classmethod
    def deserialize(cls, serialized_profile: bytes) -> "ColumnProfileView":
        column_message = ColumnMessage.FromString(bytes(serialized_profile))
        return ColumnProfileView.from_protobuf(column_message)

    def get_metric(self, m_name: str) -> Optional[METRIC]:
        return self._metrics.get(m_name)

    def to_protobuf(self) -> ColumnMessage:
        res: Dict[str, MetricComponentMessage] = {}
        for m_name, m in self._metrics.items():
            for mc_name, mc in m.to_protobuf().metric_components.items():
                if not m.exclude_from_serialization:
                    res[f"{m.namespace}/{mc_name}"] = mc
        return ColumnMessage(metric_components=res)

    def get_metric_component_paths(self) -> List[str]:
        res: List[str] = []
        for m_name, m in self._metrics.items():
            for mc_name in m.get_component_paths():
                res.append(f"{m.namespace}/{mc_name}")
        return res

    def get_metric_names(self) -> List[str]:
        return [name for name in self._metrics]

    def get_metrics(self) -> List[Metric]:
        return list(self._metrics.values())

    def to_summary_dict(
        self, *, column_metric: Optional[str] = None, cfg: Optional[SummaryConfig] = None
    ) -> Dict[str, Any]:
        cfg = cfg or SummaryConfig()
        res = {}
        for metric_name, metric in self._metrics.items():
            if column_metric is not None and metric_name != column_metric:
                continue
            try:
                m_sum = metric.to_summary_dict(cfg=cfg)
                for k, v in m_sum.items():
                    res[f"{metric_name}/{k}"] = v
            except NotImplementedError:
                logger.warning(f"No summary implemented for {metric_name}")

        if column_metric is not None and len(res) == 0:
            raise UnsupportedError(f"No metric available for requested column metric: {column_metric}")

        return res

    @classmethod
    def zero(cls, msg: ColumnMessage) -> "ColumnProfileView":
        return ColumnProfileView(metrics={})

    @classmethod
    def from_protobuf(cls, msg: ColumnMessage) -> "ColumnProfileView":
        # importing to trigger registration of non-standard metrics
        import whylogs.experimental.core.metrics.udf_metric  # noqa

        # These require numpy & PIL, but we assume users will install
        # whylogs with the optional dependencies and import the metric
        # modules if they want to use them. So it should be safe to
        # ignore the import failures.
        try:
            import whylogs.experimental.extras.embedding_metric  # noqa
        except ModuleNotFoundError:
            pass
        try:
            import whylogs.experimental.extras.nlp_metric  # noqa
        except ModuleNotFoundError:
            pass
        try:
            import whylogs.extras.image_metric  # noqa
        except ModuleNotFoundError:
            pass

        result_metrics: Dict[str, Metric] = {}
        metric_messages: Dict[str, Dict[str, MetricComponentMessage]] = {}
        for full_path, c_msg in msg.metric_components.items():
            metric_name = full_path.split("/")[0]
            metric_components = metric_messages.get(metric_name)

            if metric_components is None:
                metric_components = {}
            metric_messages[metric_name] = metric_components

            c_key = full_path[len(metric_name) + 1 :]
            metric_components[c_key] = c_msg
        for m_name, metric_components in metric_messages.items():
            m_enum = StandardMetric.__members__.get(m_name)
            if m_enum is None:
                if m_name in _METRIC_DESERIALIZER_REGISTRY:
                    metric_class = _METRIC_DESERIALIZER_REGISTRY[m_name]
                else:
                    raise UnsupportedError(f"Unsupported metric: {m_name}")
            else:
                metric_class = m_enum.value

            m_msg = MetricMessage(metric_components=metric_components)
            try:
                deserialized_metric = metric_class.from_protobuf(m_msg)
                result_metrics[m_name] = deserialized_metric
            except Exception as error:  # noqa
                raise DeserializationError(
                    f"Failed to deserialize metric: {m_name}:{error}; possibly missing dependencies"
                )

        return ColumnProfileView(metrics=result_metrics)

    @classmethod
    def from_bytes(cls, data: bytes) -> "ColumnProfileView":
        msg = ColumnMessage()
        if isinstance(data, bytearray):
            data = bytes(data)
        msg.ParseFromString(data)
        return ColumnProfileView.from_protobuf(msg)