whylabs/whylogs-python

View on GitHub
src/whylogs/core/metrics/confusion_matrix.py

Summary

Maintainability
A
1 hr
Test Coverage
from logging import getLogger
from typing import List, Union

import numpy as np

from whylogs.core.statistics import NumberTracker
from whylogs.proto import ScoreMatrixMessage
from whylogs.util.util_functions import encode_to_integers

SUPPORTED_TYPES = ("binary", "multiclass")
MODEL_METRICS_MAX_LABELS = 256
MODEL_METRICS_LABEL_SIZE_WARNING_THRESHOLD = 64

_logger = getLogger("whylogs")


class ConfusionMatrix:
    """

    Confusion Matrix Class to hold labels and matrix data.

    Attributes:
        labels: list of labels in a sorted order
        prediction_field: name of the prediction field
        target_field: name of the target field
        score_field: name of the score field
        confusion_matrix (nd.array): Confusion Matrix kept as matrix of NumberTrackers
        labels (List[str]): list of labels for the confusion_matrix axes
    """

    def __init__(
        self,
        labels: List[str] = None,
        prediction_field: str = None,
        target_field: str = None,
        score_field: str = None,
    ):
        self.prediction_field = prediction_field
        self.target_field = target_field
        self.score_field = score_field
        if labels:
            labels_size = len(labels)
            if labels_size > MODEL_METRICS_LABEL_SIZE_WARNING_THRESHOLD:
                _logger.warning(
                    f"The initialized confusion matrix has {labels_size} labels and the resulting"
                    " confusion matrix will be larger than is recommended with whylogs current"
                    " representation of the model metric for a confusion matrix of this size."
                )
            if labels_size > MODEL_METRICS_MAX_LABELS:
                raise ValueError(
                    f"The initialized confusion matrix has {labels_size} labels and the resulting"
                    " confusion matrix will be larger than is supported by whylogs current"
                    " representation of the model metric for a confusion matrix of this size,"
                    " selectively log the most important labels or configure the threshold of"
                    " {MODEL_METRICS_MAX_LABELS} higher by setting MODEL_METRICS_MAX_LABELS."
                )

            self.labels = sorted(labels)
            num_labels = len(self.labels)
            self.confusion_matrix = np.empty([num_labels, num_labels], dtype=object)
            for each_ind_i in range(num_labels):
                for each_ind_j in range(num_labels):
                    self.confusion_matrix[each_ind_i, each_ind_j] = NumberTracker()
        else:
            self.labels = None
            self.confusion_matrix = None

    def add(
        self,
        predictions: List[Union[str, int, bool]],
        targets: List[Union[str, int, bool]],
        scores: List[float],
    ):
        """
        Function adds predictions and targets to confusion matrix with scores.

        Args:
            predictions (List[Union[str, int, bool]]):
            targets (List[Union[str, int, bool]]):
            scores (List[float]):

        Raises:
            NotImplementedError: in case targets do not fall into binary or
            multiclass suport
            ValueError: incase missing validation or predictions
        """
        if not isinstance(targets, list):
            targets = [targets]
        if not isinstance(predictions, list):
            predictions = [predictions]

        if scores is None:
            scores = [1.0 for _ in range(len(targets))]

        if len(targets) != len(predictions):
            raise ValueError("both targets and predictions need to have the same length")

        targets_indx = encode_to_integers(targets, self.labels)
        prediction_indx = encode_to_integers(predictions, self.labels)

        for ind in range(len(predictions)):
            self.confusion_matrix[prediction_indx[ind], targets_indx[ind]].track(scores[ind])

    def merge(self, other_cm):
        """
        Merge two seperate confusion matrix which may or may not overlap in labels.

        Args:
              other_cm (Optional[ConfusionMatrix]): confusion_matrix to merge with self
        Returns:
              ConfusionMatrix: merged confusion_matrix
        """
        # TODO: always return new objects
        if other_cm is None:
            return self
        if self.labels is None or self.labels == []:
            return other_cm
        if other_cm.labels is None or other_cm.labels == []:
            return self

        labels = list(set(self.labels + other_cm.labels))

        conf_matrix = ConfusionMatrix(labels)

        conf_matrix = _merge_CM(self, conf_matrix)
        conf_matrix = _merge_CM(other_cm, conf_matrix)

        return conf_matrix

    def to_protobuf(
        self,
    ):
        """
        Convert to protobuf

        Returns:
            TYPE: Description
        """
        return ScoreMatrixMessage(
            labels=[str(i) for i in self.labels],
            prediction_field=self.prediction_field,
            target_field=self.target_field,
            score_field=self.score_field,
            scores=[nt.to_protobuf() if nt else NumberTracker.to_protobuf(NumberTracker()) for nt in np.ravel(self.confusion_matrix)],
        )

    @classmethod
    def from_protobuf(
        cls,
        message: ScoreMatrixMessage,
    ):
        if message.ByteSize() == 0:
            return None
        labels = message.labels
        num_labels = len(labels)
        matrix = np.array([NumberTracker.from_protobuf(score) for score in message.scores]).reshape((num_labels, num_labels)) if num_labels > 0 else None

        cm_instance = ConfusionMatrix(
            labels=labels,
            prediction_field=message.prediction_field,
            target_field=message.target_field,
            score_field=message.score_field,
        )
        cm_instance.confusion_matrix = matrix

        return cm_instance


def _merge_CM(old_conf_matrix: ConfusionMatrix, new_conf_matrix: ConfusionMatrix):
    """
    Merges two confusion_matrix since distinc or overlaping labels

    Args:
        old_conf_matrix (ConfusionMatrix)
        new_conf_matrix (ConfusionMatrix): Will be overridden
    """
    new_indxes = encode_to_integers(old_conf_matrix.labels, new_conf_matrix.labels)
    old_indxes = encode_to_integers(old_conf_matrix.labels, old_conf_matrix.labels)

    res_conf_matrix = ConfusionMatrix(new_conf_matrix.labels)

    res_conf_matrix.confusion_matrix = new_conf_matrix.confusion_matrix

    for old_row_idx, each_row_indx in enumerate(new_indxes):
        for old_column_idx, each_column_inx in enumerate(new_indxes):
            res_conf_matrix.confusion_matrix[each_row_indx, each_column_inx] = new_conf_matrix.confusion_matrix[each_row_indx, each_column_inx].merge(
                old_conf_matrix.confusion_matrix[old_indxes[old_row_idx], old_indxes[old_column_idx]]
            )

    return res_conf_matrix