Unbabel/OpenKiwi

View on GitHub
kiwi/modules/token_embeddings.py

Summary

Maintainability
A
1 hr
Test Coverage
B
84%
#  OpenKiwi: Open-Source Machine Translation Quality Estimation
#  Copyright (C) 2020 Unbabel <openkiwi@unbabel.com>
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Affero General Public License as published
#  by the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Affero General Public License for more details.
#
#  You should have received a copy of the GNU Affero General Public License
#  along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
import math

from torch import nn

from kiwi.data.batch import BatchedSentence
from kiwi.modules.common.layer_norm import TFLayerNorm
from kiwi.modules.common.positional_encoding import PositionalEncoding
from kiwi.utils.io import BaseConfig


class TokenEmbeddings(nn.Module):
    class Config(BaseConfig):
        dim: int = 50
        freeze: bool = False
        dropout: float = 0.0
        use_position_embeddings: bool = False
        max_position_embeddings: int = 4000
        sparse_embeddings: bool = False
        scale_embeddings: bool = False
        input_layer_norm: bool = False

    def __init__(self, num_embeddings: int, pad_idx: int, config: Config, vectors=None):
        """A model for embedding a single type of tokens."""
        super().__init__()
        self.pad_idx = pad_idx

        if vectors is not None:
            assert num_embeddings == vectors.size(0)

            self.embedding = nn.Embedding(
                num_embeddings=num_embeddings,
                embedding_dim=config.dim,
                padding_idx=pad_idx,
                sparse=config.sparse_embeddings,
                _weight=vectors,
            )
        else:
            self.embedding = nn.Embedding(
                num_embeddings=num_embeddings,
                embedding_dim=config.dim,
                padding_idx=pad_idx,
                sparse=config.sparse_embeddings,
            )
            nn.init.xavier_uniform_(self.embedding.weight)

        self._size = config.dim
        self._pe = config.max_position_embeddings

        if config.freeze:
            self.embedding.weight.requires_grad = False

        self.dropout = nn.Dropout(config.dropout)

        self.embeddings_scale_factor = 1
        if config.scale_embeddings:
            self.embeddings_scale_factor = math.sqrt(self._size)

        self.positional_encoding = None
        if config.use_position_embeddings:
            self.positional_encoding = PositionalEncoding(self._pe, self._size)

        self.layer_norm = None
        if config.input_layer_norm:
            self.layer_norm = TFLayerNorm(self._size)

    @property
    def num_embeddings(self):
        return self.embedding.num_embeddings

    def size(self):
        return self._size

    def forward(self, batch_input, *args):
        assert isinstance(batch_input, BatchedSentence)
        ids = batch_input.tensor

        embeddings = self.embedding(ids)
        embeddings = self.embeddings_scale_factor * embeddings

        if self.positional_encoding is not None:
            embeddings = self.positional_encoding(embeddings)

        if self.layer_norm is not None:
            embeddings = self.layer_norm(embeddings)

        embeddings = self.dropout(embeddings)

        return embeddings