python/whylogs/api/logger/result_set.py
from abc import ABC, abstractmethod
from datetime import datetime
from logging import getLogger
from typing import Any, Dict, List, Optional, Tuple, Union
from whylogs.api.reader import Reader, Readers
from whylogs.api.writer import Writer, Writers
from whylogs.api.writer.writer import Writable
from whylogs.core import DatasetProfile, DatasetProfileView, Segment
from whylogs.core.metrics.metrics import Metric
from whylogs.core.model_performance_metrics import ModelPerformanceMetrics
from whylogs.core.segmentation_partition import SegmentationPartition
from whylogs.core.utils import ensure_timezone
from whylogs.core.view.dataset_profile_view import _MODEL_PERFORMANCE
from whylogs.core.view.segmented_dataset_profile_view import SegmentedDatasetProfileView
logger = getLogger(__name__)
def _merge_metrics(
lhs_metrics: Optional[Dict[str, Any]], rhs_metrics: Optional[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
if not rhs_metrics:
return lhs_metrics
if not lhs_metrics:
return rhs_metrics
lhs_keys = lhs_metrics.keys()
rhs_keys = rhs_metrics.keys()
merged_metrics: Dict[str, Any] = dict()
lhs_only = lhs_keys - rhs_keys
rhs_only = rhs_keys - lhs_keys
intersection_keys = lhs_keys & rhs_keys
for key in lhs_only:
merged_metrics[key] = lhs_metrics[key]
for key in rhs_only:
merged_metrics[key] = rhs_metrics[key]
for key in intersection_keys:
merged_metrics[key] = lhs_metrics[key].merge(rhs_metrics[key])
return merged_metrics
def _accumulate_properties(acc: Optional[Dict[str, Any]], props: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
if not props:
return acc
if acc is None:
acc = dict()
intersection_keys = acc.keys() & props.keys()
for key in intersection_keys:
if acc[key] != props[key]:
logger.warning(f"Merging result properties collision, using {key}:{acc[key]} and dropping {props[key]}.")
acc.update(props)
return acc
def _merge_segments(
lhs_segments: Dict[Segment, Union[DatasetProfile, DatasetProfileView]],
rhs_segments: Dict[Segment, Union[DatasetProfile, DatasetProfileView]],
) -> Dict[Segment, DatasetProfileView]:
lhs_keys = lhs_segments.keys()
rhs_keys = rhs_segments.keys()
merged_segments: Dict[Segment, DatasetProfileView] = dict()
lhs_only = lhs_keys - rhs_keys
rhs_only = rhs_keys - lhs_keys
intersection_keys = lhs_keys & rhs_keys
for key in lhs_only:
left_value = lhs_segments[key]
left_view = left_value if isinstance(left_value, DatasetProfileView) else left_value.view()
merged_segments[key] = left_view
for key in rhs_only:
right_value = rhs_segments[key]
right_view = right_value if isinstance(right_value, DatasetProfileView) else right_value.view()
merged_segments[key] = right_view
for key in intersection_keys:
left_value = lhs_segments[key]
left_view = left_value if isinstance(left_value, DatasetProfileView) else left_value.view()
right_value = lhs_segments[key]
right_view = right_value if isinstance(right_value, DatasetProfileView) else right_value.view()
merged_segments[key] = left_view.merge(right_view)
return merged_segments
def _merge_partitioned_segments(
lhs_segments: Dict[str, Dict[Segment, DatasetProfile]], rhs_segments: Dict[str, Dict[Segment, DatasetProfile]]
) -> Dict[str, Dict[Segment, DatasetProfile]]:
lhs_partitions = lhs_segments.keys()
rhs_partitions = rhs_segments.keys()
merged_partitions: Dict[str, Dict[Segment, DatasetProfile]] = dict()
lhs_only = lhs_partitions - rhs_partitions
rhs_only = rhs_partitions - lhs_partitions
intersection_keys = lhs_partitions & rhs_partitions
for key in lhs_only:
merged_partitions[key] = lhs_segments[key]
for key in rhs_only:
merged_partitions[key] = rhs_segments[key]
for key in intersection_keys:
merged_partitions[key] = _merge_segments(lhs_segments[key], rhs_segments[key])
return merged_partitions
def _merge_partitions(
lhs_partitions: List[SegmentationPartition], rhs_partitions: List[SegmentationPartition]
) -> List[SegmentationPartition]:
return list(set(lhs_partitions).union(set(rhs_partitions)))
class ResultSetWriter:
"""
Result of a logging call.
A result set might contain one or multiple profiles or profile views.
"""
def __init__(self, results: "ResultSet", writer: Writer):
self._result_set = results
self._writer = writer
def option(self, **kwargs: Any) -> "ResultSetWriter":
self._writer.option(**kwargs)
return self
def write(self, **kwargs: Any) -> List:
# multi-profile writer
files = self._result_set.get_writables()
statuses: List[Tuple[bool, str]] = list()
if not files:
logger.warning("Attempt to write a result set with no writables returned, nothing written!")
return statuses
if hasattr(self._writer, "_reference_profile_name"):
if self._writer._reference_profile_name is not None and isinstance(self._result_set, SegmentedResultSet):
# a segmented reference profile name needs to have access to the complete result set
response = self._writer.write(file=self._result_set, **kwargs)
statuses.append(response)
return statuses
logger.debug(f"About to write {len(files)} files:")
# TODO: special handling of large number of files, handle throttling
for view in files:
response = self._writer.write(file=view, **kwargs)
statuses.append(response)
logger.debug(f"Completed writing {len(files)} files!")
return statuses
class ResultSetReader:
def __init__(self, reader: Reader) -> None:
self._reader = reader
def option(self, **kwargs: Any) -> "ResultSetReader":
self._reader.option(**kwargs)
return self
def read(self, **kwargs: Any) -> "ResultSet":
return self._reader.read(**kwargs)
class ResultSet(ABC):
"""
A holder object for profiling results.
A whylogs.log call can result in more than one profile. This wrapper class
simplifies the navigation among these profiles.
Note that currently we only hold one profile but we're planning to add other
kinds of profiles such as segmented profiles here.
"""
@staticmethod
def read(multi_profile_file: str) -> "ResultSet":
# TODO: parse multiple profile
view = DatasetProfileView.read(multi_profile_file)
return ViewResultSet(view=view)
@staticmethod
def reader(name: str = "local") -> "ResultSetReader":
reader = Readers.get(name)
return ResultSetReader(reader=reader)
def writer(self, name: str = "local") -> "ResultSetWriter":
writer = Writers.get(name)
return ResultSetWriter(results=self, writer=writer)
@abstractmethod
def view(self) -> Optional[DatasetProfileView]:
pass
@abstractmethod
def profile(self) -> Optional[DatasetProfile]:
pass
@property
def metadata(self) -> Optional[Dict[str, str]]:
if hasattr(self, "_metadata"):
return self._metadata
return None
def get_writables(self) -> Optional[List[Writable]]:
return [self.view()]
def set_dataset_timestamp(self, dataset_timestamp: datetime) -> None:
ensure_timezone(dataset_timestamp)
profile = self.profile()
if profile is None:
raise ValueError("Cannot set timestamp on a result set without a profile!")
else:
profile.set_dataset_timestamp(dataset_timestamp)
@property
def count(self) -> int:
result = 0
if self.view() is not None:
result = 1
return result
@property
def performance_metrics(self) -> Optional[ModelPerformanceMetrics]:
profile = self.profile()
if profile:
return profile.model_performance_metrics
return None
def add_model_performance_metrics(self, metrics: ModelPerformanceMetrics) -> None:
profile = self.profile()
if profile:
profile.add_model_performance_metrics(metrics)
else:
raise ValueError("Cannot add performance metrics to a result set with no profile!")
def add_metric(self, name: str, metric: Metric) -> None:
profile = self.profile()
if profile:
profile.add_dataset_metric(name, metric)
else:
raise ValueError(f"Cannot add {name} metric {metric} to a result set with no profile!")
def merge(self, other: "ResultSet") -> "ResultSet":
raise NotImplementedError("This result set did not define merge, see ProfileResultSet or SegmentedResulSet.")
class ViewResultSet(ResultSet):
def __init__(self, view: DatasetProfileView) -> None:
self._view = view
def profile(self) -> Optional[DatasetProfile]:
raise ValueError("No profile available. Can only view")
def view(self) -> Optional[DatasetProfileView]:
return self._view
@property
def metadata(self) -> Optional[Dict[str, str]]:
view = self.view()
if view:
return view.metadata
else:
return None
@staticmethod
def zero() -> "ViewResultSet":
return ViewResultSet(DatasetProfileView.zero())
def merge(self, other: "ResultSet") -> "ViewResultSet":
if other is None:
return self
lhs_view = self._view or DatasetProfileView.zero()
if not isinstance(other, (ViewResultSet, ProfileResultSet)):
logger.warning(f"Merging potentially incompatible ViewResultSet and {type(other)}")
return ViewResultSet(lhs_view.merge(other.view()))
def set_dataset_timestamp(self, dataset_timestamp: datetime) -> None:
ensure_timezone(dataset_timestamp)
view = self.view()
if view is None:
raise ValueError("Cannot set timestamp on a view result set without a view!")
else:
view.dataset_timestamp = dataset_timestamp
class ProfileResultSet(ResultSet):
def __init__(self, profile: DatasetProfile) -> None:
self._profile = profile
def profile(self) -> Optional[DatasetProfile]:
return self._profile
def view(self) -> Optional[DatasetProfileView]:
return self._profile.view()
@property
def metadata(self) -> Optional[Dict[str, str]]:
view = self.view()
if view:
return view.metadata
else:
return None
@staticmethod
def zero() -> "ProfileResultSet":
return ProfileResultSet(DatasetProfile())
def merge(self, other: "ResultSet") -> ViewResultSet:
if other is None:
return self
lhs_profile = self.view() or DatasetProfileView()
if not isinstance(other, (ProfileResultSet, ViewResultSet)):
logger.error(f"Merging potentially incompatible ProfileResultSet and {type(other)}")
return ViewResultSet(lhs_profile.merge(other.view()))
class SegmentedResultSet(ResultSet):
def __init__(
self,
segments: Dict[str, Dict[Segment, Union[DatasetProfile, DatasetProfileView]]],
partitions: Optional[List[SegmentationPartition]] = None,
metrics: Optional[Dict[str, Any]] = None,
properties: Optional[Dict[str, Any]] = None,
) -> None:
self._segments = segments
self._partitions = partitions
self._metrics = metrics or dict()
self._dataset_properties = properties or dict()
self._metadata: Dict[str, str] = dict()
def profile(self, segment: Optional[Segment] = None) -> Optional[Union[DatasetProfile, DatasetProfileView]]:
if not self._segments:
return None
elif segment:
paritition_segments = self._segments.get(segment.parent_id)
return paritition_segments.get(segment) if paritition_segments else None
# special case return a single segment if there is only one, even if not specified
elif len(self._segments) == 1:
for partition_id in self._segments:
segments = self._segments.get(partition_id)
number_of_segments = len(segments) if segments else 0
if number_of_segments == 1:
single_dictionary: Dict[Segment, Union[DatasetProfile, DatasetProfileView]] = (
segments if segments else dict()
)
for key in single_dictionary:
return single_dictionary[key]
raise ValueError(
f"A profile was requested from a segmented result set without specifying which segment to return: {self._segments}"
)
@property
def dataset_properties(self) -> Optional[Dict[str, Any]]:
return self._dataset_properties
@property
def dataset_metrics(self) -> Optional[Dict[str, Any]]:
return self._metrics
@property
def partitions(self) -> Optional[List[SegmentationPartition]]:
return self._partitions
def set_dataset_timestamp(self, dataset_timestamp: datetime) -> None:
# TODO: pull dataset_timestamp up into a result set scoped property
segment_keys = self.segments()
if not segment_keys:
return
for key in segment_keys:
profile = self.profile(segment=key)
if profile:
profile.set_dataset_timestamp(dataset_timestamp)
def segments(self, restrict_to_parition_id: Optional[str] = None) -> Optional[List[Segment]]:
result: Optional[List[Segment]] = None
if not self._segments:
return result
result = list()
if restrict_to_parition_id:
segments = self._segments.get(restrict_to_parition_id)
if segments:
for segment in segments:
result.append(segment)
else:
for partition_id in self._segments:
for segment in self._segments[partition_id]:
result.append(segment)
return result
@property
def count(self) -> int:
result = 0
if self._segments:
for segment_key in self._segments:
profiles = self._segments[segment_key]
result += len(profiles)
return result
def segments_in_partition(
self, partition: SegmentationPartition
) -> Optional[Dict[Segment, Union[DatasetProfile, DatasetProfileView]]]:
return self._segments.get(partition.id)
def view(self, segment: Optional[Segment] = None) -> Optional[DatasetProfileView]:
result = self.profile(segment)
view = result.view() if isinstance(result, DatasetProfile) else result
return view
def get_model_performance_metrics_for_segment(self, segment: Segment) -> Optional[ModelPerformanceMetrics]:
if segment.parent_id in self._segments:
profile = self._segments[segment.parent_id][segment]
if not profile:
logger.warning(
f"No profile found for segment {segment} when requesting model performance metrics, returning None!"
)
return None
if isinstance(profile, DatasetProfileView):
view = profile
else:
if hasattr(profile, "view"):
view = profile.view()
else:
logger.error(
f"Unexpected type: {type(profile)} -> {profile}, cannot check for model performance metrics."
)
return None
return view.model_performance_metrics
return None
def get_writables(self) -> Optional[List[Writable]]:
results: Optional[List[Writable]] = None
if self._segments:
results = []
logger.info(f"Building list of: {self.count} SegmentedDatasetProfileViews in SegmentedResultSet.")
# TODO: handle more than one partition
if not self.partitions:
raise ValueError(
f"Building list of: {self.count} SegmentedDatasetProfileViews in SegmentedResultSet but no partitions found: {self.partitions}."
)
if len(self.partitions) > 1:
logger.error(
f"Building list of: {self.count} SegmentedDatasetProfileViews in SegmentedResultSet but found more than one partition: "
f"{self.partitions}. Using first partition only!!"
)
first_partition = self.partitions[0]
segments = self.segments_in_partition(first_partition)
if segments:
for segment_key in segments:
profile = segments[segment_key]
metric = self.get_model_performance_metrics_for_segment(segment_key)
if metric:
profile.add_model_performance_metrics(metric)
logger.debug(
f"Found model performance metrics: {metric}, adding to segmented profile: {segment_key}."
)
view = profile.view() if isinstance(profile, DatasetProfile) else profile
segmented_profile = SegmentedDatasetProfileView(
profile_view=view, segment=segment_key, partition=first_partition
)
if self.metadata:
segmented_profile.metadata.update(self.metadata)
results.append(segmented_profile)
else:
logger.warning(
f"Found no segments in partition: {first_partition} even though we have: {self.count} segments overall"
)
logger.info(f"From list of: {self.count} SegmentedDatasetProfileViews using {len(results)}")
else:
logger.warning(
f"Attempt to build segmented results for writing but there are no segments in this result set: {self._segments}. returning None."
)
return results
def add_metrics_for_segment(self, metrics: ModelPerformanceMetrics, segment: Segment) -> None:
if segment.parent_id in self._segments:
profile = self._segments[segment.parent_id][segment]
profile.add_model_performance_metrics(metrics)
@staticmethod
def zero() -> "SegmentedResultSet":
return SegmentedResultSet(segments=dict())
@property
def model_performance_metric(self) -> Optional[ModelPerformanceMetrics]:
if self._metrics:
return self._metrics.get(_MODEL_PERFORMANCE)
return None
def add_model_performance_metrics(self, metrics: ModelPerformanceMetrics) -> None:
if self._metrics:
self._metrics[_MODEL_PERFORMANCE] = metrics
else:
self._metrics = {_MODEL_PERFORMANCE: metrics}
def add_metric(self, name: str, metric: Metric) -> None:
if not self._metrics:
self._metrics = dict()
self._metrics[name] = metric
def merge(self, other: "ResultSet") -> "SegmentedResultSet":
if other is None:
return self
if isinstance(other, SegmentedResultSet):
lhs_partitions: List[SegmentationPartition] = self.partitions or list()
rhs_partitions: List[SegmentationPartition] = other.partitions or list()
lhs_segments: Dict[str, Dict[Segment, DatasetProfile]] = self._segments or dict()
rhs_segments: Dict[str, Dict[Segment, DatasetProfile]] = other._segments or dict()
merged_segments = _merge_partitioned_segments(lhs_segments, rhs_segments)
merged_metrics = _merge_metrics(self.dataset_metrics, other.dataset_metrics)
merged_partitions = _merge_partitions(lhs_partitions, rhs_partitions)
properties = _accumulate_properties(self._dataset_properties, other.dataset_properties)
return SegmentedResultSet(merged_segments, merged_partitions, metrics=merged_metrics, properties=properties)
else:
raise ValueError(f"Cannot merge incompatible SegmentedResultSet and {type(other)}")