whylabs/whylogs-python

View on GitHub
python/whylogs/api/logger/segment_processing.py

Summary

Maintainability
B
5 hrs
Test Coverage
import logging
import math
from functools import reduce
from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Tuple

from whylogs.api.logger.result_set import SegmentedResultSet
from whylogs.api.logger.segment_cache import SegmentCache
from whylogs.core import DatasetSchema
from whylogs.core.dataset_profile import DatasetProfile
from whylogs.core.input_resolver import _pandas_or_dict
from whylogs.core.segment import Segment
from whylogs.core.segmentation_partition import (
    ColumnMapperFunction,
    SegmentationPartition,
    SegmentFilter,
)
from whylogs.core.stubs import pd

logger = logging.getLogger(__name__)

_MAX_SEGMENT_PARTITIONS = 10


def _process_segment(
    segmented_data: Any,
    segment_key: Segment,
    segments: Dict[Segment, Any],
    schema: DatasetSchema,
    segment_cache: Optional[SegmentCache] = None,
):
    profile = None
    if segment_cache:
        profile = segment_cache.get_or_create_matching_profile(segment_key)

    if profile is None:
        profile = DatasetProfile(schema)

    profile.track(segmented_data, execute_udfs=False)
    segments[segment_key] = profile


def _get_segment_from_group_key(group_key, partition_id, explicit_keys: Tuple[str, ...] = ()) -> Tuple[str, ...]:
    if isinstance(group_key, str):
        segment_tuple_key: Tuple[str, ...] = (group_key,)
    elif isinstance(group_key, (List, Iterable, Iterator)):
        segment_tuple_key = tuple(str(k) for k in group_key)
    else:
        segment_tuple_key = (str(group_key),)

    return Segment(segment_tuple_key + explicit_keys, partition_id)


def _is_nan(x):
    try:
        return math.isnan(x)
    except TypeError:
        return False


def _process_simple_partition(
    partition_id: str,
    schema: DatasetSchema,
    segments: Dict[Segment, Any],
    columns: List[str],
    pandas: Optional[pd.DataFrame] = None,
    row: Optional[Mapping[str, Any]] = None,
    segment_cache: Optional[SegmentCache] = None,
    segment_key_values: Optional[Dict[str, str]] = None,
):
    explicit_keys = (
        tuple(str(segment_key_values[k]) for k in sorted(segment_key_values.keys())) if segment_key_values else tuple()
    )
    if pandas is not None:
        # simple means we can segment on column values
        grouped_data = pandas.groupby(columns)
        for group in grouped_data.groups.keys():
            if isinstance(group, tuple) and any([_is_nan(x) for x in group]):
                evaluations = []
                for val, col in zip(group, columns):
                    if _is_nan(val):
                        evaluations.append((pandas[col].isna()))
                    else:
                        evaluations.append((pandas[col] == val))
                mask = reduce(lambda x, y: x & y, evaluations)
                pandas_segment = pandas[mask]
            else:
                pandas_segment = grouped_data.get_group(group)
            segment_key = _get_segment_from_group_key(group, partition_id, explicit_keys)
            _process_segment(pandas_segment, segment_key, segments, schema, segment_cache)
    elif row:
        # TODO: consider if we need to combine with the column names
        segment_key = Segment(tuple(str(row[element]) for element in columns) + explicit_keys, partition_id)
        _process_segment(row, segment_key, segments, schema, segment_cache)


def _filter_inputs(
    filter: SegmentFilter, pandas: Optional[pd.DataFrame] = None, row: Optional[Mapping[str, Any]] = None
) -> Tuple[Optional[pd.DataFrame], Optional[Dict[str, Any]]]:
    assert (
        filter.filter_function or filter.query_string
    ), f"must define at least a filter function or query string when specifying a segment filter: {filter}"
    filtered_pandas = None
    filtered_row = None
    if pandas is not None:
        if filter.filter_function:
            filtered_pandas = pandas[filter.filter_function]
        elif filter.query_string:
            filtered_pandas = pandas.query(filter.query_string)
    elif row is not None:
        if filter.filter_function:
            filtered_row = filter.filter_function(row)
        elif filter.query_string:
            raise ValueError(
                "SegmentFilter query string not supported when logging rows, either don't specify a filter or implement the filter.filter_function"
            )
    return (filtered_pandas, filtered_row)


