tensorflow/models

View on GitHub
official/nlp/modeling/layers/multi_query_attention.py

Summary

Maintainability
A
35 mins
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Keras-based attention layers to support multi-query attention.

Based on https://arxiv.org/pdf/1911.02150.pdf and
https://arxiv.org/pdf/2305.13245.pdf.
"""

import string
from typing import Optional, Sequence, Union

import tensorflow as tf, tf_keras

_CHR_IDX = string.ascii_lowercase


def _build_proj_equation(
    free_dims: int, bound_dims: int, output_dims: int
) -> ...:
  """Builds an einsum equation for projections inside attention layer.

  Args:
    free_dims: The number of free dimensions which are copied from input to
      output.
    bound_dims: The number of bound dimensions part of input which are combined
      with the kernel to produce output.
    output_dims: The number of output dimensions.

  Returns:
    A tuple of einsum equation, bias axes and output rank.
  """

  input_str = ""
  kernel_str = ""
  output_str = ""
  bias_axes = ""
  letter_offset = 0
  for i in range(free_dims):
    char = _CHR_IDX[i + letter_offset]
    input_str += char
    output_str += char

  letter_offset += free_dims
  for i in range(bound_dims):
    char = _CHR_IDX[i + letter_offset]
    input_str += char
    kernel_str += char

  letter_offset += bound_dims
  for i in range(output_dims):
    char = _CHR_IDX[i + letter_offset]
    kernel_str += char
    output_str += char
    bias_axes += char
  equation = f"{input_str},{kernel_str}->{output_str}"

  return equation, bias_axes, len(output_str)


def _get_output_shape(
    output_rank: int, known_last_dims: Sequence[int]
) -> list[Optional[int]]:
  return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)


class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
  """Multi-query attention layer."""

  def __init__(self, num_kv_heads=None, **kwargs):
    # num_kv_heads defines the number of key/value heads. A value of 1 means
    # that the key/value heads are shared across all query heads. Any other
    # value must be less than num_heads and must divide num_heads exactly. If
    # num_kv_heads is greater than 1, query heads are split into groups of
    # num_kv_heads.
    super().__init__(**kwargs)
    self._num_kv_heads = num_kv_heads or self._num_heads
    assert (
        self._num_kv_heads < self._num_heads
    ), "num_kv_heads must be less than num_heads."
    assert (
        self._num_heads % self._num_kv_heads == 0
    ), "num_kv_heads needs to divide num_heads exactly."

  def _build_from_signature(
      self,
      query: Union[tf.Tensor, tf.TensorShape],
      value: Union[tf.Tensor, tf.TensorShape],
      key: Optional[Union[tf.Tensor, tf.TensorShape]] = None,
  ):
    """Builds layers and variables.

    Once the method is called, self._built_from_signature will be set to
    True.

    Args:
        query: Query tensor or TensorShape.
        value: Value tensor or TensorShape.
        key: Key tensor or TensorShape.
    """
    # pytype: disable=attribute-error
    super()._build_from_signature(query=query, value=value, key=key)
    # pytype: enable=attribute-error

    with tf.init_scope():
      # Key, value are shared across heads in multi-query attention.
      # Overwrite the K, V projections, logits & attend einsum equations to
      # remove the number of attention head dimension in K, V related tensors.
      #
      # The following capital letters are used to denote the tensor dimension
      # parameters:
      # B = batch size
      # S = length of the key/value (source)
      # T = length of the query (target)
      # N = number of query attention heads
      # K = number of key/value heads
      # n = N // K
      # H = dimensions of each attention head.
      #
      if self._num_kv_heads == 1:
        output_dims = 1
        key_last_dims = [self._key_dim]
        value_last_dims = [self._value_dim]
        self._dot_product_equation = "...SH,...TNH->...NTS"
        self._combine_equation = "...NTS,...SH->...TNH"
      else:
        output_dims = 2
        key_last_dims = [self._num_kv_heads, self._key_dim]
        value_last_dims = [self._num_kv_heads, self._value_dim]
        self._dot_product_equation = "...SKH,...TKnH->...nKTS"
        self._combine_equation = "...nKTS,...SKH->...TnKH"

      einsum_equation, bias_axes, output_rank = _build_proj_equation(
          free_dims=self._key_shape.rank - 1,
          bound_dims=1,
          output_dims=output_dims,
      )
      self._key_dense = tf_keras.layers.EinsumDense(
          einsum_equation,
          output_shape=_get_output_shape(output_rank - 1, key_last_dims),
          bias_axes=bias_axes if self._use_bias else None,
          name="key",
          **self._get_common_kwargs_for_sublayer(),
      )
      einsum_equation, bias_axes, output_rank = _build_proj_equation(
          free_dims=self._value_shape.rank - 1,
          bound_dims=1,
          output_dims=output_dims,
      )
      self._value_dense = tf_keras.layers.EinsumDense(
          einsum_equation,
          output_shape=_get_output_shape(output_rank - 1, value_last_dims),
          bias_axes=bias_axes if self._use_bias else None,
          name="value",
          **self._get_common_kwargs_for_sublayer(),
      )

  def _compute_attention(
      self, query, key, value, attention_mask=None, training=None
  ):
    if self._num_kv_heads > 1:
      query = tf.reshape(
          query,
          [
              tf.shape(query)[0],
              tf.shape(query)[1],
              self._num_kv_heads,
              self._num_heads // self._num_kv_heads,
              tf.shape(query)[-1],
          ],
      )

    # pytype: disable=attribute-error
    attention_output, attention_scores = super()._compute_attention(
        query, key, value, attention_mask=attention_mask, training=training
    )
    # pytype: enable=attribute-error
    if self._num_kv_heads != 1:
      attention_output = tf.reshape(
          attention_output,
          [
              tf.shape(attention_output)[0],
              tf.shape(attention_output)[1],
              self._num_heads,
              tf.shape(attention_output)[-1],
          ],
      )
      attention_scores = tf.reshape(
          attention_scores,
          [
              tf.shape(attention_scores)[0],
              self._num_heads,
              tf.shape(attention_scores)[-2],
              tf.shape(attention_scores)[-1],
          ],
      )
    return attention_output, attention_scores