Unbabel/OpenKiwi

View on GitHub
kiwi/data/encoders/wmt_qe_data_encoder.py

Summary

Maintainability
A
1 hr
Test Coverage
B
82%
#  OpenKiwi: Open-Source Machine Translation Quality Estimation
#  Copyright (C) 2020 Unbabel <openkiwi@unbabel.com>
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Affero General Public License as published
#  by the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Affero General Public License for more details.
#
#  You should have received a copy of the GNU Affero General Public License
#  along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, Generic, Optional, TypeVar

from pydantic import PositiveInt
from pydantic.class_validators import validator
from pydantic.generics import GenericModel
from typing_extensions import Literal

import kiwi.constants as const
from kiwi.data import tokenizers
from kiwi.data.batch import MultiFieldBatch
from kiwi.data.datasets.wmt_qe_dataset import WMTQEDataset
from kiwi.data.encoders.base import DataEncoders
from kiwi.data.encoders.field_encoders import (
    AlignmentEncoder,
    BinaryScoreEncoder,
    ScoreEncoder,
    TagEncoder,
    TextEncoder,
)
from kiwi.utils.io import (
    BaseConfig,
    load_torch_file,
    target_gaps_to_gaps,
    target_gaps_to_target,
)

logger = logging.getLogger(__name__)

T = TypeVar('T')


class InputFields(GenericModel, Generic[T]):
    source: T
    target: T


class EmbeddingsConfig(BaseConfig):
    """Paths to word embeddings file for each input field."""

    source: Optional[Path]
    target: Optional[Path]
    post_edit: Optional[Path] = None
    source_pos: Optional[Path] = None
    target_pos: Optional[Path] = None

    format: Literal['polyglot', 'word2vec', 'fasttext', 'glove', 'text'] = 'polyglot'
    """Word embeddings format. See README for specific formatting instructions."""


class VocabularyConfig(BaseConfig):
    min_frequency: InputFields[PositiveInt] = 1
    """Only add to vocabulary words that occur more than this number of times in the
    training dataset (doesn't apply to loaded or pretrained vocabularies)."""

    max_size: InputFields[Optional[PositiveInt]] = None
    """Only create vocabulary with up to this many words (doesn't apply to loaded or
    pretrained vocabularies)."""

    keep_rare_words_with_embeddings = False
    """Keep words that occur less then min-frequency but are
    in embeddings vocabulary."""

    add_embeddings_vocab = False
    """Add words from embeddings vocabulary to source/target vocabulary."""

    # TODO: add this feature
    # source_max_length: PositiveInt = float("inf")
    # 'Maximum source sequence length'
    # source_min_length: int = 1
    # 'Truncate source sequence length.'
    # target_max_length: PositiveInt = float("inf")
    # 'Maximum target sequence length to keep.'
    # target_min_length: int = 1
    # 'Truncate target sequence length.'

    @validator('min_frequency', 'max_size', pre=True, always=True)
    def check_nested_options(cls, v):
        if isinstance(v, int) or v is None:
            return {'source': v, 'target': v}
        return v


