tensorflow/models

View on GitHub
official/nlp/modeling/layers/attention.py

Summary

Maintainability
B
4 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Keras-based attention layer."""
# pylint: disable=g-classes-have-attributes
import math

import tensorflow as tf, tf_keras

EinsumDense = tf_keras.layers.EinsumDense
MultiHeadAttention = tf_keras.layers.MultiHeadAttention


@tf_keras.utils.register_keras_serializable(package="Text")
class CachedAttention(tf_keras.layers.MultiHeadAttention):
  """Attention layer with cache used for autoregressive decoding.

  Arguments are the same as `tf_keras.layers.MultiHeadAttention` layer.
  """

  def _update_cache(self, key, value, cache, decode_loop_step):
    """Updates cache states and gets full-length key/value tensors."""
    # Combines cached keys and values with new keys and values.
    if decode_loop_step is not None:
      # TPU special case.
      key_seq_dim = cache["key"].shape.as_list()[1]
      indices = tf.reshape(
          tf.one_hot(decode_loop_step, key_seq_dim, dtype=key.dtype),
          [1, key_seq_dim, 1, 1])
      key = cache["key"] + key * indices
      value_seq_dim = cache["value"].shape.as_list()[1]
      indices = tf.reshape(
          tf.one_hot(decode_loop_step, value_seq_dim, dtype=value.dtype),
          [1, value_seq_dim, 1, 1])
      value = cache["value"] + value * indices
    else:
      key = tf.concat([tf.cast(cache["key"], key.dtype), key], axis=1)
      value = tf.concat([tf.cast(cache["value"], value.dtype), value], axis=1)

    # Update cache
    cache["key"] = key
    cache["value"] = value

    return key, value

  def call(self,
           query,
           value,
           key=None,
           attention_mask=None,
           cache=None,
           decode_loop_step=None,
           return_attention_scores=False):
    if not self._built_from_signature:
      self._build_from_signature(query=query, value=value, key=key)
    if key is None:
      key = value

    # Scalar dimensions referenced here:
    #   B = batch size (number of sequences)
    #   F = `from_tensor` sequence length
    #   T = `to_tensor` sequence length
    #   N = `num_attention_heads`
    #   H = `size_per_head`
    # `query` = [B, F, N ,H]
    query = self._query_dense(query)

    # `key` = [B, T, N, H]
    key = self._key_dense(key)

    # `value` = [B, T, N, H]
    value = self._value_dense(value)

    if cache:
      key, value = self._update_cache(key, value, cache, decode_loop_step)

    query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim)))

    # Take the dot product between "query" and "key" to get the raw
    # attention scores.
    attention_scores = tf.einsum(self._dot_product_equation, key, query)

    # Normalize the attention scores to probabilities.
    # `attention_scores` = [B, N, F, T]
    attention_scores = self._masked_softmax(attention_scores, attention_mask)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_scores = self._dropout_layer(attention_scores)
    # `context_layer` = [B, F, N, H]
    attention_output = tf.einsum(self._combine_equation, attention_scores,
                                 value)
    attention_output = self._output_dense(attention_output)
    if return_attention_scores:
      return attention_output, attention_scores, cache
    return attention_output, cache