Unbabel/OpenKiwi

View on GitHub
kiwi/systems/encoders/xlmroberta.py

Summary

Maintainability
B
6 hrs
Test Coverage
D
66%
#  OpenKiwi: Open-Source Machine Translation Quality Estimation
#  Copyright (C) 2020 Unbabel <openkiwi@unbabel.com>
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Affero General Public License as published
#  by the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Affero General Public License for more details.
#
#  You should have received a copy of the GNU Affero General Public License
#  along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
import logging
from collections import Counter, OrderedDict
from pathlib import Path
from typing import Dict, Union

import torch
from pydantic import confloat
from pydantic.class_validators import validator
from torch import Tensor, nn
from transformers import (
    XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
    AutoTokenizer,
    XLMRobertaConfig,
    XLMRobertaModel,
)
from typing_extensions import Literal

from kiwi import constants as const
from kiwi.data.encoders.field_encoders import TextEncoder
from kiwi.data.vocabulary import Vocabulary
from kiwi.modules.common.scalar_mix import ScalarMixWithDropout
from kiwi.systems._meta_module import MetaModule
from kiwi.utils.data_structures import DefaultFrozenDict
from kiwi.utils.io import BaseConfig
from kiwi.utils.tensors import pieces_to_tokens, retrieve_tokens_mask

logger = logging.getLogger(__name__)


class XLMRobertaTextEncoder(TextEncoder):
    def __init__(self, tokenizer_name='xlm-roberta-base', is_source=False):
        if tokenizer_name not in XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST:
            tokenizer_name = 'xlm-roberta-base'

        tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_name))
        wordpiece_tokenize = tokenizer._tokenize

        # init_token = None if is_source else xlmroberta_tokenizer.cls_token

        super().__init__(
            subtokenize=wordpiece_tokenize,
            pad_token=tokenizer.pad_token,
            unk_token=tokenizer.unk_token,
            bos_token=tokenizer.cls_token,
            eos_token=tokenizer.eos_token,
            specials_first=True,
            # extra options from fields?
            include_lengths=True,
            include_bounds=True,
        )

        self.vocab = Vocabulary(
            counter=Counter(),
            unk_token=self.unk_token,
            pad_token=self.pad_token,
            bos_token=self.bos_token,
            eos_token=self.eos_token,
            specials_first=self.specials_first,
        )

        vocab = tokenizer.get_vocab()
        self.vocab.stoi = DefaultFrozenDict(vocab, vocab[self.unk_token])

        inverted_vocab = {v: k for k, v in self.vocab.stoi.items()}
        self.vocab.itos = DefaultFrozenDict(
            inverted_vocab, inverted_vocab[vocab[self.unk_token]]
        )

    def fit_vocab(
        self,
        samples,
        vocab_size=None,
        vocab_min_freq=0,
        embeddings_name=None,
        keep_rare_words_with_embeddings=False,
        add_embeddings_vocab=False,
    ):
        logger.info(
            'Vocabulary already defined for XLMRoberta field; '
            'not fitting it to data now'
        )
        if vocab_size:
            self.vocab.max_size(vocab_size)