def _grouped_dataframe(partition: SegmentationPartition, pandas: pd.DataFrame):
    columns = partition.mapper.col_names if partition.mapper else None
    if not columns:
        raise ValueError(
            "Please use column segmentation, there are no columns defined and ColumnMapperFunction not yet supported."
        )
    grouped_data = pandas.groupby(columns)
    return grouped_data


def _log_segment(
    partition: SegmentationPartition,
    schema: DatasetSchema,
    obj: Any = None,
    pandas: Optional[pd.DataFrame] = None,
    row: Optional[Mapping[str, Any]] = None,
    segment_cache: Optional[SegmentCache] = None,
    segment_key_values: Optional[Dict[str, str]] = None,
) -> Dict[Segment, Any]:
    segments: Dict[Segment, Any] = {}
    pandas, row = _pandas_or_dict(obj, pandas, row)
    if partition.filter:
        pandas, row = _filter_inputs(partition.filter, pandas, row)
    if partition.simple:
        columns = partition.mapper.col_names if partition.mapper else None
        if columns:
            _process_simple_partition(
                partition.id, schema, segments, columns, pandas, row, segment_cache, segment_key_values
            )
        else:
            logger.error(
                "Segmented DatasetSchema defines no segments; use an unsegmented DatasetSchema or specify columns to segment on."
            )
    else:
        raise NotImplementedError("custom mapped segments not yet implemented")
    return segments


def segment_processing(
    schema: DatasetSchema,
    obj: Any = None,
    pandas: Optional[pd.DataFrame] = None,
    row: Optional[Dict[str, Any]] = None,
    segment_cache: Optional[SegmentCache] = None,
    segment_key_values: Optional[Dict[str, str]] = None,
) -> SegmentedResultSet:
    number_of_partitions = len(schema.segments)
    logger.info(f"The specified schema defines segments with {number_of_partitions} partitions.")
    assert not (
        number_of_partitions > _MAX_SEGMENT_PARTITIONS
    ), f"Attempt to process {number_of_partitions} partitions is larger than the max of {_MAX_SEGMENT_PARTITIONS}, use a lower number of partitions"
    segmented_profiles = dict()
    segment_partitions = list()

    for partition_name in schema.segments:
        segment_partition = schema.segments[partition_name]
        if segment_partition.mapper and segment_key_values:
            segment_partition.mapper.set_explicit_names(segment_key_values.keys())

        logger.info(f"Processing partition with name({partition_name})")
        logger.debug(f"{partition_name}: is simple ({segment_partition.simple}), id ({segment_partition.id})")
        if segment_partition.filter:
            # TODO segments filter
            logger.debug(f"{partition_name}: defines filter ({segment_partition.filter})")
        if segment_partition.mapper:
            logger.debug(
                f"{partition_name}: defines mapper on colums ({segment_partition.mapper.col_names}) and id ({segment_partition.mapper.id})"
            )
        partition_segments = _log_segment(
            partition=segment_partition,
            schema=schema,
            obj=obj,
            pandas=pandas,
            row=row,
            segment_cache=segment_cache,
            segment_key_values=segment_key_values,
        )
        segmented_profiles[segment_partition.id] = partition_segments
        segment_partitions.append(segment_partition)
        logger.debug(f"Done profiling for partition with name({partition_name})")
    return SegmentedResultSet(segments=segmented_profiles, partitions=segment_partitions)


def _result_set_for_segment_key_values(segment_key_values: Dict[str, Any], profile) -> SegmentedResultSet:
    segment_keys = segment_key_values.keys()
    segment_values = segment_key_values.values()
    partition_name = ",".join(segment_keys)
    partition = SegmentationPartition(name=partition_name, mapper=ColumnMapperFunction(col_names=list(segment_keys)))
    segment_key = Segment(key=tuple(segment_values), parent_id=partition.id)
    partition_segments = {segment_key: profile}
    segmented_profiles = {partition.id: partition_segments}
    segment_partitions = [partition]

    return SegmentedResultSet(segments=segmented_profiles, partitions=segment_partitions)