whylabs/whylogs-python

View on GitHub
python/whylogs/experimental/extras/embedding_metric.py

Summary

Maintainability
A
2 hrs
Test Coverage
import logging
from dataclasses import dataclass, field
from enum import Enum
from itertools import chain
from typing import List, Optional

from whylogs.core.metrics import StandardMetric
from whylogs.core.metrics.metrics import MetricConfig, OperationResult, register_metric
from whylogs.core.metrics.multimetric import MultiMetric
from whylogs.core.preprocessing import PreprocessedColumn
from whylogs.core.proto import MetricMessage
from whylogs.core.stubs import np, sklp
from whylogs.experimental.extras.matrix_component import MatrixComponent

logger = logging.getLogger(__name__)


class DistanceFunction(Enum):
    euclidean = sklp.euclidean_distances
    cosine = sklp.cosine_distances


@dataclass(frozen=True)
class EmbeddingConfig(MetricConfig):
    """
    The rows of references are the reference vectors. A shape of (1, 1) indicates
    there's no reference matrix and the metric will not be updatable. It should still
    be mergeable with compatible metrics (details of compatability TBD). The rows
    must be in the same order as labels. If labels are not provided, "0" ... "n" will
    be used, where n is the number of rows in the references matrix.
    """

    references: np.ndarray = field(default_factory=lambda: np.zeros((1, 1)))
    labels: Optional[List[str]] = None
    distance_fn: DistanceFunction = DistanceFunction.cosine
    serialize_references: bool = True  # should references be included in protobuf message?

    # TODO: limit refeence size

    def __post_init__(self) -> None:
        if len(self.references.shape) != 2:
            raise ValueError("Embedding reference matrix must be 2 dimensional")

        if self.labels:
            if len(self.labels) != self.references.shape[0]:
                raise ValueError(
                    f"Number of labels ({len(self.labels)}) must match number of reference vectors ({self.references.shape[1]})"
                )


