monarch-initiative/N2V

View on GitHub
embiggen/layers/tensorflow/flat_embedding.py

Summary

Maintainability
A
35 mins
Test Coverage
"""Submodule providing flat embedding layer."""
from typing import Dict
import tensorflow as tf
from tensorflow.keras.layers import GlobalAveragePooling1D, Layer, Embedding  # pylint: disable=import-error,no-name-in-module


class FlatEmbedding(Layer):
    """Layer implementing simple wrapped embedding layer plus a flatten."""

    def __init__(
        self,
        vocabulary_size: int,
        dimension: int,
        input_length: int,
        mask_zero: bool = False,
        **kwargs: Dict
    ):
        """Create new FlatEmbedding layer.

        Parameters
        ----------------------
        vocabulary_size: int
            The number of elements in the embedding.
        dimension: int
            The dimensionality of the embedding.
        input_length: int
            The expected input length of the embedding.
        mask_zero: bool = False
            Whether to treat zero inputs as if they are a mask.
        **kwargs: Dict,
            Kwargs to pass to the parent Layer class.
        """
        super().__init__(**kwargs)
        self._flatten_layer = None
        self._embedding_layer = None
        self._vocabulary_size = vocabulary_size
        self._dimension = dimension
        self._input_length = input_length
        self._mask_zero = mask_zero

    def build(self, input_shape) -> None:
        """Build the Graph Convolution layer.

        Parameters
        ------------------------------
        input_shape
            Shape of the output of the previous layer.
        """
        self._flatten_layer = GlobalAveragePooling1D()
        self._embedding_layer = Embedding(
            input_dim=self._vocabulary_size,
            output_dim=self._dimension,
            input_length=self._input_length,
            mask_zero=self._mask_zero
        )
        super().build(input_shape)

    def call(
        self,
        inputs: tf.Tensor
    ) -> tf.Tensor:
        """Returns called flattened embedding.

        Parameters
        ---------------------------
        inputs: tf.Tensor
        """
        return self._flatten_layer(
            self._embedding_layer(inputs)
        )