class WMTQEDataEncoder(DataEncoders):
    class Config(BaseConfig):
        share_input_fields_encoders: bool = False
        """Share encoding/vocabs between source and target fields."""

        vocab: VocabularyConfig = VocabularyConfig()
        embeddings: Optional[EmbeddingsConfig] = None

        @validator('embeddings', pre=True, always=False)
        def warn_missing_feature(cls, v):
            if v is not None:
                raise ValueError('Embeddings are not yet supported in the new version.')
            return None

    def __init__(self, config: Config, field_encoders: Dict[str, TextEncoder] = None):
        super().__init__()
        self._vocabularies_initialized = False
        self.config = config

        self.field_encoders = {}

        if field_encoders is None:
            field_encoders = {}

        # Inputs
        if const.SOURCE in field_encoders and const.TARGET in field_encoders:
            self.field_encoders[const.SOURCE] = field_encoders[const.SOURCE]
            self.field_encoders[const.TARGET] = field_encoders[const.TARGET]
        elif const.SOURCE not in field_encoders and const.TARGET not in field_encoders:
            self.field_encoders[const.TARGET] = TextEncoder()
            if config.share_input_fields_encoders:
                self.field_encoders[const.SOURCE] = self.field_encoders[const.TARGET]
            else:
                self.field_encoders[const.SOURCE] = TextEncoder()
        elif const.TARGET in field_encoders:
            self.field_encoders[const.TARGET] = field_encoders[const.TARGET]
            if config.share_input_fields_encoders:
                self.field_encoders[const.SOURCE] = self.field_encoders[const.TARGET]
            else:
                raise ValueError(
                    f'QE requires encoders for {const.SOURCE} and {const.TARGET} '
                    f'fields, but only {const.TARGET} was provided and share vocab '
                    f'is false.'
                )
        elif const.SOURCE in field_encoders:
            self.field_encoders[const.SOURCE] = field_encoders[const.SOURCE]
            if config.share_input_fields_encoders:
                self.field_encoders[const.TARGET] = self.field_encoders[const.SOURCE]
            else:
                raise ValueError(
                    f'QE requires encoders for {const.SOURCE} and {const.TARGET} '
                    f'fields, but only {const.SOURCE} was provided and share vocab '
                    f'is false.'
                )
        # TLM multi-task fine-tuning output
        if const.PE in field_encoders:
            self.field_encoders[const.PE] = field_encoders[const.PE]
        else:
            self.field_encoders[const.PE] = self.field_encoders[const.TARGET]

        # Word level outputs
        logger.debug('Assuming WMT18 format for target tags outputs')
        self.field_encoders[const.TARGET_TAGS] = TagEncoder(
            tokenize=lambda x: target_gaps_to_target(tokenizers.tokenize(x))
        )
        self.field_encoders[const.GAP_TAGS] = TagEncoder(
            tokenize=lambda x: target_gaps_to_gaps(tokenizers.tokenize(x))
        )
        self.field_encoders[const.SOURCE_TAGS] = TagEncoder()

        # Sentence level outputs
        # if const.SENTENCE_SCORES in data_source:
        self.field_encoders[const.SENTENCE_SCORES] = ScoreEncoder()
        self.field_encoders[const.BINARY] = BinaryScoreEncoder()

        # Optional fields
        # For optional fields, check whether they exist in data_sources
        # if const.TARGET_POS in data_source:
        self.field_encoders[const.ALIGNMENTS] = AlignmentEncoder()

    def fit_vocabularies(self, dataset: WMTQEDataset):
        # Inputs
        if self.config.share_input_fields_encoders:
            target_samples = dataset[const.TARGET] + dataset[const.SOURCE]
            source_samples = None
        else:
            target_samples = dataset[const.TARGET]
            source_samples = dataset[const.SOURCE]

        if self.field_encoders[const.TARGET].vocab is not None:
            logger.warning(
                f'Vocabulary for {const.TARGET} already exists; '
                f'not going to fit it to data'
            )
        else:
            self.field_encoders[const.TARGET].fit_vocab(
                target_samples,
                vocab_size=self.config.vocab.max_size.target,
                vocab_min_freq=self.config.vocab.min_frequency.target,
            )
        if source_samples:
            if self.field_encoders[const.SOURCE].vocab is not None:
                logger.warning(
                    f'Vocabulary for {const.SOURCE} already exists; '
                    f'not going to fit it to data'
                )
            else:
                self.field_encoders[const.SOURCE].fit_vocab(
                    source_samples,
                    vocab_size=self.config.vocab.max_size.source,
                    vocab_min_freq=self.config.vocab.min_frequency.source,
                )

        # Outputs
        for tag_output in [const.TARGET_TAGS, const.SOURCE_TAGS]:
            if tag_output in dataset:
                if self.field_encoders[tag_output].vocab is not None:
                    logger.warning(
                        f'Vocabulary for {tag_output} already exists; '
                        f'not going to fit it to data'
                    )
                else:
                    self.field_encoders[tag_output].fit_vocab(dataset[tag_output])
        self.field_encoders[const.GAP_TAGS].vocab = self.field_encoders[
            const.TARGET_TAGS
        ].vocab

        if const.TARGET_TAGS in dataset:
            # With data, we can check whether we really have data in WMT18 format
            target_sample = dataset[const.TARGET][0]
            _, target_bounds, *_ = self.field_encoders[const.TARGET].encode(
                target_sample
            )
            extras = 1 if self.field_encoders[const.TARGET].bos_token else 0
            extras += 1 if self.field_encoders[const.TARGET].eos_token else 0
            target_sample_length = len(target_bounds) - extras

            tag_sample = dataset[const.TARGET_TAGS][0]
            tags, *_ = self.field_encoders[const.TARGET_TAGS].encode(tag_sample)
            if len(tags) == target_sample_length // 2:
                # Pre WMT18 format (old format; gap tags in a separate file)
                logger.debug('Inferred old WMT17 format for tags outputs')
                self.field_encoders[const.TARGET_TAGS] = TagEncoder()
                # Fit again
                self.field_encoders[const.TARGET_TAGS].fit_vocab(
                    dataset[const.TARGET_TAGS]
                )
                self.field_encoders[const.GAP_TAGS] = self.field_encoders[
                    const.TARGET_TAGS
                ]

        self._vocabularies_initialized = True

    def load_vocabularies(self, load_vocabs_from: Path = None, overwrite: bool = False):
        """Load serialized Vocabularies from disk into fields."""
        logger.info(f'Loading vocabularies from: {load_vocabs_from}')
        vocabs_dict = load_torch_file(load_vocabs_from)
        if const.VOCAB not in vocabs_dict:
            raise KeyError(f'File {load_vocabs_from} has no {const.VOCAB}')

        return self.vocabularies_from_dict(vocabs_dict[const.VOCAB], overwrite)

    def vocabularies_from_dict(self, vocabs_dict: Dict, overwrite: bool = False):
        # Inputs
        input_fields = [const.TARGET, const.SOURCE, const.PE]
        output_fields = [const.TARGET_TAGS, const.SOURCE_TAGS, const.GAP_TAGS]

        extraneous_fields = [
            field for field in vocabs_dict if field not in input_fields + output_fields
        ]
        if extraneous_fields:
            logger.warning(
                f'Unknown vocabularies and thus not loaded: {extraneous_fields}'
            )

        fields_for_loading = [
            field for field in vocabs_dict if field in input_fields + output_fields
        ]
        for field in fields_for_loading:
            if self.field_encoders[field].vocab is not None and not overwrite:
                logger.warning(
                    f'Vocabulary for {field} already exists; not loading it again'
                )
            else:
                self.field_encoders[field].vocab = vocabs_dict[field]

        self._vocabularies_initialized = True
        return self.vocabularies

    @property
    def vocabularies(self):
        """Return the vocabularies for all encoders that have one.

        Return:
            A dict mapping encoder names to Vocabulary instances.
        """
        if not self._vocabularies_initialized:
            return None

        vocabs = {}
        for name, encoder in self.field_encoders.items():
            if encoder is not None and encoder.vocabulary:
                vocabs[name] = encoder.vocabulary
        return vocabs

    def collate_fn(self, samples, device=None):
        if not self._vocabularies_initialized:
            raise ValueError(
                f'Vocabularies have not been initialized; need to first call '
                f'``{self.__class__.__name__}'
                f'.fit_vocabularies({WMTQEDataset.__name__})`` or'
                f'``{self.__class__.__name__}.load_vocabularies(Path)``'
            )

        columns_data = defaultdict(list)
        for sample in samples:
            for column, value in sample.items():
                # data = self.data_encoders[column].encode(value)
                columns_data[column].append(value)
            # FIXME: hack to have gap_tags in the batch
            if const.TARGET_TAGS in sample and const.GAP_TAGS not in sample:
                columns_data[const.GAP_TAGS].append(sample[const.TARGET_TAGS])
            if const.SENTENCE_SCORES in sample and const.BINARY in self.field_encoders:
                columns_data[const.BINARY].append(sample[const.SENTENCE_SCORES])

        batch = {
            column: self.field_encoders[column].batch_encode(examples)
            for column, examples in columns_data.items()
        }

        batch = MultiFieldBatch(batch)

        return batch