@dataclass
class EmbeddingMetric(MultiMetric):
    references: MatrixComponent
    labels: List[str]
    distance_fn: DistanceFunction
    serialize_references: bool

    def __post_init__(self):
        """
        Each label has a {label}_distance submetric that tracks the distribution of
        distances from the label's reference vector to every logged vector.

        There is also a closest submetric that tracks how often a label's reference
        vector is the closest to the logged vectors.
        """

        submetrics = {
            f"{label}_distance": {
                "distribution": StandardMetric.distribution.zero(),
                "counts": StandardMetric.counts.zero(),
                "types": StandardMetric.types.zero(),
                "cardinality": StandardMetric.cardinality.zero(),
            }
            for label in self.labels
        }
        submetrics.update(
            {
                "closest": {
                    "frequent_items": StandardMetric.frequent_items.zero(),
                    "counts": StandardMetric.counts.zero(),
                    "types": StandardMetric.types.zero(),
                    "cardinality": StandardMetric.cardinality.zero(),
                }
            }
        )
        super().__init__(submetrics)
        # sort labels and permute reference matrix rows to match to support deserialization
        label_indices = np.argsort(self.labels).tolist()
        self.labels = sorted(self.labels)
        self.references = MatrixComponent(self.references.value[label_indices, :])

    @property
    def namespace(self) -> str:
        return "embedding"

    def merge(self, other: "EmbeddingMetric") -> "EmbeddingMetric":
        if self.references.value.shape != other.references.value.shape:
            if other.references.value.shape == (1, 1):
                # TODO: handle merging with other.serialize_references==False better
                # The (1, 1) shape indicates the other metric was created without a reference matrix.
                # It can't have meaningful data in it, so just return myself
                logger.warning("Attempt to merge with unconfigured EmbeddingMetric; ignored")
                return self
            if self.references.value.shape == (1, 1):
                # See comment above
                logger.warning("Attempt to merge with unconfigured EmbeddingMetric; ignored")
                return other
            raise ValueError("Attempt to merge incompatible EbeddingMetrics")

        if (
            self.labels != other.labels
            or not (self.references.value == other.references.value).all()
            # TODO: maybe   or self.distance_fn != other.distance_fn ? warn if != ?
        ):
            raise ValueError("Attempt to merge incompatible EbeddingMetrics")

        result = EmbeddingMetric(self.references, self.labels, self.distance_fn, self.serialize_references)
        result.submetrics = self.merge_submetrics(other)
        return result

    def to_protobuf(self) -> MetricMessage:
        msg = {}
        for sub_name, metrics in self.submetrics.items():
            for namespace, metric in metrics.items():
                sub_msg = metric.to_protobuf()
                for comp_name, comp_msg in sub_msg.metric_components.items():
                    msg[f"{sub_name}:{namespace}/{comp_name}"] = comp_msg
        if self.serialize_references:
            msg["references"] = self.references.to_protobuf()

        return MetricMessage(metric_components=msg)

    def _update_submetrics(self, submetric: str, data: PreprocessedColumn) -> None:
        for key in self.submetrics[submetric].keys():
            self.submetrics[submetric][key].columnar_update(data)

    def columnar_update(self, data: PreprocessedColumn) -> OperationResult:
        reference_dim = self.references.value.shape[1]  # number of columns in reference matrix
        successes = 0
        failures = 0
        pandas_tensors = data.pandas.tensors if data.pandas.tensors is not None else []
        for matrix in chain(data.list.tensors or [], pandas_tensors):  # TODO: stack these
            if len(matrix.shape) == 1:
                matrix = matrix.reshape((1, matrix.shape[0]))
            if len(matrix.shape) != 2 or matrix.shape[1] != reference_dim:
                logger.warn(
                    f"EmbeddingMetric requires 1 x {reference_dim} matrices; got tensor with shape {matrix.shape}"
                )
                failures += 1
                continue

            ref_dists = self.distance_fn(matrix, self.references.value)  # type: ignore
            ref_closest = np.argmin(ref_dists, axis=1)

            for i in range(ref_dists.shape[1]):
                self._update_submetrics(f"{self.labels[i]}_distance", PreprocessedColumn.apply(ref_dists[:, i]))

            closest = [self.labels[i] for i in ref_closest]
            self._update_submetrics("closest", PreprocessedColumn.apply(np.asarray(closest)))
            successes += 1

        return OperationResult(failures, successes)

    @classmethod
    def from_protobuf(cls, msg: MetricMessage) -> "EmbeddingMetric":
        if "references" in msg.metric_components:
            references = MatrixComponent.from_protobuf(msg.metric_components["references"])
            msg.metric_components.pop("references")  # it's not a submetric's component
            serialize_references = True
        else:
            references = np.zeros((1, 1))  # indicate I don't have a usuable reference matrix
            serialize_references = False

        submetrics = EmbeddingMetric.submetrics_from_protobuf(msg)

        # figure out what my labels were from the {label}_distance submetric names
        labels: List[str] = []
        for submetric_name in submetrics.keys():
            if submetric_name.endswith("_distance"):
                labels.append(submetric_name[:-9])
        labels = sorted(labels)  # the rows should already be in this order

        result = EmbeddingMetric(
            references=references,
            labels=labels,
            distance_fn=DistanceFunction.cosine,  # not updatable after deserialization
            serialize_references=serialize_references,
        )
        result.submetrics = submetrics
        return result

    @classmethod
    def zero(cls, cfg: Optional[EmbeddingConfig] = None) -> "EmbeddingMetric":
        cfg = cfg or EmbeddingConfig()
        if not isinstance(cfg, EmbeddingConfig):
            raise ValueError("EmbeddingMetric.zero() requires EmbeddingConfig argument")

        return EmbeddingMetric(
            references=MatrixComponent(cfg.references),
            labels=cfg.labels or [str(i) for i in range(cfg.references.shape[0])],
            distance_fn=cfg.distance_fn,
            serialize_references=cfg.serialize_references,
        )


# Register it so Multimetric and ProfileView can deserialize
register_metric(EmbeddingMetric)