
View on GitHub


0 mins
Test Coverage
from typing import Any, Callable, Dict, Optional, TypeVar

import whylogs_sketching as ds  # type: ignore

from whylogs.core.metrics.decorators import (
from whylogs.core.proto import MetricComponentMessage

M = TypeVar("M")
NUM = TypeVar("NUM", float, int)


class _Deserializer(DecoratedFunction):
    def __init__(self, *, func: FuncType, name: str):  # noqa
        self._name = name
        self._func = func

    def name(self) -> Optional[str]:
        return self._name

    def __call__(self, *, msg: MetricComponentMessage) -> Any:
        return self._func(msg)

    def build(cls, func: FuncType, name: str) -> "_Deserializer":  # noqa
        return _Deserializer(func=func, name=name)

# these are all built in and initialized at imports
_TYPED_DESERIALIZERS: Dict[type, _Deserializer] = {}
_ID_DESERIALIZERS: Dict[int, _Deserializer] = {}

def _builtin_deserializer(*, name: str) -> Callable[[Callable], _Deserializer]:
    """Decorator for a builtin field aggregator.

    Note that since these are built in, they MUST be unique and stable over time.
        name: a human readable string. This must be unique for validation.


    def decorated(func: FuncType) -> _Deserializer:
        annotations: Dict[str, type] = func.__annotations__.copy()
        r_type = annotations.pop("return")

        arg_len = len(annotations.items())
        if arg_len != 1:
            raise ValueError(f"Expected 1 argument, got: {arg_len}")

        if next(iter(annotations.values())) != MetricComponentMessage:
            raise ValueError("Invalid function type: input argument not MetricComponentMessage")

        return _func_wrapper(
            func=func, key=r_type, name=f"builtin.{r_type}", wrapper_dict=_TYPED_DESERIALIZERS, clazz=_Deserializer

    return decorated  # type: ignore

def _int(msg: MetricComponentMessage) -> int:
    field = msg.WhichOneof("value")
    if field != "n":
        raise ValueError(f"Unsupported field: {field}")
    return msg.n

def _float(msg: MetricComponentMessage) -> float:
    field = msg.WhichOneof("value")
    if field != "d":
        raise ValueError(f"Unsupported field: {field}")
    return msg.d

def _kll(msg: MetricComponentMessage) -> ds.kll_doubles_sketch:
    field = msg.WhichOneof("value")
    if field != "kll":
        raise ValueError(f"Unsupported field: {field}")

    return ds.kll_doubles_sketch.deserialize(msg.kll.sketch)

def _hll(msg: MetricComponentMessage) -> ds.hll_sketch:
    field = msg.WhichOneof("value")
    if field != "hll":
        raise ValueError(f"Unsupported field: {field}")

    sketch = ds.hll_sketch.deserialize(msg.hll.sketch)
    return sketch

def _fs(msg: MetricComponentMessage) -> ds.frequent_strings_sketch:
    field = msg.WhichOneof("value")
    if field != "frequent_items":
        raise ValueError(f"Unsupported field: {field}")

    return ds.frequent_strings_sketch.deserialize(msg.frequent_items.sketch)

class DeserializerRegistry:
    def __init__(self) -> None:
        self._typed_deserializer = _TYPED_DESERIALIZERS.copy()
        self._id_deserializer = _ID_DESERIALIZERS.copy()

    def get(self, *, mtype: Optional[type] = None, type_id: int = 0) -> Optional[_Deserializer]:
        if mtype is None and type_id <= 0:
            raise ValueError("Either mtype or a positive ID must be specified")
        result = None
        if type_id > 0:
            result = self._id_deserializer.get(type_id)

        if result is None and mtype is not None:
            result = self._typed_deserializer.get(mtype)
        return result


def _get_or_create_registry() -> DeserializerRegistry:
    if _STANDARD_REGISTRY is None:
        _STANDARD_REGISTRY = DeserializerRegistry()


def get_deserializer(
    *, mtype: Optional[type] = None, type_id: int = 0, registry: Optional[DeserializerRegistry] = None
) -> Optional[_Deserializer]:
    if registry is None:
        registry = _get_or_create_registry()
    return registry.get(mtype=mtype, type_id=type_id)

def deserializer(*, type_id: int, registry: Optional[DeserializerRegistry] = None):  # type: ignore
    if type_id < _MAX_BUILT_IN_ID:
        raise ValueError("Custom aggregator ID must be equal or greater than 100")
    if registry is None:
        registry = _get_or_create_registry()

    return _decorate_func(
        key=type_id, name=f"custom.{type_id}", wrapper_dict=registry._id_deserializer, clazz=_Deserializer