tensorflow/models

View on GitHub
official/projects/roformer/roformer_attention.py

Summary

Maintainability
C
7 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Roformer attention layer."""
# pylint: disable=g-classes-have-attributes
import tensorflow as tf, tf_keras

EinsumDense = tf_keras.layers.EinsumDense
MultiHeadAttention = tf_keras.layers.MultiHeadAttention


def _build_trig_vector(length, key_dim):
  """Builds the trig vector."""
  tf_dtype = tf_keras.mixed_precision.global_policy().compute_dtype
  position_ids = tf.cast(tf.range(length), dtype=tf_dtype)
  position_ids = tf.expand_dims(position_ids, axis=0)
  steps = key_dim // 2
  # 2 (i - 1) / key_dim = (i - 1) / steps: (-1 achieved with zero-indexing)
  wavenumber_exponent = -tf.cast(tf.range(steps), dtype=tf_dtype) / steps
  wavenumbers = tf.pow(
      tf.constant(10000.0, dtype=tf_dtype), wavenumber_exponent
  )
  vec = tf.einsum('bl,d->bld', position_ids, wavenumbers)
  sin_vec = tf.repeat(tf.sin(vec), repeats=2, axis=-1)
  cos_vec = tf.repeat(tf.cos(vec), repeats=2, axis=-1)
  sin_vec, cos_vec = tf.expand_dims(sin_vec, 2), tf.expand_dims(cos_vec, 2)
  return sin_vec, cos_vec


@tf_keras.utils.register_keras_serializable(package='Text')
class RoformerAttention(tf_keras.layers.MultiHeadAttention):
  """Roformer Attention."""

  def __init__(self,
               q_max_sequence_length,
               kv_max_sequence_length,
               output_range=None,
               **kwargs):
    """Instantiates a roformer attention layer.

    Roformer paper: https://arxiv.org/abs/2104.09864

    Args:
      q_max_sequence_length: maximum length in input for the query
      kv_max_sequence_length: maximum length in input for key and value, can be
        different from q_max_sequence_length
      output_range: length of the query tensor to consider.
      **kwargs: other keyword arguments.
    """
    super().__init__(**kwargs)
    self._q_max_sequence_length = q_max_sequence_length
    self._kv_max_sequence_length = kv_max_sequence_length
    assert self._key_dim % 2 == 0
    q_sin_vec, q_cos_vec = _build_trig_vector(self._q_max_sequence_length,
                                              self._key_dim)
    k_sin_vec, k_cos_vec = _build_trig_vector(self._kv_max_sequence_length,
                                              self._key_dim)
    # pylint:disable=g-long-ternary
    self.q_sin_vec, self.q_cos_vec = (q_sin_vec,
                                      q_cos_vec) if output_range is None else (
                                          q_sin_vec[:, 0:output_range, ...],
                                          q_cos_vec[:, 0:output_range, ...])
    # pylint:enable=g-long-ternary
    self.k_sin_vec, self.k_cos_vec = (k_sin_vec, k_cos_vec)

  def roformer_recompute_qkv(self, q, k, v):
    q_shape = tf.shape(q)
    q_len = q_shape[1]
    k_shape = tf.shape(k)
    k_len = k_shape[1]

    q2 = tf.stack([-q[..., 1::2], q[..., ::2]], axis=4)
    q2 = tf.reshape(q2, q_shape)
    k2 = tf.stack([-k[..., 1::2], k[..., ::2]], axis=4)
    k2 = tf.reshape(k2, k_shape)
    ret_q = q * self.q_cos_vec[:, 0:q_len,
                               ...] + q2 * self.q_sin_vec[:, 0:q_len, ...]
    ret_w = k * self.k_cos_vec[:, 0:k_len,
                               ...] + k2 * self.k_sin_vec[:, 0:k_len, ...]
    return ret_q, ret_w, v

  def call(self,  # pytype: disable=signature-mismatch  # overriding-parameter-count-checks
           query,
           value,
           key=None,
           attention_mask=None,
           return_attention_scores=False,
           training=None):
    if not self._built_from_signature:
      self._build_from_signature(query=query, value=value, key=key)
    if key is None:
      key = value

    query = self._query_dense(query)
    key = self._key_dense(key)
    value = self._value_dense(value)

    query, key, value = self.roformer_recompute_qkv(query, key, value)

    attention_output, attention_scores = self._compute_attention(
        query, key, value, attention_mask, training)
    attention_output = self._output_dense(attention_output)

    if return_attention_scores:
      return attention_output, attention_scores
    return attention_output