tensorflow/models

View on GitHub
official/nlp/modeling/networks/albert_encoder.py

Summary

Maintainability
A
3 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""ALBERT (https://arxiv.org/abs/1810.04805) text encoder network."""
# pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf, tf_keras

from official.modeling import activations
from official.modeling import tf_utils
from official.nlp.modeling import layers


@tf_keras.utils.register_keras_serializable(package='Text')
class AlbertEncoder(tf_keras.Model):
  """ALBERT (https://arxiv.org/abs/1810.04805) text encoder network.

  This network implements the encoder described in the paper "ALBERT: A Lite
  BERT for Self-supervised Learning of Language Representations"
  (https://arxiv.org/abs/1909.11942).

  Compared with BERT (https://arxiv.org/abs/1810.04805), ALBERT refactorizes
  embedding parameters into two smaller matrices and shares parameters
  across layers.

  The default values for this object are taken from the ALBERT-Base
  implementation described in the paper.

  *Note* that the network is constructed by Keras Functional API.

  Args:
    vocab_size: The size of the token vocabulary.
    embedding_width: The width of the word embeddings. If the embedding width is
      not equal to hidden size, embedding parameters will be factorized into two
      matrices in the shape of `(vocab_size, embedding_width)` and
      `(embedding_width, hidden_size)`, where `embedding_width` is usually much
      smaller than `hidden_size`.
    hidden_size: The size of the transformer hidden layers.
    num_layers: The number of transformer layers.
    num_attention_heads: The number of attention heads for each transformer. The
      hidden size must be divisible by the number of attention heads.
    max_sequence_length: The maximum sequence length that this encoder can
      consume. If None, max_sequence_length uses the value from sequence length.
      This determines the variable shape for positional embeddings.
    type_vocab_size: The number of types that the 'type_ids' input can take.
    intermediate_size: The intermediate size for the transformer layers.
    activation: The activation to use for the transformer layers.
    dropout_rate: The dropout rate to use for the transformer layers.
    attention_dropout_rate: The dropout rate to use for the attention layers
      within the transformer layers.
    initializer: The initialzer to use for all weights in this encoder.
    dict_outputs: Whether to use a dictionary as the model outputs.
  """

  def __init__(self,
               vocab_size,
               embedding_width=128,
               hidden_size=768,
               num_layers=12,
               num_attention_heads=12,
               max_sequence_length=512,
               type_vocab_size=16,
               intermediate_size=3072,
               activation=activations.gelu,
               dropout_rate=0.1,
               attention_dropout_rate=0.1,
               initializer=tf_keras.initializers.TruncatedNormal(stddev=0.02),
               dict_outputs=False,
               **kwargs):
    activation = tf_keras.activations.get(activation)
    initializer = tf_keras.initializers.get(initializer)

    word_ids = tf_keras.layers.Input(
        shape=(None,), dtype=tf.int32, name='input_word_ids')
    mask = tf_keras.layers.Input(
        shape=(None,), dtype=tf.int32, name='input_mask')
    type_ids = tf_keras.layers.Input(
        shape=(None,), dtype=tf.int32, name='input_type_ids')

    if embedding_width is None:
      embedding_width = hidden_size
    embedding_layer = layers.OnDeviceEmbedding(
        vocab_size=vocab_size,
        embedding_width=embedding_width,
        initializer=tf_utils.clone_initializer(initializer),
        name='word_embeddings')
    word_embeddings = embedding_layer(word_ids)

    # Always uses dynamic slicing for simplicity.
    position_embedding_layer = layers.PositionEmbedding(
        initializer=tf_utils.clone_initializer(initializer),
        max_length=max_sequence_length,
        name='position_embedding')
    position_embeddings = position_embedding_layer(word_embeddings)

    type_embeddings = (
        layers.OnDeviceEmbedding(
            vocab_size=type_vocab_size,
            embedding_width=embedding_width,
            initializer=tf_utils.clone_initializer(initializer),
            use_one_hot=True,
            name='type_embeddings')(type_ids))

    embeddings = tf_keras.layers.Add()(
        [word_embeddings, position_embeddings, type_embeddings])
    embeddings = (
        tf_keras.layers.LayerNormalization(
            name='embeddings/layer_norm',
            axis=-1,
            epsilon=1e-12,
            dtype=tf.float32)(embeddings))
    embeddings = (tf_keras.layers.Dropout(rate=dropout_rate)(embeddings))
    # We project the 'embedding' output to 'hidden_size' if it is not already
    # 'hidden_size'.
    if embedding_width != hidden_size:
      embeddings = tf_keras.layers.EinsumDense(
          '...x,xy->...y',
          output_shape=hidden_size,
          bias_axes='y',
          kernel_initializer=tf_utils.clone_initializer(initializer),
          name='embedding_projection')(
              embeddings)

    data = embeddings
    attention_mask = layers.SelfAttentionMask()(data, mask)
    shared_layer = layers.TransformerEncoderBlock(
        num_attention_heads=num_attention_heads,
        inner_dim=intermediate_size,
        inner_activation=activation,
        output_dropout=dropout_rate,
        attention_dropout=attention_dropout_rate,
        kernel_initializer=tf_utils.clone_initializer(initializer),
        name='transformer')
    encoder_outputs = []
    for _ in range(num_layers):
      data = shared_layer([data, attention_mask])
      encoder_outputs.append(data)

    # Applying a tf.slice op (through subscript notation) to a Keras tensor
    # like this will create a SliceOpLambda layer. This is better than a Lambda
    # layer with Python code, because that is fundamentally less portable.
    first_token_tensor = data[:, 0, :]
    cls_output = tf_keras.layers.Dense(
        units=hidden_size,
        activation='tanh',
        kernel_initializer=tf_utils.clone_initializer(initializer),
        name='pooler_transform')(
            first_token_tensor)
    if dict_outputs:
      outputs = dict(
          sequence_output=data,
          encoder_outputs=encoder_outputs,
          pooled_output=cls_output,
      )
    else:
      outputs = [data, cls_output]

    # b/164516224
    # Once we've created the network using the Functional API, we call
    # super().__init__ as though we were invoking the Functional API Model
    # constructor, resulting in this object having all the properties of a model
    # created using the Functional API. Once super().__init__ is called, we
    # can assign attributes to `self` - note that all `self` assignments are
    # below this line.
    super().__init__(
        inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs)
    config_dict = {
        'vocab_size': vocab_size,
        'embedding_width': embedding_width,
        'hidden_size': hidden_size,
        'num_layers': num_layers,
        'num_attention_heads': num_attention_heads,
        'max_sequence_length': max_sequence_length,
        'type_vocab_size': type_vocab_size,
        'intermediate_size': intermediate_size,
        'activation': tf_keras.activations.serialize(activation),
        'dropout_rate': dropout_rate,
        'attention_dropout_rate': attention_dropout_rate,
        'initializer': tf_keras.initializers.serialize(initializer),
    }

    # We are storing the config dict as a namedtuple here to ensure checkpoint
    # compatibility with an earlier version of this model which did not track
    # the config dict attribute. TF does not track immutable attrs which
    # do not contain Trackables, so by creating a config namedtuple instead of
    # a dict we avoid tracking it.
    config_cls = collections.namedtuple('Config', config_dict.keys())
    self._config = config_cls(**config_dict)
    self._embedding_layer = embedding_layer
    self._position_embedding_layer = position_embedding_layer

  def get_embedding_table(self):
    return self._embedding_layer.embeddings

  def get_config(self):
    return dict(self._config._asdict())

  @classmethod
  def from_config(cls, config):
    return cls(**config)