Unbabel/OpenKiwi

View on GitHub
kiwi/utils/tensors.py

Summary

Maintainability
A
0 mins
Test Coverage
B
87%
#  OpenKiwi: Open-Source Machine Translation Quality Estimation
#  Copyright (C) 2020 Unbabel <openkiwi@unbabel.com>
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Affero General Public License as published
#  by the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Affero General Public License for more details.
#
#  You should have received a copy of the GNU Affero General Public License
#  along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
from typing import Dict, Optional

import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as unpack

from kiwi.data.batch import BatchedSentence
from kiwi.data.vocabulary import Vocabulary


def pad_zeros_around_timesteps(batched_tensor: torch.Tensor) -> torch.Tensor:
    input_size = batched_tensor.size()
    left_pad = batched_tensor.new_zeros(input_size[0], 1, *input_size[2:])
    right_pad = batched_tensor.new_zeros(input_size[0], 1, *input_size[2:])
    return torch.cat((left_pad, batched_tensor, right_pad), dim=1)


def convolve_tensor(sequences, window_size, pad_value=0):
    """Convolve a sequence and apply padding.

    Arguments:
        sequences: nD tensor
        window_size: filter length
        pad_value: int value used as padding

    Return:
         (n+1)D tensor, where the last dimension has size window_size.
    """
    pad = (window_size // 2,) * 2
    pad = (0, 0) * (len(sequences.shape) - 2) + pad  # We only want to pad dim 1
    t = F.pad(sequences, pad=pad, value=pad_value)
    t = t.unfold(1, window_size, 1)

    return t


def apply_packed_sequence(rnn, padded_sequences, lengths):
    """Run a forward pass of padded_sequences through an rnn using packed sequence.

    Arguments:
       rnn: The RNN that that we want to compute a forward pass with.
       padded_sequences (FloatTensor b x seq x dim): A batch of padded_sequences.
       lengths (LongTensor batch): The length of each sequence in the batch.

    Return:
       output: the output of the RNN `rnn` with input `padded_sequences`
    """
    # Sort Batch by sequence length
    total_length = padded_sequences.size(1)  # Get the max sequence length
    lengths_sorted, permutation = torch.sort(lengths, descending=True)
    padded_sequences_sorted = padded_sequences[permutation]

    # Use Packed Sequence
    padded_sequences_packed = pack(
        padded_sequences_sorted, lengths_sorted, batch_first=True
    )
    outputs_packed, (hidden, cell) = rnn(padded_sequences_packed)
    outputs_sorted, _ = unpack(
        outputs_packed, batch_first=True, total_length=total_length
    )

    # Restore original order
    _, permutation_rev = torch.sort(permutation, descending=False)
    outputs = outputs_sorted[permutation_rev]
    hidden, cell = hidden[:, permutation_rev], cell[:, permutation_rev]
    return outputs, (hidden, cell)


def replace_token(target: torch.LongTensor, old: int, new: int):
    """Replace old tokens with new.

    Arguments:
        target
        old: the token to be replaced by new.
        new: the token used to replace old.

    """
    return target.masked_fill(target == old, new)


def make_classes_loss_weights(vocab: Vocabulary, label_weights: Dict[str, float]):
    """Create a loss weight vector for nn.CrossEntropyLoss.

    Arguments:
        vocab: vocabulary for classes.
        label_weights: weight for specific classes (str); classes in vocab and not in
                       this dict will get a weight of 1.

    Return:
       weights (FloatTensor): weight Tensor of shape `nb_classes`.
    """
    nb_classes = (vocab.net_length(),)
    class_weights = torch.ones(nb_classes)
    for class_label, weight in label_weights.items():
        class_idx = vocab.stoi[class_label]
        class_weights[class_idx] = weight
    return class_weights


def sequence_mask(lengths: torch.LongTensor, max_len: Optional[int] = None):
    """Create a boolean mask from sequence lengths.

    Arguments:
        lengths: lengths with shape (bs,)
        max_len: max sequence length; if None it will be set to lengths.max()
    """
    if max_len is None:
        max_len = lengths.max()
    # This is equivalent
    mask = torch.arange(max_len, device=lengths.device)[None, :] < lengths[:, None]
    return mask


def unmask(tensor, mask):
    """Unmask a tensor and convert it back to a list of lists."""
    lengths = mask.int().sum(dim=-1).tolist()
    return [x[: lengths[i]].tolist() for i, x in enumerate(tensor)]


def unsqueeze_as(tensor, as_tensor, dim=-1):
    """Expand new dimensions based on a template tensor along `dim` axis."""
    x = tensor
    while x.dim() < as_tensor.dim():
        x = x.unsqueeze(dim)
    return x


def make_mergeable_tensors(t1: torch.Tensor, t2: torch.Tensor):
    """Expand a new dimension in t1 and t2 and expand them so that both
    tensors will have the same number of timesteps.

    Arguments:
        t1: tensor with shape (bs, ..., m, d1)
        t2: tensor with shape (bs, ..., n, d2)

    Return:
        tuple of
            torch.Tensor: (bs, ..., m, n, d1),
            torch.Tensor: (bs, ..., m, n, d2)
    """
    assert t1.dim() == t2.dim()
    assert t1.dim() >= 3
    assert t1.shape[:-2] == t2.shape[:-2]
    # new_shape = [-1, ..., m, n, -1]
    new_shape = [-1 for _ in range(t1.dim() + 1)]
    new_shape[-3] = t1.shape[-2]  # m
    new_shape[-2] = t2.shape[-2]  # n
    # (bs, ..., m, d1) -> (bs, ..., m, 1, d1) -> (bs, ..., m, n, d1)
    new_t1 = t1.unsqueeze(-2).expand(new_shape)
    # (bs, ..., n, d2) -> (bs, ..., 1, n, d2) -> (bs, ..., m, n, d2)
    new_t2 = t2.unsqueeze(-3).expand(new_shape)
    return new_t1, new_t2


class GradientMul(Function):
    @staticmethod
    def forward(ctx, x, constant=0):
        ctx.constant = constant
        return x

    @staticmethod
    def backward(ctx, grad):
        return ctx.constant * grad, None


gradient_mul = GradientMul.apply


def retrieve_tokens_mask(input_batch: BatchedSentence):
    """Compute Mask of Tokens for side.

    Migrated from FieldEmbedder.get_mask()

    Arguments:
        input_batch (BatchedSentence): batch of tensors

    Return:
        mask tensor
    """
    assert isinstance(input_batch, BatchedSentence)

    tensor = input_batch.tensor
    lengths = input_batch.lengths
    mask = torch.ones_like(tensor, dtype=torch.int)
    mask[:] = torch.arange(mask.shape[1])
    mask = mask < lengths.unsqueeze(-1).int()
    return mask


def select_positions(tensor, indices):
    range_vector = torch.arange(tensor.size(0), device=tensor.device).unsqueeze(1)
    return tensor[range_vector, indices]


def pieces_to_tokens(features_tensor, batch, strategy='first'):
    """Join together pieces of a token back into the original token dimension."""
    if strategy == 'first':
        # Use the bounds for the wordpieces
        # This discards all features not from the first wordpiece of the token
        return select_positions(features_tensor, batch.bounds)
    else:
        raise NotImplementedError('Only available joining strategy is first.')