@MetaModule.register_subclass
class XLMRobertaEncoder(MetaModule):
    """XLM-RoBERTa model, using HuggingFace's implementation."""

    class Config(BaseConfig):
        model_name: Union[str, Path] = 'xlm-roberta-base'
        """Pre-trained XLMRoberta model to use."""

        interleave_input: bool = False
        """Concatenate SOURCE and TARGET without internal padding
        (111222000 instead of 111002220)"""

        use_mlp: bool = True
        """Apply a linear layer on top of XLMRoberta."""

        hidden_size: int = 100
        """Size of the linear layer on top of XLMRoberta."""

        pooling: Literal['first_token', 'mean', 'll_mean', 'mixed'] = 'mixed'
        """Type of pooling used to extract features from the encoder. Options are:
            first_token: CLS_token is used for sentence representation
            mean: Use avg pooling for sentence representation using scalar mixed layers
            ll_mean: Mean pool of only last layer embeddings
            mixed: Concat CLS token with mean_pool"""

        scalar_mix_dropout: confloat(ge=0.0, le=1.0) = 0.1
        scalar_mix_layer_norm: bool = True

        freeze: bool = False
        """Freeze XLMRoberta during training."""

        freeze_for_number_of_steps: int = 0
        """Freeze XLMR during training for this number of steps."""

        @validator('model_name', pre=True)
        def fix_relative_path(cls, v):
            if v not in XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST:
                v = Path(v)
                if not v.is_absolute():
                    v = Path.cwd().joinpath(v)
            return v

    def __init__(
        self, vocabs: Dict[str, Vocabulary], config: Config, pre_load_model: bool = True
    ):
        super().__init__(config=config)

        if pre_load_model:
            self.xlm_roberta = XLMRobertaModel.from_pretrained(
                self.config.model_name, output_hidden_states=True
            )
        else:
            xlm_roberta_config = XLMRobertaConfig.from_pretrained(
                self.config.model_name, output_hidden_states=True
            )
            self.xlm_roberta = XLMRobertaModel(xlm_roberta_config)

        self.vocabs = {
            const.TARGET: vocabs[const.TARGET],
            const.SOURCE: vocabs[const.SOURCE],
        }

        self.mlp = None

        if self.config.use_mlp:
            self.mlp = nn.Sequential(
                nn.Linear(self.xlm_roberta.config.hidden_size, self.config.hidden_size),
                nn.Tanh(),
            )
            output_size = self.config.hidden_size
        else:
            output_size = self.xlm_roberta.config.hidden_size

        sentence_size = output_size
        if config.pooling == 'mixed':
            sentence_size *= 2

        self.scalar_mix = ScalarMixWithDropout(
            mixture_size=self.xlm_roberta.config.num_hidden_layers
            + 1,  # +1 for embeddings
            do_layer_norm=self.config.scalar_mix_layer_norm,
            dropout=self.config.scalar_mix_dropout,
        )

        self._sizes = {
            const.TARGET: output_size,
            const.TARGET_LOGITS: output_size,
            const.TARGET_SENTENCE: sentence_size,
            const.SOURCE: output_size,
        }

        self.output_embeddings = self.xlm_roberta.embeddings.word_embeddings

        self._training_steps_ran = 0
        self._is_frozen = False
        if self.config.freeze:
            logger.info(
                'Freezing XLMRoberta encoder weights; training will not update them'
            )
            for param in self.xlm_roberta.parameters():
                param.requires_grad = False
            self._is_frozen = True
        if self.config.freeze_for_number_of_steps > 0:
            # Done inside `forward()` to guarantee we can unfreeze (if optimizer is
            #  built after this, we cannot unfreeze without calling
            #  `optimizer.add_param_group({'params': self.xlm.parameters()})`
            pass

    def load_state_dict(
        self,
        state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]],
        strict: bool = True,
    ):
        try:
            keys = super().load_state_dict(state_dict, strict)
        except RuntimeError as e:
            if "position_ids" in str(e):
                # FIXME: hack to get around Transformers 3.1 breaking changes
                # https://github.com/huggingface/transformers/issues/6882
                self.xlm_roberta.embeddings._non_persistent_buffers_set.add(
                    'position_ids'
                )
                keys = super().load_state_dict(state_dict, strict)
                self.xlm_roberta.embeddings._non_persistent_buffers_set.discard(
                    'position_ids'
                )
            else:
                raise e
        return keys

    @classmethod
    def input_data_encoders(cls, config: Config):
        return {
            const.SOURCE: XLMRobertaTextEncoder(tokenizer_name=config.model_name),
            const.TARGET: XLMRobertaTextEncoder(tokenizer_name=config.model_name),
        }

    def size(self, field=None):
        if field:
            return self._sizes[field]
        return self._sizes

    def _check_freezing(self):
        if self._training_steps_ran == 0 and self.config.freeze_for_number_of_steps > 0:
            logger.info(
                f'Freezing XLMRoberta encoder weights for '
                f'{self.config.freeze_for_number_of_steps} steps'
            )
            for param in self.xlm_roberta.parameters():
                param.requires_grad = False
            self._is_frozen = True
        elif (
            self._is_frozen
            and self._training_steps_ran >= self.config.freeze_for_number_of_steps
        ):
            logger.info(
                f'Unfreezing XLMRoberta encoder '
                f'({self._training_steps_ran} steps have passed)'
            )
            for param in self.xlm_roberta.parameters():
                param.requires_grad = True
            self._is_frozen = False

        self._training_steps_ran += 1

    def forward(self, batch_inputs, *args, include_logits=False):
        self._check_freezing()

        # Input is a concatenation of both embeddings or an interleave
        if self.config.interleave_input:
            merge_input_fn = self.interleave_input
        else:
            merge_input_fn = self.concat_input

        input_ids, token_type_ids, attention_mask = merge_input_fn(
            batch_inputs[const.SOURCE],
            batch_inputs[const.TARGET],
            pad_id=self.vocabs[const.TARGET].pad_id,
        )

        # encoded_layers also includes the embedding layer
        # encoded_layers[-1] is the last layer
        # pooled_output is the first token of the last layer [CLS]
        last_layer, pooled_output, encoded_layers = self.xlm_roberta(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
        )

        features = self.scalar_mix(encoded_layers, attention_mask)

        if self.config.use_mlp:
            features = self.mlp(features)
            last_layer = self.mlp(last_layer)
            pooled_output = self.mlp(pooled_output)

        # Build the feature dictionary to be returned to the system
        split_features = self.split_outputs(
            features, batch_inputs, interleaved=self.config.interleave_input
        )
        split_last_layer = self.split_outputs(
            last_layer, batch_inputs, interleaved=self.config.interleave_input
        )

        mask = retrieve_tokens_mask(batch_inputs[const.TARGET])

        average_pooling = (split_features[const.TARGET] * mask[:, :, None]).sum(
            1
        ) / mask.sum(1)[:, None]

        last_layer_average_pooling = (
            split_last_layer[const.TARGET] * mask[:, :, None]
        ).sum(1) / mask.sum(1)[:, None]

        # interesting idea to try, but not what we implemented
        # mixed_pool = torch.cat((average_pooling, last_layer_average_pooling), 1)
        mixed_pool = torch.cat((average_pooling, pooled_output), 1)

        # Pooling the token embeddings that have already been scalar mixed
        # (mean of tokens)
        sentence_features = pooled_output
        if self.config.pooling == 'mixed':
            sentence_features = mixed_pool
        elif self.config.pooling == 'mean':
            sentence_features = average_pooling
        elif self.config.pooling == 'll_mean':
            sentence_features = last_layer_average_pooling
        elif self.config.pooling == 'first_token':
            sentence_features = pooled_output

        # Convert pieces to tokens
        output_features = {
            const.TARGET: pieces_to_tokens(
                split_features[const.TARGET], batch_inputs[const.TARGET]
            ),
            const.SOURCE: pieces_to_tokens(
                split_features[const.SOURCE], batch_inputs[const.SOURCE]
            ),
            const.TARGET_SENTENCE: sentence_features,
        }

        # Logits for multi-task fine-tuning
        if include_logits:
            # FIXME: this is wrong
            raise NotImplementedError('Logic not implemented for the XLMR encoder.')

        return output_features

    @staticmethod
    def concat_input(source_batch, target_batch, pad_id):
        """Concatenate tensors of two batches into one tensor.

        Return:
            the concatenation, a mask of types (a as zeroes and b as ones)
                and concatenation of attention_mask.
        """
        source_ids = source_batch.tensor
        target_ids = target_batch.tensor

        source_attention_mask = retrieve_tokens_mask(source_batch)
        target_attention_mask = retrieve_tokens_mask(target_batch)

        input_ids = torch.cat((target_ids, source_ids), dim=1)

        # XLMR does not use NSP
        token_type_ids = torch.zeros_like(input_ids)

        attention_mask = torch.cat(
            (target_attention_mask, source_attention_mask), dim=1
        )
        return input_ids, token_type_ids, attention_mask

    @staticmethod
    def split_outputs(features, batch_inputs, interleaved=False):
        """Split contexts to get tag_side outputs.

        Arguments:
            features (tensor): XLMRoberta output: <s> target </s> </s> source </s>
                Shape of (bs, 1 + target_len + 2 + source_len + 1, 2)
            batch_inputs:
            interleaved (bool): whether the concat strategy was 'interleaved'.

        Return:
            dict of tensors, one per tag side.
        """
        outputs = OrderedDict()

        target_lengths = batch_inputs[const.TARGET].lengths

        if interleaved:
            raise NotImplementedError('interleaving not supported.')
            # TODO: fix code below to use the lengths information and not bounds
            # if interleaved, shift each source sample by its correspondent length
            shift = target_lengths.unsqueeze(-1)

            range_vector = torch.arange(
                features.size(0), device=features.device
            ).unsqueeze(1)

            target_bounds = batch_inputs[const.TARGET].bounds
            target_features = features[range_vector, target_bounds]
            # Shift bounds by target length and preserve padding
            source_bounds = batch_inputs[const.SOURCE].bounds
            m = (source_bounds != -1).long()  # for masking out padding (which is -1)
            shifted_bounds = (source_bounds + shift) * m + source_bounds * (1 - m)
            source_features = features[range_vector, shifted_bounds]
        else:
            # otherwise, shift all by max_length
            # if we'd like to maintain the word pieces we merely select all
            target_features = features[:, : target_lengths.max()]
            # ignore the target and get the rest
            source_features = features[:, target_lengths.max() :]

        outputs[const.TARGET] = target_features

        # Source doesn't have an init_token (like CLS) and we keep SEP
        outputs[const.SOURCE] = source_features

        return outputs

    @staticmethod
    def interleave_input(source_batch, target_batch, pad_id):
        """Interleave the source + target embeddings into one tensor.

        This means making the input as [batch, target [SEP] source].

        Return:
            interleave of embds, mask of target (as zeroes) and source (as ones)
                and concatenation of attention_mask
        """
        source_ids = source_batch.tensor
        target_ids = target_batch.tensor

        batch_size = source_ids.size(0)

        source_lengths = source_batch.lengths
        target_lengths = target_batch.lengths

        max_pair_length = source_ids.size(1) + target_ids.size(1)

        input_ids = torch.full(
            (batch_size, max_pair_length),
            pad_id,
            dtype=torch.long,
            device=source_ids.device,
        )
        token_type_ids = torch.zeros_like(input_ids)
        attention_mask = torch.zeros_like(input_ids)

        for i in range(batch_size):
            # [CLS] and [SEP] are included in the mask (=1)
            # note: source does not have CLS
            t_len = target_lengths[i].item()
            s_len = source_lengths[i].item()

            input_ids[i, :t_len] = target_ids[i, :t_len]
            token_type_ids[i, :t_len] = 0
            attention_mask[i, :t_len] = 1

            input_ids[i, t_len : t_len + s_len] = source_ids[i, :s_len]
            # XLMR does not use NSP
            # token_type_ids[i, t_len : t_len + s_len] = 1
            attention_mask[i, t_len : t_len + s_len] = 1

        # TODO, why is attention mask 1 for all positions?
        return input_ids, token_type_ids, attention_mask

    @staticmethod
    def get_mismatch_features(logits, target, pred):
        # calculate mismatch features and concat them
        t_max = torch.gather(logits, -1, target.unsqueeze(-1))
        p_max = torch.gather(logits, -1, pred.unsqueeze(-1))
        diff_max = t_max - p_max
        diff_arg = (target != pred).float().unsqueeze(-1)
        mismatch = torch.cat((t_max, p_max, diff_max, diff_arg), dim=-1)
        return mismatch