IBM/pytorchpipe

View on GitHub
ptp/components/models/multi_modal_reasoning/factorized_bilinear_pooling.py

Summary

Maintainability
A
2 hrs
Test Coverage
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2019
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__author__ = "Deepta Rajan"


import torch

from ptp.components.models.model import Model
from ptp.data_types.data_definition import DataDefinition
import torch.nn.functional as F

class FactorizedBilinearPooling(Model):
    """
    Element of one of the classical baselines for Visual Question Answering.
    The multi-modal data is fused via sum-pooling of the element-wise multiplied high-dimensional representations and returned (for subsequent classification, done in a separate component e.g. ffn).

    On the basis of: Zhou Yu, Jun Yu. "Beyond Bilinear: Generalized Multi-modal Factorized High-order Pooling for Visual Question Answering" (2015).
    Code: https://github.com/Cadene/block.bootstrap.pytorch/blob/master/block/models/networks/fusions/fusions.py
    """
    def __init__(self, name, config):
        """
        Initializes the model, creates the required layers.

        :param name: Name of the model (taken from the configuration file).

        :param config: Parameters read from configuration file.
        :type config: ``ptp.configuration.ConfigInterface``

        """
        super(FactorizedBilinearPooling, self).__init__(name, FactorizedBilinearPooling, config)

        # Get key mappings.
        self.key_image_encodings = self.stream_keys["image_encodings"]
        self.key_question_encodings = self.stream_keys["question_encodings"]
        self.key_outputs = self.stream_keys["outputs"]

        # Retrieve input/output sizes from globals.
        self.image_encoding_size = self.globals["image_encoding_size"]
        self.question_encoding_size = self.globals["question_encoding_size"]

        # Get size of latent space and number of heads from config.
        self.latent_size = self.config["latent_size"]
        self.factor = self.config["pool_factor"]
        # Output feature size
        self.output_size = self.latent_size

        # Export to globals.
        self.globals["output_size"] = self.output_size

        # Map image and question encodings to a common latent space of dimension 'latent_size'.
        self.image_encodings_ff = torch.nn.Linear(self.image_encoding_size, self.latent_size*self.factor)
        self.question_encodings_ff = torch.nn.Linear(self.question_encoding_size, self.latent_size*self.factor)

        # Create activation layer.
        self.activation = torch.nn.ReLU()

        # Retrieve dropout rate value - if set, will put dropout between every layer.
        dropout_rate = self.config["dropout_rate"]

        # Create dropout layer.
        self.dropout = torch.nn.Dropout(dropout_rate)


    def input_data_definitions(self):
        """
        Function returns a dictionary with definitions of input data that are required by the component.

        :return: dictionary containing input data definitions (each of type :py:class:`ptp.utils.DataDefinition`).
        """
        return {
            self.key_image_encodings: DataDefinition([-1, self.image_encoding_size], [torch.Tensor], "Batch of encoded images [BATCH_SIZE x IMAGE_ENCODING_SIZE]"),
            self.key_question_encodings: DataDefinition([-1, self.question_encoding_size], [torch.Tensor], "Batch of encoded questions [BATCH_SIZE x QUESTION_ENCODING_SIZE]"),
            }


    def output_data_definitions(self):
        """
        Function returns a dictionary with definitions of output data produced the component.

        :return: dictionary containing output data definitions (each of type :py:class:`ptp.utils.DataDefinition`).
        """
        return {
            self.key_outputs: DataDefinition([-1, self.output_size], [torch.Tensor], "Batch of outputs [BATCH_SIZE x OUTPUT_SIZE]")
            }

    def forward(self, data_streams):
        """
        Main forward pass of the model.

        :param data_streams: DataStreams({'images',**})
        :type data_streams: ``ptp.dadatypes.DataStreams``
        """

        # Unpack DataStreams.
        enc_img = data_streams[self.key_image_encodings] #[48, 2048]
        enc_q = data_streams[self.key_question_encodings] #[48, 100]

        # Map image and question encodings to high-dimensional space using FF
        latent_img = self.dropout(self.image_encodings_ff(enc_img)) # [48, 512]
        latent_q =  self.dropout(self.question_encodings_ff(enc_q)) # [48, 512]

        # Element-wise mutliplication of image and question encodings
        enc_z = latent_img * latent_q # [48, 512]

        # Dropout regularization
        enc_z = self.dropout(enc_z)
        enc_z = enc_z.view(enc_z.size(0), self.latent_size, self.factor) # [48, 256, 2]
        # Sum pooling
        enc_z = enc_z.sum(2) # [48, 256]
        # Power and L2 normalization
        enc_z = torch.sqrt(self.activation(enc_z)) - torch.sqrt(self.activation(-enc_z))
        outputs = F.normalize(enc_z, p=2, dim=1) # [48, 256]

        # Add predictions to datadict.
        data_streams.publish({self.key_outputs: outputs})