IBM/pytorchpipe

View on GitHub
ptp/components/tasks/text_to_class/language_identification.py

Summary

Maintainability
B
6 hrs
Test Coverage
# Copyright (C) tkornuta, IBM Corporation 2019
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__author__ = "Tomasz Kornuta"

from ptp.components.tasks.task import Task
from ptp.data_types.data_definition import DataDefinition


class LanguageIdentification(Task):
    """
    Language identification (classification) task.
    """

    def __init__(self, name, class_type, config):
        """
        Initializes task object. Calls base constructor.

        :param name: Name of the component.

        :param class_type: Class type of the component.

        :param config: Dictionary of parameters (read from configuration ``.yaml`` file).
        """
        # Call constructors of parent classes.
        Task.__init__(self, name, class_type, config)

        # Set key mappings.
        self.key_inputs = self.stream_keys["inputs"]
        self.key_targets = self.stream_keys["targets"]

        # Set empty inputs and targets.
        self.inputs = []
        self.targets = []


    def output_data_definitions(self):
        """ 
        Function returns a dictionary with definitions of output data produced the component.

        :return: dictionary containing output data definitions (each of type :py:class:`ptp.utils.DataDefinition`).
        """
        return {
            self.key_indices: DataDefinition([-1, 1], [list, int], "Batch of sample indices [BATCH_SIZE] x [1]"),
            self.key_inputs: DataDefinition([-1, 1], [list, str], "Batch of sentences, each being a single string (many words) [BATCH_SIZE x SENTENCE]"),
            self.key_targets: DataDefinition([-1, 1], [list, str], "Batch of targets, each being a single label (word) BATCH_SIZE x WORD]")
            }


    def __len__(self):
        """
        Returns the "size" of the "task" (total number of samples).

        :return: The size of the task.
        """
        return len(self.inputs)


    def __getitem__(self, index):
        """
        Getter method to access the dataset and return a sample.

        :param index: index of the sample to return.
        :type index: int

        :return: ``DataStreams({'inputs','targets'})``

        """
        # Return data_streams.
        data_streams = self.create_data_streams(index)
        data_streams[self.key_inputs] = self.inputs[index]
        data_streams[self.key_targets] = self.targets[index]
        return data_streams