tensorflow/models

View on GitHub
official/nlp/modeling/layers/position_embedding.py

Summary

Maintainability
A
45 mins
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Keras-based positional embedding layer."""
# pylint: disable=g-classes-have-attributes
import math
from typing import Optional

import tensorflow as tf, tf_keras

from official.modeling import tf_utils

Initializer = tf_keras.initializers.Initializer


@tf_keras.utils.register_keras_serializable(package="Text")
class PositionEmbedding(tf_keras.layers.Layer):
  """Creates a positional embedding.

  Example:
  ```python
  position_embedding = PositionEmbedding(max_length=100)
  inputs = tf_keras.Input((100, 32), dtype=tf.float32)
  outputs = position_embedding(inputs)
  ```


  Args:
    max_length: The maximum size of the dynamic sequence.
    initializer: The initializer to use for the embedding weights. Defaults to
      "glorot_uniform".
    seq_axis: The axis of the input tensor where we add the embeddings.

  Reference: This layer creates a positional embedding as described in
  [BERT: Pre-training of Deep Bidirectional Transformers for Language
  Understanding](https://arxiv.org/abs/1810.04805).
  """

  def __init__(self,
               max_length,
               initializer="glorot_uniform",
               seq_axis=1,
               **kwargs):

    super().__init__(**kwargs)
    if max_length is None:
      raise ValueError(
          "`max_length` must be an Integer, not `None`."
      )
    self._max_length = max_length
    self._initializer = tf_keras.initializers.get(initializer)
    self._seq_axis = seq_axis

  def get_config(self):
    config = {
        "max_length": self._max_length,
        "initializer": tf_keras.initializers.serialize(self._initializer),
        "seq_axis": self._seq_axis,
    }
    base_config = super(PositionEmbedding, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def build(self, input_shape):
    dimension_list = input_shape
    width = dimension_list[-1]
    weight_sequence_length = self._max_length

    self._position_embeddings = self.add_weight(
        "embeddings",
        shape=[weight_sequence_length, width],
        initializer=self._initializer)

    super().build(input_shape)

  def call(self, inputs):
    input_shape = tf.shape(inputs)
    actual_seq_len = input_shape[self._seq_axis]
    position_embeddings = self._position_embeddings[:actual_seq_len, :]
    new_shape = [1 for _ in inputs.get_shape().as_list()]
    new_shape[self._seq_axis] = actual_seq_len
    new_shape[-1] = position_embeddings.get_shape().as_list()[-1]
    position_embeddings = tf.reshape(position_embeddings, new_shape)
    return tf.broadcast_to(position_embeddings, input_shape)


@tf_keras.utils.register_keras_serializable(package="Text")
class RelativePositionEmbedding(tf_keras.layers.Layer):
  """Creates a positional embedding.

  This layer calculates the position encoding as a mix of sine and cosine
  functions with geometrically increasing wavelengths. Defined and formulized in
   "Attention is All You Need", section 3.5.
  (https://arxiv.org/abs/1706.03762).

  Args:
    hidden_size: Size of the hidden layer.
    min_timescale: Minimum scale that will be applied at each position
    max_timescale: Maximum scale that will be applied at each position.
  """

  def __init__(self,
               hidden_size: int,
               min_timescale: float = 1.0,
               max_timescale: float = 1.0e4,
               **kwargs):
    # We need to have a default dtype of float32, since the inputs (which Keras
    # usually uses to infer the dtype) will always be int32.
    # We compute the positional encoding in float32 even if the model uses
    # float16, as many of the ops used, like log and exp, are numerically
    # unstable in float16.
    if "dtype" not in kwargs:
      kwargs["dtype"] = "float32"

    super().__init__(**kwargs)
    self._hidden_size = hidden_size
    self._min_timescale = min_timescale
    self._max_timescale = max_timescale

  def get_config(self):
    config = {
        "hidden_size": self._hidden_size,
        "min_timescale": self._min_timescale,
        "max_timescale": self._max_timescale,
    }
    base_config = super(RelativePositionEmbedding, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def call(self, inputs, length=None):
    """Implements call() for the layer.

    Args:
      inputs: An tensor whose second dimension will be used as `length`. If
        `None`, the other `length` argument must be specified.
      length: An optional integer specifying the number of positions. If both
        `inputs` and `length` are spcified, `length` must be equal to the second
        dimension of `inputs`.

    Returns:
      A tensor in shape of `(length, hidden_size)`.
    """
    if inputs is None and length is None:
      raise ValueError("If inputs is None, `length` must be set in "
                       "RelativePositionEmbedding().")
    if inputs is not None:
      input_shape = tf_utils.get_shape_list(inputs)
      if length is not None and length != input_shape[1]:
        raise ValueError(
            "If inputs is not None, `length` must equal to input_shape[1].")
      length = input_shape[1]
    position = tf.cast(tf.range(length), tf.float32)
    num_timescales = self._hidden_size // 2
    min_timescale, max_timescale = self._min_timescale, self._max_timescale
    log_timescale_increment = (
        math.log(float(max_timescale) / float(min_timescale)) /
        (tf.cast(num_timescales, tf.float32) - 1))
    inv_timescales = min_timescale * tf.exp(
        tf.cast(tf.range(num_timescales), tf.float32) *
        -log_timescale_increment)
    scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(
        inv_timescales, 0)
    position_embeddings = tf.concat(
        [tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
    return position_embeddings


def _relative_position_bucket(relative_position,
                              bidirectional=True,
                              num_buckets=32,
                              max_distance=128):
  """Translate relative position to a bucket number for relative attention.

  The relative position is defined as memory_position - query_position, i.e.
  the distance in tokens from the attending position to the attended-to
  position.

  If `bidirectional=False`, then positive relative positions are invalid.

  We use smaller buckets for small absolute relative_position and larger
  buckets for larger absolute relative_positions.

  All relative positions >=max_distance map to the same bucket.

  All relative positions <=-max_distance map to the same bucket.

  This should allow for more graceful generalization to longer sequences
  than the model has been trained on.

  Args:
    relative_position: An int32 Tensor
    bidirectional: A boolean - whether the attention is bidirectional
    num_buckets: An integer
    max_distance: An integer

  Returns:
    A Tensor with the same shape as relative_position, containing int32
    values in the range [0, num_buckets)
  """
  ret = 0
  n = -relative_position
  if bidirectional:
    num_buckets //= 2
    ret += tf.cast(tf.math.less(n, 0), tf.int32) * num_buckets
    n = tf.math.abs(n)
  else:
    n = tf.math.maximum(n, 0)
  # now n is in the range [0, inf)
  max_exact = num_buckets // 2
  is_small = tf.math.less(n, max_exact)
  val_if_large = max_exact + tf.dtypes.cast(
      tf.math.log(tf.cast(n, tf.float32) / max_exact) /
      math.log(max_distance / max_exact) * (num_buckets - max_exact),
      tf.int32,
  )
  val_if_large = tf.math.minimum(val_if_large, num_buckets - 1)
  ret += tf.where(is_small, n, val_if_large)
  return ret


@tf_keras.utils.register_keras_serializable(package="Text")
class RelativePositionBias(tf_keras.layers.Layer):
  """Relative position embedding via per-head bias in T5 style.

  Reference implementation in MeshTF:
  https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L1000

  This layer implements the relative position bias used in "Exploring the Limits
  of Transfer Learning with a Unified Text-to-Text Transformer"
  (https://arxiv.org/abs/1910.10683)
  """

  def __init__(self,
               num_heads: int,
               relative_attention_num_buckets: int = 32,
               relative_attention_max_distance: int = 128,
               bidirectional: bool = True,
               embeddings_initializer: Optional[Initializer] = None,
               **kwargs):
    super().__init__(**kwargs)
    self.num_heads = num_heads
    self.relative_attention_num_buckets = relative_attention_num_buckets
    self.bidirectional = bidirectional
    self.relative_attention_max_distance = relative_attention_max_distance
    if embeddings_initializer:
      self._embed_init = embeddings_initializer
    else:
      self._embed_init = tf_keras.initializers.TruncatedNormal(stddev=1.0)
    with tf.name_scope(self.name):
      self._relative_attention_bias = self.add_weight(
          "rel_embedding",
          shape=[self.relative_attention_num_buckets, self.num_heads],
          initializer=self._embed_init,
          dtype=self.dtype,
          trainable=True)

  def get_config(self):
    config = {
        "num_heads":
            self.num_heads,
        "relative_attention_num_buckets":
            self.relative_attention_num_buckets,
        "relative_attention_max_distance":
            self.relative_attention_max_distance,
        "bidirectional":
            self.bidirectional,
        "embeddings_initializer":
            tf_keras.initializers.serialize(self._embed_init),
    }
    base_config = super().get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def call(self, query: tf.Tensor, key: tf.Tensor):
    """Implements the forward pass.

    Args:
      query: query input tensor shape [batch, query length, hidden size].
      key: key input tensor shape [batch, key length, hidden size].

    Returns:
      A tensor in shape of [batch, heads, query length, key length].
    """
    batch_size, qlen = tf_utils.get_shape_list(query)[:2]
    klen = tf_utils.get_shape_list(key)[1]
    context_position = tf.range(qlen)[:, None]
    memory_position = tf.range(klen)[None, :]
    relative_position = memory_position - context_position
    rp_bucket = _relative_position_bucket(
        relative_position,
        bidirectional=self.bidirectional,
        num_buckets=self.relative_attention_num_buckets,
        max_distance=self.relative_attention_max_distance)
    values = tf.nn.embedding_lookup(self._relative_attention_bias, rp_bucket)
    values = tf.expand_dims(
        tf.transpose(values, [2, 0, 1]),
        axis=0)  # shape (1, num_heads, qlen, klen)
    values = tf.tile(values, [batch_size, 1, 1, 1])
    return values