whylabs/whylogs-python

View on GitHub
python/whylogs/core/metrics/condition_count_metric.py

Summary

Maintainability
B
5 hrs
Test Coverage
import logging
import re
from copy import copy
from dataclasses import dataclass, field
from itertools import chain
from typing import Any, Callable, Dict, List, Optional, Set, Union

from whylogs.core.configs import SummaryConfig
from whylogs.core.metrics.metric_components import IntegralComponent, MetricComponent
from whylogs.core.metrics.metrics import (
    Metric,
    MetricConfig,
    OperationResult,
    register_metric,
)
from whylogs.core.preprocessing import PreprocessedColumn
from whylogs.core.proto import MetricMessage
from whylogs.core.relations import Relation as Rel

logger = logging.getLogger(__name__)

# For backward compatability
Relation = Rel  # type: ignore


# relation() is annoying, use Predicate instead
def relation(op: Relation, value: Union[str, int, float]) -> Callable[[Any], bool]:  # type: ignore
    if op == Relation.match:  # type: ignore
        return lambda x: re.compile(value).match(x)  # type: ignore
    if op == Relation.fullmatch:  # type: ignore
        return lambda x: re.compile(value).fullmatch(x)  # type: ignore
    if op == Relation.equal:  # type: ignore
        return lambda x: x == value  # type: ignore
    if op == Relation.less:  # type: ignore
        return lambda x: x < value  # type: ignore
    if op == Relation.leq:  # type: ignore
        return lambda x: x <= value  # type: ignore
    if op == Relation.greater:  # type: ignore
        return lambda x: x > value  # type: ignore
    if op == Relation.geq:  # type: ignore
        return lambda x: x >= value  # type: ignore
    if op == Relation.neq:  # type: ignore
        return lambda x: x != value  # type: ignore
    raise ValueError("Unknown ConditionCountMetric predicate")


def and_relations(left: Callable[[Any], bool], right: Callable[[Any], bool]) -> Callable[[Any], bool]:
    return lambda x: left(x) and right(x)


def or_relations(left: Callable[[Any], bool], right: Callable[[Any], bool]) -> Callable[[Any], bool]:
    return lambda x: left(x) or right(x)


def not_relation(relation: Callable[[Any], bool]) -> Callable[[Any], bool]:
    return lambda x: not relation(x)


@dataclass(frozen=True)
class Condition:
    relation: Callable[[Any], bool]
    throw_on_failure: bool = False
    log_on_failure: bool = False
    actions: List[Callable[[str, str, Any], None]] = field(default_factory=list)


@dataclass(frozen=True)
class ConditionCountConfig(MetricConfig):
    conditions: Dict[str, Condition] = field(default_factory=dict)
    exclude_from_serialization: bool = False


@dataclass(frozen=True)
class ConditionCountMetric(Metric):
    conditions: Dict[str, Condition]
    total: IntegralComponent
    matches: Dict[str, IntegralComponent] = field(default_factory=dict)
    hide_from_serialization: bool = False

    @property
    def exclude_from_serialization(self) -> bool:
        return self.hide_from_serialization

    @property
    def namespace(self) -> str:
        return "condition_count"

    def __post_init__(self) -> None:
        if "total" in self.conditions.keys():
            raise ValueError("Condition cannot be named 'total'")

        for cond_name in self.conditions.keys():
            if cond_name not in self.matches:
                self.matches[cond_name] = IntegralComponent(0)

    def merge(self, other: "ConditionCountMetric") -> "ConditionCountMetric":
        if set(self.matches.keys()) != set(other.matches.keys()):
            # log warning?
            matches = {cond_name: IntegralComponent(comp.value) for cond_name, comp in self.matches.items()}
            total = self.total.value
        else:
            matches = {
                cond_name: IntegralComponent(self.matches[cond_name].value + other.matches[cond_name].value)
                for cond_name in self.matches.keys()
            }
            total = self.total.value + other.total.value

        return ConditionCountMetric(
            copy(self.conditions),
            IntegralComponent(total),
            matches,
            hide_from_serialization=self.hide_from_serialization or other.hide_from_serialization,
        )

    def add_conditions(self, conditions: Dict[str, Condition]) -> None:
        if "total" in conditions.keys():
            raise ValueError("Condition cannot be named 'total'")
        for cond_name, cond in conditions.items():
            self.conditions[cond_name] = cond
            self.matches[cond_name] = IntegralComponent(0)

    def get_component_paths(self) -> List[str]:
        paths: List[str] = [
            "total",
        ] + list(self.conditions.keys())
        return paths

    def columnar_update(self, data: PreprocessedColumn) -> OperationResult:
        if data.len <= 0:
            return OperationResult.ok(0)

        count = 0
        failed_conditions: Set[str] = set()
        for datum in list(chain.from_iterable(data.raw_iterator())):
            count += 1
            for cond_name, condition in self.conditions.items():
                try:
                    if condition.relation(datum):
                        self.matches[cond_name].set(self.matches[cond_name].value + 1)
                    else:
                        failed_conditions.add(cond_name)
                        for action in condition.actions:
                            action(self.namespace, cond_name, datum)

                except Exception as e:  # noqa
                    logger.debug(e)
                    failed_conditions.add(cond_name)

        self.total.set(self.total.value + count)
        if failed_conditions:
            if condition.log_on_failure:
                logger.warning(f"Condition(s) {', '.join(failed_conditions)} failed")

            if condition.throw_on_failure:
                raise ValueError(f"Condition(s) {', '.join(failed_conditions)} failed")

        return OperationResult.ok(count)

    @classmethod
    def zero(cls, config: Optional[MetricConfig] = None) -> "ConditionCountMetric":
        config = config or ConditionCountConfig()
        if not isinstance(config, ConditionCountConfig):
            raise ValueError("ConditionCountMetric.zero() requires ConditionCountConfig argument")

        metric = ConditionCountMetric(
            conditions=copy(config.conditions),
            total=IntegralComponent(0),
            hide_from_serialization=config.exclude_from_serialization,
        )
        return metric

    def to_protobuf(self) -> MetricMessage:
        msg = {"total": self.total.to_protobuf()}
        for cond_name in self.conditions.keys():
            msg[cond_name] = self.matches[cond_name].to_protobuf()

        return MetricMessage(metric_components=msg)

    def to_summary_dict(self, cfg: Optional[SummaryConfig] = None) -> Dict[str, Any]:
        summary = {"total": self.total.value}
        for cond_name in self.matches.keys():
            summary[cond_name] = self.matches[cond_name].value

        return summary

    @classmethod
    def from_protobuf(cls, msg: MetricMessage) -> "ConditionCountMetric":
        cond_names: Set[str] = set(msg.metric_components.keys())
        cond_names.remove("total")

        conditions = {cond_name: lambda x: False for cond_name in cond_names}
        total = MetricComponent.from_protobuf(msg.metric_components["total"])
        matches = {
            cond_name: MetricComponent.from_protobuf(msg.metric_components[cond_name]) for cond_name in cond_names
        }
        return ConditionCountMetric(
            conditions,
            total,
            matches,
        )


# Register it so Multimetric and ProfileView can deserialize
register_metric(ConditionCountMetric)