Unbabel/OpenKiwi

View on GitHub
kiwi/modules/word_level_output.py

Summary

Maintainability
A
1 hr
Test Coverage
A
93%
#  OpenKiwi: Open-Source Machine Translation Quality Estimation
#  Copyright (C) 2020 Unbabel <openkiwi@unbabel.com>
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Affero General Public License as published
#  by the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Affero General Public License for more details.
#
#  You should have received a copy of the GNU Affero General Public License
#  along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
import torch
from torch import nn


class WordLevelOutput(nn.Module):
    def __init__(
        self,
        input_size,
        output_size,
        pad_idx,
        class_weights=None,
        remove_first=False,
        remove_last=False,
    ):
        super().__init__()

        self.pad_idx = pad_idx

        # Explicit check to avoid using 0 as False
        self.start_pos = None if remove_first is False or remove_first is None else 1
        self.stop_pos = None if remove_last is False or remove_last is None else -1

        self.linear = nn.Linear(input_size, output_size)

        self.loss_fn = nn.CrossEntropyLoss(
            reduction='sum', ignore_index=pad_idx, weight=class_weights
        )

        nn.init.xavier_uniform_(self.linear.weight)
        nn.init.constant_(self.linear.bias, 0.0)

    def forward(self, features_tensor, batch_inputs=None):
        logits = self.linear(features_tensor)
        logits = logits[:, self.start_pos : self.stop_pos]
        return logits


class GapTagsOutput(WordLevelOutput):
    def __init__(
        self,
        input_size,
        output_size,
        pad_idx,
        class_weights=None,
        remove_first=False,
        remove_last=False,
    ):
        super().__init__(
            input_size=2 * input_size,
            output_size=output_size,
            pad_idx=pad_idx,
            class_weights=class_weights,
            remove_first=False,
            remove_last=False,
        )
        self.add_pad_start = 1 if remove_first is False or remove_first is None else 0
        self.add_pad_stop = 1 if remove_last is False or remove_last is None else 0

    def forward(self, features_tensor, batch_inputs=None):
        h_gaps = features_tensor
        if self.add_pad_start or self.add_pad_stop:
            # Pad dim=1
            num_of_pads = self.add_pad_start + self.add_pad_stop
            h_gaps = nn.functional.pad(
                h_gaps,
                pad=[0, 0] * (len(h_gaps.shape) - num_of_pads)
                + [self.add_pad_start, self.add_pad_stop],
                value=0,
            )
        h_gaps = torch.cat((h_gaps[:, :-1], h_gaps[:, 1:]), dim=-1)
        logits = super().forward(h_gaps, batch_inputs)
        return logits