Unbabel/OpenKiwi

View on GitHub
kiwi/systems/encoders/predictor.py

Summary

Maintainability
C
7 hrs
Test Coverage
B
88%
#  OpenKiwi: Open-Source Machine Translation Quality Estimation
#  Copyright (C) 2020 Unbabel <openkiwi@unbabel.com>
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Affero General Public License as published
#  by the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Affero General Public License for more details.
#
#  You should have received a copy of the GNU Affero General Public License
#  along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
import logging
from collections import OrderedDict
from typing import Dict, Optional

import torch
from pydantic import validator
from torch import nn

from kiwi import constants as const
from kiwi.data.vocabulary import Vocabulary
from kiwi.modules.common.attention import Attention
from kiwi.modules.common.scorer import MLPScorer
from kiwi.modules.token_embeddings import TokenEmbeddings
from kiwi.systems._meta_module import MetaModule
from kiwi.systems.encoders.quetch import InputEmbeddingsConfig
from kiwi.utils.io import BaseConfig
from kiwi.utils.tensors import apply_packed_sequence, pad_zeros_around_timesteps

logger = logging.getLogger(__name__)


class DualSequencesEncoder(nn.Module):
    def __init__(
        self,
        input_size_a,
        input_size_b,
        hidden_size,
        output_size,
        num_layers,
        dropout,
        _use_v0_buggy_strategy=False,
    ):
        super().__init__()
        self._use_v0_buggy_strategy = _use_v0_buggy_strategy  # Check doc in Config

        scorer = MLPScorer(hidden_size * 2, hidden_size * 2)
        self.attention = Attention(scorer)

        self.forward_backward_a = nn.LSTM(
            input_size=input_size_a,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout,
            bidirectional=True,
        )
        self.forward_b = nn.LSTM(
            input_size=input_size_b,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout,
            bidirectional=False,
        )
        self.backward_b = nn.LSTM(
            input_size=input_size_b,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout,
            bidirectional=False,
        )

        self.W2 = nn.Parameter(
            torch.zeros(output_size, output_size), requires_grad=True
        )

        self.V = nn.Parameter(
            torch.zeros(2 * input_size_b, 2 * output_size), requires_grad=True
        )
        self.C = nn.Parameter(
            torch.zeros(2 * hidden_size, 2 * output_size), requires_grad=True
        )
        self.S = nn.Parameter(
            torch.zeros(2 * hidden_size, 2 * output_size), requires_grad=True
        )

        for p in self.parameters():
            if len(p.shape) > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, embeddings_a, lengths_a, mask_a, embeddings_b, lengths_b):
        if self._use_v0_buggy_strategy:
            embeddings_a = embeddings_a[:, :-2]
            mask_a = mask_a[:, 1:-1]
            lengths_a -= 2  # Equivalent to mask_a.sum(dim=1)
            lengths_b -= 2

        # Encode sequence A
        contexts_a, hidden = apply_packed_sequence(
            self.forward_backward_a, embeddings_a, lengths_a
        )

        # Encode sequence B
        forward_contexts, backward_contexts = self.contextualize_b(
            embeddings_b, lengths_b, hidden
        )
        post_QEFV = torch.cat([forward_contexts, backward_contexts], dim=-1)

        f = self.encode_b(
            embeddings_b, forward_contexts, backward_contexts, contexts_a, mask_a
        )

        return f, post_QEFV

    def contextualize_b(self, embeddings, lengths, hidden):
        h_forward, h_backward = self._split_hidden(hidden)
        # Note: hidden, h_forward and h_backward are not batch_first
        forward_contexts, _ = self.forward_b(embeddings, h_forward)
        reversed_embeddings = self._reverse_padded_seq(lengths, embeddings)
        backward_contexts, _ = self.backward_b(reversed_embeddings, h_backward)
        backward_contexts = self._reverse_padded_seq(lengths, backward_contexts)

        return forward_contexts, backward_contexts

    def encode_b(
        self,
        embeddings,
        forward_contexts,
        backward_contexts,
        contexts_a,
        attention_mask,
    ):
        """Encode sequence B.

        Build a feature vector for each position i using left context i-1 and right
        context i+1. In the original implementation, this resulted in a returned tensor
        with -2 timesteps (dim=1). We have now changed it to return the same number
        of timesteps as the input. The consequence is that callers now have to deal
        with BOS and EOS in a different way, but hopefully this new behaviour is more
        consistent and less surprising. The old behaviour can be forced by setting
        ``self._use_v0_buggy_strategy`` to True.
        """
        if not self._use_v0_buggy_strategy:
            # Pad inputs on both sides for retaining the original number of timesteps
            forward_contexts = pad_zeros_around_timesteps(forward_contexts)
            backward_contexts = pad_zeros_around_timesteps(backward_contexts)
            embeddings = pad_zeros_around_timesteps(embeddings)

        # For each position, concatenate left context i-1 and right context i+1
        # (bs, target_len, d) -> (bs, target_len-2, d*2)
        contexts_b = torch.cat(
            [forward_contexts[:, :-2], backward_contexts[:, 2:]], dim=-1
        )

        # For each position i, concatenate Embeddings i-1 and i+1
        context_b_embeddings = torch.cat(
            [embeddings[:, :-2], embeddings[:, 2:]], dim=-1
        )

        # Get Attention for all positions and stack (vectorized)
        attns, p_attns = self.attention(
            query=contexts_b,
            keys=contexts_a,
            values=contexts_a,
            mask=attention_mask.unsqueeze(1),
        )

        # Combine attention, embeddings and target context vectors
        S = torch.einsum('bsk,kl->bsl', [contexts_b, self.S])
        V = torch.einsum('bsj,jl->bsl', [context_b_embeddings, self.V])
        C = torch.einsum('bsi,il->bsl', [attns, self.C])
        t_tilde = S + V + C

        # Maxout with pooling size 2
        t, _ = torch.max(
            t_tilde.view(t_tilde.shape[0], t_tilde.shape[1], t_tilde.shape[-1] // 2, 2),
            dim=-1,
        )
        f = torch.einsum('oh,bso->bsh', [self.W2, t])

        return f

    @staticmethod
    def _reverse_padded_seq(lengths, sequence):
        """Reverse a batch of padded sequences of different length."""
        batch_size, max_length = sequence.shape[:-1]
        reversed_idx = []
        for i in range(batch_size * max_length):
            batch_id = i // max_length
            sent_id = i % max_length
            if sent_id < lengths[batch_id]:
                sent_id_rev = lengths[batch_id] - sent_id - 1
            else:
                sent_id_rev = sent_id  # Padding symbol, don't change order
            reversed_idx.append(max_length * batch_id + sent_id_rev)
        flat_sequence = sequence.contiguous().view(batch_size * max_length, -1)
        reversed_seq = flat_sequence[reversed_idx, :].view(*sequence.shape)
        return reversed_seq

    @staticmethod
    def _split_hidden(hidden):
        """Split hidden state into forward/backward parts."""
        h, c = hidden
        size = h.size(0)
        idx_forward = torch.arange(0, size, 2, dtype=torch.long)
        idx_backward = torch.arange(1, size, 2, dtype=torch.long)
        hidden_forward = (h[idx_forward], c[idx_forward])
        hidden_backward = (h[idx_backward], c[idx_backward])
        return hidden_forward, hidden_backward


@MetaModule.register_subclass
class PredictorEncoder(MetaModule):
    """Bidirectional Conditional Language Model

    Implemented after Kim et al 2017, see: http://www.statmt.org/wmt17/pdf/WMT63.pdf
    """

    class Config(BaseConfig):
        hidden_size: int = 400
        """Size of hidden layers in LSTM."""

        rnn_layers: int = 3
        """Number of RNN layers in the Predictor."""

        dropout: float = 0.0

        share_embeddings: bool = False
        """Tie input and output embeddings for target."""

        out_embeddings_dim: Optional[int] = None
        """Word Embedding in Output layer."""

        use_mismatch_features: bool = False
        """Whether to use Alibaba's mismatch features."""

        embeddings: InputEmbeddingsConfig = InputEmbeddingsConfig()

        use_v0_buggy_strategy: bool = False
        """The Predictor implementation in Kiwi<=0.3.4 had a bug in applying the LSTM
        to encode source (it used lengths too short by 2) and in reversing the target
        embeddings for applying the backward LSTM (also short by 2). This flag is set
        to true when loading a saved model from those versions."""
        v0_start_stop: bool = False
        """Whether pre_qe_f_v is padded on both ends or
        post_qe_f_v is strip on both ends."""

        @validator('dropout', pre=True)
        def dropout_on_rnns(cls, v, values):
            if v > 0.0 and values['rnn_layers'] == 1:
                logger.info(
                    'Dropout on an RNN of one layer has no effect; setting it to zero.'
                )
                return 0.0
            return v

        @validator('use_mismatch_features', pre=True)
        def no_implementation(cls, v):
            if v:
                raise NotImplementedError('Not yet implemented')
            return False

    def __init__(
        self,
        vocabs: Dict[str, Vocabulary],
        config: Config,
        pretraining: bool = False,
        pre_load_model: bool = True,
    ):
        """
        Arguments:
            vocabs: dictionary Mapping Field Names to Vocabularies.
            config: a state dict of a PredictorConfig object.
            pretraining: set it to True when pretraining with parallel data.
            pre_load_model: not used
        """
        super().__init__(config=config)
        self.pretraining = pretraining

        # Input embeddings
        self.embeddings = nn.ModuleDict()
        self.embeddings[const.TARGET] = TokenEmbeddings(
            num_embeddings=len(vocabs[const.TARGET]),
            pad_idx=vocabs[const.TARGET].pad_id,
            config=config.embeddings.target,
            vectors=vocabs[const.TARGET].vectors,
        )
        self.embeddings[const.SOURCE] = TokenEmbeddings(
            num_embeddings=len(vocabs[const.SOURCE]),
            pad_idx=vocabs[const.SOURCE].pad_id,
            config=config.embeddings.source,
            vectors=vocabs[const.SOURCE].vectors,
        )
        self.vocabs = {
            const.TARGET: vocabs[const.TARGET],
            const.SOURCE: vocabs[const.SOURCE],
        }

        # Output embeddings
        self.output_embeddings = nn.ModuleDict()
        if self.config.share_embeddings:
            self.output_embeddings[const.TARGET] = self.embeddings[
                const.TARGET
            ].embedding
        else:
            self.output_embeddings[const.TARGET] = nn.Embedding(
                num_embeddings=self.embeddings[const.TARGET].num_embeddings,
                embedding_dim=self.config.out_embeddings_dim,
                padding_idx=self.embeddings[const.TARGET].pad_idx,
            )

        # Encoders
        self.encode_target = DualSequencesEncoder(
            input_size_a=self.embeddings[const.SOURCE].size(),
            input_size_b=self.embeddings[const.TARGET].size(),
            hidden_size=self.config.hidden_size,
            output_size=self.output_embeddings[const.TARGET].embedding_dim,
            num_layers=self.config.rnn_layers,
            dropout=self.config.dropout,
            _use_v0_buggy_strategy=self.config.use_v0_buggy_strategy,
        )

        output_dim = self.output_embeddings[const.TARGET].embedding_dim
        self.start_PreQEFV = nn.Parameter(
            torch.zeros(1, 1, output_dim), requires_grad=True
        )
        self.end_PreQEFV = nn.Parameter(
            torch.zeros(1, 1, output_dim), requires_grad=True
        )

        # total_size = sum(emb.size() for emb in self.embeddings.values())
        self._sizes = {
            const.TARGET: output_dim + 2 * self.config.hidden_size,
            # const.SOURCE: output_dim + 2 * self.config.hidden_size,
            const.TARGET_LOGITS: self.output_embeddings[const.TARGET].num_embeddings,
            const.PE_LOGITS: self.output_embeddings[const.TARGET].num_embeddings,
        }

    @classmethod
    def input_data_encoders(cls, config: Config):
        return None  # Use defaults, i.e., TextEncoder

    def size(self, field=None):
        if field:
            return self._sizes[field]
        return self._sizes

    def forward(self, batch_inputs, include_target_logits=False):
        target_embeddings = self.embeddings[const.TARGET](batch_inputs[const.TARGET])
        source_embeddings = self.embeddings[const.SOURCE](batch_inputs[const.SOURCE])

        target_lengths = batch_inputs[const.TARGET].lengths
        source_lengths = batch_inputs[const.SOURCE].lengths

        source_attention_mask = batch_inputs[const.SOURCE].strict_masks

        f, post_qe_feature_vector = self.encode_target(
            source_embeddings,
            source_lengths,
            source_attention_mask,
            target_embeddings,
            target_lengths,
        )

        output_embeddings = self.output_embeddings[const.TARGET](
            batch_inputs[const.TARGET].tensor
        )
        if self.config.use_v0_buggy_strategy:
            pre_qe_feature_vector = torch.einsum(
                'bsh,bsh->bsh', [output_embeddings[:, 1:-1], f]
            )
            if self.config.v0_start_stop:
                start = self.start_PreQEFV.expand(output_embeddings.size(0), -1, -1)
                end = self.end_PreQEFV.expand(output_embeddings.size(0), -1, -1)
                pre_qe_feature_vector = torch.cat(
                    (start, pre_qe_feature_vector, end), dim=1
                )
            else:
                post_qe_feature_vector = post_qe_feature_vector[:, 1:-1]
        else:
            pre_qe_feature_vector = torch.einsum('bsh,bsh->bsh', [output_embeddings, f])
            # Using these learnable start and stop parameters seemed to help (according
            #  to Sony).
            start = self.start_PreQEFV.expand(output_embeddings.size(0), -1, -1)
            end = self.end_PreQEFV.expand(output_embeddings.size(0), -1, -1)
            pre_qe_feature_vector = torch.cat(
                (start, pre_qe_feature_vector[:, 1:-1], end), dim=1
            )

        features = torch.cat([pre_qe_feature_vector, post_qe_feature_vector], dim=-1)

        output_features = OrderedDict()
        output_features[const.TARGET] = features

        if include_target_logits or self.pretraining:
            logits = torch.einsum(
                'vh,bsh->bsv', [self.output_embeddings[const.TARGET].weight, f]
            )
            output_features[const.TARGET_LOGITS] = logits

        if const.PE in batch_inputs:
            pe_embeddings = self.embeddings[const.TARGET](batch_inputs[const.PE])
            pe_lengths = batch_inputs[const.PE].lengths
            f, _ = self.encode_target(
                source_embeddings,
                source_lengths,
                source_attention_mask,
                pe_embeddings,
                pe_lengths,
            )
            output_features[const.PE_LOGITS] = torch.einsum(
                'vh,bsh->bsv', [self.output_embeddings[const.TARGET].weight, f]
            )

        return output_features