IBM/pytorchpipe

View on GitHub
ptp/components/masking/join_masked_predictions.py

Summary

Maintainability
A
2 hrs
Test Coverage
# -*- coding: utf-8 -*-
#
# Copyright (C) tkornuta, IBM Corporation 2019
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__author__ = "Tomasz Kornuta"

import torch
import numpy as np

from ptp.components.component import Component
from ptp.data_types.data_definition import DataDefinition


class JoinMaskedPredictions(Component):
    """
    Class responsible joining several prediction streams using the associated masks.
    Additionally, it returns the associated string indices.

    .. warning:
        As performed operations are not differentiable, the returned  'output_indices' cannot be used for e.g. calculation of loss!!

    """

    def __init__(self, name, config):
        """
        Initializes the object. Loads keys, word mappings and vocabularies.

        :param name: Name of the component read from the configuration file
        :type name: str

        :param config: Dictionary of parameters (read from the configuration ``.yaml`` file).
        :type config: :py:class:`ptp.configuration.ConfigInterface`

        """
        # Call constructors of parent classes.
        Component.__init__(self, name, JoinMaskedPredictions, config)

        # Get input key mappings.
        # Load list of prediction streams names (keys).
        self.input_prediction_stream_keys = self.config["input_prediction_streams"]
        if type(self.input_prediction_stream_keys) == str:
            self.input_prediction_stream_keys = self.input_prediction_stream_keys.replace(" ", "").split(",")
        #assert(self.input_prediction_stream_keys != ""), "ooo"

        # Load list of mask streams names (keys).
        self.input_mask_stream_keys = self.config["input_mask_streams"]
        if type(self.input_mask_stream_keys) == str:
            self.input_mask_stream_keys = self.input_mask_stream_keys.replace(" ", "").split(",")

        # Load list of word mappings names (keys).
        input_word_mappings_keys = self.config["input_word_mappings"]
        if type(input_word_mappings_keys) == str:
            input_word_mappings_keys = input_word_mappings_keys.replace(" ", "").split(",")

        # Retrieve input word mappings from globals.
        self.input_ix_to_word = []
        for wmk in input_word_mappings_keys:
            # Get word mappings.
            word_to_ix = self.globals[wmk]
            # Create inverse transformation.
            ix_to_word = {value: key for (key, value) in word_to_ix.items()}
            self.input_ix_to_word.append(ix_to_word)


        # Get output key mappings.
        self.key_output_indices = self.stream_keys["output_indices"]
        self.key_output_strings = self.stream_keys["output_strings"]

        # Retrieve output word mappings from globals.
        self.output_word_to_ix = self.globals["output_word_mappings"]


    def input_data_definitions(self):
        """ 
        Function returns a dictionary with definitions of input data that are required by the component.

        :return: dictionary containing input data definitions (each of type :py:class:`ptp.data_types.DataDefinition`).
        """
        input_defs = {}
        # Add input prediction streams.
        for i, ipsk in enumerate(self.input_prediction_stream_keys):
            # Use input prediction stream key along with the length of the associated word mappings (i.e. size of the vocabulary = NUM_CLASSES)
            input_defs[ipsk] = DataDefinition([-1, len(self.input_ix_to_word[i])], [torch.Tensor], "Batch of predictions, represented as tensor with probability distribution over classes [BATCH_SIZE x NUM_CLASSES]")
        # Add mask streams.
        for imsk in self.input_mask_stream_keys:
            # Every mask has the same definition, but different stream key.
            input_defs[imsk] = DataDefinition([-1], [torch.Tensor], "Batch of masks [BATCH_SIZE]")

        return input_defs

    def output_data_definitions(self):
        """ 
        Function returns a empty dictionary with definitions of output data produced the component.

        :return: Empty dictionary.
        """
        return {
            self.key_output_indices: DataDefinition([-1], [torch.Tensor], "Batch of merged (output) indices [BATCH_SIZE]"),
            self.key_output_strings: DataDefinition([-1, 1], [list, str], "Batch of merged strings, corresponging to indices when using the provided word mappings [BATCH_SIZE] x [STRING]")
            }


    def __call__(self, data_streams):
        """
        Encodes "inputs" in the format of a single tensor.
        Stores reshaped tensor in "outputs" field of in data_streams.

        :param data_streams: :py:class:`ptp.utils.DataStreams` object containing (among others):

            - "inputs": expected input field containing tensor [BATCH_SIZE x ...]

            - "outputs": added output field containing tensor [BATCH_SIZE x ...] 
        """
        # Get inputs masks.
        masks = []
        for imsk in self.input_mask_stream_keys:
            # Get mask from stream.
            mask = data_streams[imsk]
            masks.append(mask.data.cpu().numpy())
        
        # Sum all masks and make sure that they are complementary.
        masks_sum = np.sum(masks, axis=0)
        batch_size = masks_sum.shape[0]
        sum_ones = sum(filter(lambda x: x == 1, masks_sum))
        if sum_ones != batch_size:
            self.logger.error("Masks received from the {} streams are not complementary!".format(self.input_mask_stream_keys))
            exit(-1)

        # Create mapping indicating from which input prediction/mask/dictionary stream we will take data associated with given "sample".
        weights = np.array(range(len(masks)))
        masks = np.array(masks).transpose()
        mapping = np.dot(masks, weights)
        #print("Mapping = \n",mapping)

        # "Translate". 
        output_answers = []
        output_indices = []
        output_predictions_lst = []
        # Iterate through samples.
        for sample in range(batch_size):
            # Get the right dictionary.
            ix_to_word = self.input_ix_to_word[mapping[sample]]
            #print(ix_to_word)

            # Get the right sample from the right prediction stream.
            sample_prediction = data_streams[self.input_prediction_stream_keys[mapping[sample]]][sample]
            #print(sample_prediction)
            output_predictions_lst.append(sample_prediction)

            # Get the index of max log-probabilities.
            index = sample_prediction.max(0)[1].data.cpu().item()
            #print(index)
            
            # Get the right word.
            word = ix_to_word[index]
            output_answers.append(word)

            # Get original index using output dictionary.
            output_indices.append(self.output_word_to_ix[word])

        #print(output_predictions_lst)
        #targets = data_streams["targets"].data.cpu().numpy()
        #print("targets = \n",targets.tolist())
        #print("joined answers = \n",output_indices)

        # Change to tensor.
        output_indices_tensor = torch.tensor(output_indices)

        # Extend the dict by returned output streams.
        data_streams.publish({
            self.key_output_indices: output_indices_tensor,
            self.key_output_strings: output_answers
            })