IBM/pytorchpipe

View on GitHub
ptp/components/statistics/precision_recall_statistics.py

Summary

Maintainability
D
2 days
Test Coverage
# -*- coding: utf-8 -*-
#
# Copyright (C) tkornuta, IBM Corporation 2019
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__author__ = "Tomasz Kornuta"

import torch
import numpy as np
import math

from ptp.components.component import Component
from ptp.data_types.data_definition import DataDefinition


class PrecisionRecallStatistics(Component):
    """
    Class collecting statistics: batch size.

    """

    def __init__(self, name, config):
        """
        Initializes object.

        :param name: Loss name.
        :type name: str

        :param config: Dictionary of parameters (read from the configuration ``.yaml`` file).
        :type config: :py:class:`ptp.configuration.ConfigInterface`

        """
        # Call constructors of parent classes.
        Component.__init__(self, name, PrecisionRecallStatistics, config)

        # Get stream key mappings.
        self.key_targets = self.stream_keys["targets"]
        self.key_predictions = self.stream_keys["predictions"]
        self.key_masks = self.stream_keys["masks"]

        # Get prediction distributions/indices flag.
        self.use_prediction_distributions = self.config["use_prediction_distributions"]

        # Get masking flag.
        self.use_masking = self.config["use_masking"]

        # Get statistics key mappings.
        self.key_precision = self.statistics_keys["precision"]
        self.key_recall = self.statistics_keys["recall"]
        self.key_f1score = self.statistics_keys["f1score"]

        
        # Get (or create) vocabulary.
        if self.config["use_word_mappings"]:
            # Get labels from word mappings.
            self.labels = []
            self.index_mappings = {}
            # Assume they are ordered, starting from 0.
            for i,(word,index) in enumerate(self.globals["word_mappings"].items()):
                self.labels.append(word)
                self.index_mappings[index] = i
            # Set number of classes by looking at labels.
            self.num_classes = len(self.labels)
        else:
            # Get the number of possible outputs.
            self.num_classes = self.globals["num_classes"]
            self.labels = list(range(self.num_classes))
            self.index_mappings = {i: i for i in range(self.num_classes)}

        # Check display options.
        self.show_confusion_matrix = self.config["show_confusion_matrix"]
        self.show_class_scores = self.config["show_class_scores"]

    def input_data_definitions(self):
        """ 
        Function returns a dictionary with definitions of input data that are required by the component.

        :return: dictionary containing input data definitions (each of type :py:class:`ptp.utils.DataDefinition`).
        """
        input_defs = {
            self.key_targets: DataDefinition([-1], [torch.Tensor], "Batch of targets, each being a single index [BATCH_SIZE]")
            }

        if self.use_prediction_distributions:
            input_defs[self.key_predictions] = DataDefinition([-1, -1], [torch.Tensor], "Batch of predictions, represented as tensor with probability distribution over classes [BATCH_SIZE x NUM_CLASSES]")
        else: 
            input_defs[self.key_predictions] = DataDefinition([-1], [torch.Tensor], "Batch of predictions, represented as tensor with indices of predicted answers [BATCH_SIZE]")

        if self.use_masking:
            input_defs[self.key_masks] = DataDefinition([-1], [torch.Tensor], "Batch of masks [BATCH_SIZE]")
        return input_defs

    def output_data_definitions(self):
        """ 
        Function returns a empty dictionary with definitions of output data produced the component.

        :return: Empty dictionary.
        """
        return {}

    def __call__(self, data_streams):
        """
        Calculates precission recall statistics.

        :param data_streams: DataStreams containing the targets.
        :type data_streams: DataStreams

        """
        # Use worker interval.
        if self.app_state.episode % self.app_state.args.logging_interval == 0:

            # Calculate all four statistics.
            confusion_matrix, precisions, recalls, f1scores, supports = self.calculate_statistics(data_streams)

            if self.show_confusion_matrix:
                self.logger.info("Confusion matrix:\n{}".format(confusion_matrix))

            # Calculate weighted averages.
            support_sum = sum(supports)
            if support_sum > 0:
                precision_avg = sum([pi*si for (pi,si) in zip(precisions,supports)]) / support_sum 
                recall_avg = sum([ri*si for (ri,si) in zip(recalls,supports)]) / support_sum
                f1score_avg = sum([fi*si for (fi,si) in zip(f1scores,supports)]) / support_sum
            else:
                precision_avg = 0
                recall_avg = 0
                f1score_avg = 0

            # Log class scores.
            if self.show_class_scores:
                log_str = "\n| Precision | Recall | F1Score | Support | Label\n"
                log_str+= "|-----------|--------|---------|---------|-------\n"
                for i in range(self.num_classes):
                    log_str += "|    {:05.4f} | {:05.4f} |  {:05.4f} |   {:5d} | {}\n".format(
                        precisions[i], recalls[i], f1scores[i], supports[i], self.labels[i])
                log_str+= "|-----------|--------|---------|---------|-------\n"
                log_str += "|    {:05.4f} | {:05.4f} |  {:05.4f} |   {:5d} | Weighted Avg\n".format(
                        precision_avg, recall_avg, f1score_avg, support_sum)
                self.logger.info(log_str)


    def calculate_statistics(self, data_streams):
        """
        Calculates confusion_matrix, precission, recall, f1score and support statistics.

        :param data_streams: DataStreams containing the targets.
        :type data_streams: DataStreams

        :return: Calculated statistics.
        """
        targets = data_streams[self.key_targets].data.cpu().numpy()
        #print("Targets :", targets)

        if self.use_prediction_distributions:
            # Get indices of the max log-probability.
            preds = data_streams[self.key_predictions].max(1)[1].data.cpu().numpy()
        else: 
            preds = data_streams[self.key_predictions].data.cpu().numpy()
        #print("Predictions :", preds)

        if self.use_masking:
            # Get masks from inputs.
            masks = data_streams[self.key_masks].data.cpu().numpy()
        else:
            # Create vector full of ones.
            masks = np.ones(targets.shape[0])

        # Create the confusion matrix, use SciKit learn order:
        # Column - predicted class
        #print(self.index_mappings)
        # Row - target (actual) class
        confusion_matrix = np.zeros([self.num_classes, self.num_classes], dtype=int)
        for i, (target, pred) in enumerate(zip(targets, preds)):
            #print("T: ",target)
            #print("P: ",pred)
            # If both indices are ok.
            if target in self.index_mappings.keys() and pred in self.index_mappings.keys():
                #print(self.index_mappings[target])
                #print(self.index_mappings[pred])
                confusion_matrix[self.index_mappings[target]][self.index_mappings[pred]] += 1 * masks[i]

        # Calculate true positive (TP), eqv. with hit.
        tp = np.zeros([self.num_classes], dtype=int)
        for i in range(self.num_classes):
            tp[i] = confusion_matrix[i][i]
        #print("TP = ",tp)        

        # Calculate false positive (FP) eqv. with false alarm, Type I error
        # Predictions that incorrectly labelled as belonging to a given class.
        # Sum wrong predictions along the column.
        fp = np.sum(confusion_matrix, axis=0) - tp
        #print("FP = ",fp)

        # Calculate false negative (FN), eqv. with miss, Type II error
        # The target belonged to a given class, but it wasn't correctly labeled.
        # Sum wrong predictions along the row.
        fn = np.sum(confusion_matrix, axis=1) - tp
        #print("FN = ",fn)        

        # Precision is the fraction of events where we correctly declared i
        # out of all instances where the algorithm declared i.
        precisions = [float(tpi) / float(tpi+fpi) if (tpi+fpi) > 0 else 0.0 for (tpi,fpi) in zip(tp,fp)]

        # Recall is the fraction of events where we correctly declared i 
        # out of all of the cases where the true of state of the world is i.
        recalls = [float(tpi) / float(tpi+fni) if (tpi+fni) > 0 else 0.0 for (tpi,fni) in zip(tp,fn)]

        # Calcualte f1-scores.
        f1scores = [ 2 * pi * ri / float(pi+ri) if (pi+ri) > 0 else 0.0 for (pi,ri) in zip(precisions,recalls)]

        # Get support.
        supports = np.sum(confusion_matrix, axis=1)

        #print('precision: {}'.format(precision))
        #print('recall: {}'.format(recall))
        #print('f1score: {}'.format(f1score))
        #print('support: {}'.format(support))

        return confusion_matrix, precisions, recalls, f1scores, supports


    def add_statistics(self, stat_col):
        """
        Adds 'accuracy' statistics to ``StatisticsCollector``.

        :param stat_col: ``StatisticsCollector``.

        """
        # Those will be displayed.
        stat_col.add_statistics(self.key_precision, '{:05.4f}')
        stat_col.add_statistics(self.key_recall, '{:05.4f}')
        stat_col.add_statistics(self.key_f1score, '{:05.4f}')
        # That one will be collected and used by aggregator.
        stat_col.add_statistics(self.key_f1score+'_support', None)


    def collect_statistics(self, stat_col, data_streams):
        """
        Collects statistics (batch_size) for given episode.

        :param stat_col: ``StatisticsCollector``.

        """
        # Calculate all four statistics.
        _, precisions, recalls, f1scores, supports = self.calculate_statistics(data_streams)

        # Calculate weighted averages.
        precision_sum = sum([pi*si for (pi,si) in zip(precisions,supports)])
        recall_sum = sum([ri*si for (ri,si) in zip(recalls,supports)])
        f1score_sum = sum([fi*si for (fi,si) in zip(f1scores,supports)])
        support_sum = sum(supports)

        if support_sum > 0:
            precision_avg = precision_sum / support_sum 
            recall_avg = recall_sum / support_sum
            f1score_avg = f1score_sum / support_sum
        else:
            precision_avg = 0
            recall_avg = 0
            f1score_avg = 0

        # Export averages to statistics.
        stat_col[self.key_precision] = precision_avg
        stat_col[self.key_recall] = recall_avg
        stat_col[self.key_f1score] = f1score_avg

        # Export support to statistics.
        stat_col[self.key_f1score+'_support'] = support_sum



    def add_aggregators(self, stat_agg):
        """
        Adds aggregator summing samples from all collected batches.

        :param stat_agg: ``StatisticsAggregator``.

        """
        stat_agg.add_aggregator(self.key_precision, '{:05.4f}') 
        stat_agg.add_aggregator(self.key_precision+'_std', '{:05.4f}')
        stat_agg.add_aggregator(self.key_recall, '{:05.4f}') 
        stat_agg.add_aggregator(self.key_recall+'_std', '{:05.4f}')
        stat_agg.add_aggregator(self.key_f1score, '{:05.4f}') 
        stat_agg.add_aggregator(self.key_f1score+'_std', '{:05.4f}')


    def aggregate_statistics(self, stat_col, stat_agg):
        """
        Aggregates samples from all collected batches.

        :param stat_col: ``StatisticsCollector``

        :param stat_agg: ``StatisticsAggregator``

        """
        precision_sums = stat_col[self.key_precision]
        recall_sums = stat_col[self.key_recall]
        f1score_sums = stat_col[self.key_f1score]
        supports = stat_col[self.key_f1score+'_support']

        # Special case - no samples!
        if sum(supports) == 0:
            stat_agg[self.key_precision] = 0
            stat_agg[self.key_precision+'_std'] = 0
            stat_agg[self.key_recall] = 0
            stat_agg[self.key_recall+'_std'] = 0
            stat_agg[self.key_f1score] = 0
            stat_agg[self.key_f1score+'_std'] = 0

        else: 
            # Else: calculate weighted precision.
            precisions_avg = np.average(precision_sums, weights=supports)
            precisions_var = np.average((precision_sums-precisions_avg)**2, weights=supports)
            
            stat_agg[self.key_precision] = precisions_avg
            stat_agg[self.key_precision+'_std'] = math.sqrt(precisions_var)

            # Calculate weighted recall.
            recalls_avg = np.average(recall_sums, weights=supports)
            recalls_var = np.average((recall_sums-recalls_avg)**2, weights=supports)

            stat_agg[self.key_recall] = recalls_avg
            stat_agg[self.key_recall+'_std'] = math.sqrt(recalls_var)

            # Calculate weighted f1 score.
            f1scores_avg = np.average(f1score_sums, weights=supports)
            f1scores_var = np.average((f1score_sums-f1scores_avg)**2, weights=supports)

            stat_agg[self.key_f1score] = f1scores_avg
            stat_agg[self.key_f1score+'_std'] = math.sqrt(f1scores_var)