tensorflow/models

View on GitHub
official/nlp/modeling/layers/block_sparse_attention_test.py

Summary

Maintainability
C
1 day
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for block sparse attention layer."""

import math

from absl.testing import parameterized
import numpy as np
import tensorflow as tf, tf_keras

from official.nlp.modeling.layers import block_sparse_attention


class BlockSparseAttentionTest(tf.test.TestCase, parameterized.TestCase):

  @parameterized.named_parameters(
      ("key_value_same_proj", None, None, [40, 80]),
      ("key_value_different_proj", 32, 60, [40, 60]),
  )
  def test_non_masked_attention(self, value_dim, output_shape, output_dims):
    """Test that the attention layer can be created without a mask tensor."""
    test_layer = block_sparse_attention.MultiHeadAttention(
        num_heads=12,
        key_dim=64,
        value_dim=value_dim,
        output_shape=output_shape,
        src_block_size=10,
        tgt_block_size=5,
    )
    # Create a 3-dimensional input (the first dimension is implicit).
    query = tf_keras.Input(shape=(40, 80))
    value = tf_keras.Input(shape=(20, 80))
    output = test_layer(query=query, value=value)
    self.assertEqual(output.shape.as_list(), [None] + output_dims)

  def test_non_masked_self_attention(self):
    """Test with one input (self-attenntion) and no mask tensor."""
    test_layer = block_sparse_attention.MultiHeadAttention(
        num_heads=12, key_dim=64, src_block_size=10, tgt_block_size=10
    )
    # Create a 3-dimensional input (the first dimension is implicit).
    query = tf_keras.Input(shape=(40, 80))
    output = test_layer(query, query)
    self.assertEqual(output.shape.as_list(), [None, 40, 80])

  @parameterized.named_parameters(
      ("with_bias", True),
      ("no_bias", False),
      ("with_sigmoid_attn", True, True),
  )
  def test_masked_attention(
      self,
      use_bias,
      use_sigmoid_attn=False,
  ):
    """Test with a mask tensor."""
    if use_sigmoid_attn:
      sigmoid_attn_bias = -math.log(2)
    else:
      sigmoid_attn_bias = None
    test_layer = block_sparse_attention.MultiHeadAttention(
        num_heads=4,
        key_dim=2,
        use_bias=use_bias,
        src_block_size=2,
        tgt_block_size=1,
        use_sigmoid_attn=use_sigmoid_attn,
        sigmoid_attn_bias=sigmoid_attn_bias,
    )
    # Create a 3-dimensional input (the first dimension is implicit).
    batch_size = 3
    query = tf_keras.Input(shape=(4, 8))
    value = tf_keras.Input(shape=(2, 8))
    mask_tensor = tf_keras.Input(shape=(4, 2))
    output = test_layer(query=query, value=value, attention_mask=mask_tensor)

    # Create a model containing the test layer.
    model = tf_keras.Model([query, value, mask_tensor], output)

    # Generate data for the input (non-mask) tensors.
    from_data = 10 * np.random.random_sample((batch_size, 4, 8))
    to_data = 10 * np.random.random_sample((batch_size, 2, 8))

    # Invoke the data with a random set of mask data. This should mask at
    # least one element.
    mask_data = np.random.randint(2, size=(batch_size, 4, 2))
    masked_output_data = model.predict([from_data, to_data, mask_data])

    # Invoke the same data, but with a null mask (where no elements are
    # masked).
    null_mask_data = np.ones((batch_size, 4, 2))
    unmasked_output_data = model.predict([from_data, to_data, null_mask_data])

    # Because one data is masked and one is not, the outputs should not be
    # the same.
    self.assertNotAllClose(masked_output_data, unmasked_output_data)

    # Tests the layer with three inputs: Q, K, V.
    key = tf_keras.Input(shape=(2, 8))
    output = test_layer(
        query, value=value, key=key, attention_mask=mask_tensor
    )
    model = tf_keras.Model([query, value, key, mask_tensor], output)

    masked_output_data = model.predict(
        [from_data, to_data, to_data, mask_data]
    )
    unmasked_output_data = model.predict(
        [from_data, to_data, to_data, null_mask_data]
    )
    # Because one data is masked and one is not, the outputs should not be
    # the same.
    self.assertNotAllClose(masked_output_data, unmasked_output_data)

    if use_bias:
      self.assertLen(test_layer._query_dense.trainable_variables, 2)
      self.assertLen(test_layer._output_dense.trainable_variables, 2)
    else:
      self.assertLen(test_layer._query_dense.trainable_variables, 1)
      self.assertLen(test_layer._output_dense.trainable_variables, 1)

  @parameterized.named_parameters(
      ("default_with_softmax", False),
      ("default_with_sigmoid", True),
  )
  def test_default_masked_attention(
      self,
      use_sigmoid_attn=False,
  ):
    """Test with a mask tensor."""
    seq_len = 8
    if use_sigmoid_attn:
      sigmoid_attn_bias = -math.log(seq_len)
    else:
      sigmoid_attn_bias = None
    test_layer = block_sparse_attention.MultiHeadAttention(
        num_heads=4,
        key_dim=2,
        use_bias=True,
        src_block_size=seq_len,
        tgt_block_size=seq_len,
        use_sigmoid_attn=use_sigmoid_attn,
        sigmoid_attn_bias=sigmoid_attn_bias,
    )
    # Create a 3-dimensional input (the first dimension is implicit).
    batch_size = 3
    query = tf_keras.Input(shape=(seq_len, 8))
    value = tf_keras.Input(shape=(seq_len, 8))
    mask_tensor = tf_keras.Input(shape=(seq_len, seq_len))
    output = test_layer(query=query, value=value, attention_mask=mask_tensor)

    # Create a model containing the test layer.
    model = tf_keras.Model([query, value, mask_tensor], output)

    # Generate data for the input (non-mask) tensors.
    from_data = 10 * np.random.random_sample((batch_size, seq_len, 8))
    to_data = 10 * np.random.random_sample((batch_size, seq_len, 8))

    # Invoke the data with a random set of mask data. This should mask at
    # least one element.
    mask_data = np.random.randint(2, size=(batch_size, seq_len, seq_len))
    masked_output_data = model.predict([from_data, to_data, mask_data])

    # Invoke the same data, but with a null mask (where no elements are
    # masked).
    null_mask_data = np.ones((batch_size, seq_len, seq_len))
    unmasked_output_data = model.predict([from_data, to_data, null_mask_data])

    # Because one data is masked and one is not, the outputs should not be
    # the same.
    self.assertNotAllClose(masked_output_data, unmasked_output_data)

    # Tests the layer with three inputs: Q, K, V.
    key = tf_keras.Input(shape=(seq_len, 8))
    output = test_layer(
        query, value=value, key=key, attention_mask=mask_tensor
    )
    model = tf_keras.Model([query, value, key, mask_tensor], output)

    masked_output_data = model.predict(
        [from_data, to_data, to_data, mask_data]
    )
    unmasked_output_data = model.predict(
        [from_data, to_data, to_data, null_mask_data]
    )
    # Because one data is masked and one is not, the outputs should not be
    # the same.
    self.assertNotAllClose(masked_output_data, unmasked_output_data)

    self.assertLen(test_layer._query_dense.trainable_variables, 2)
    self.assertLen(test_layer._output_dense.trainable_variables, 2)

  def test_masked_attention_with_scores(self):
    """Test with a mask tensor."""
    test_layer = block_sparse_attention.MultiHeadAttention(
        num_heads=4, key_dim=2, src_block_size=2, tgt_block_size=1,
    )
    # Create a 3-dimensional input (the first dimension is implicit).
    batch_size = 3
    query = tf_keras.Input(shape=(4, 8))
    value = tf_keras.Input(shape=(2, 8))
    mask_tensor = tf_keras.Input(shape=(4, 2))
    output = test_layer(query=query, value=value, attention_mask=mask_tensor)

    # Create a model containing the test layer.
    model = tf_keras.Model([query, value, mask_tensor], output)

    # Generate data for the input (non-mask) tensors.
    from_data = 10 * np.random.random_sample((batch_size, 4, 8))
    to_data = 10 * np.random.random_sample((batch_size, 2, 8))

    # Invoke the data with a random set of mask data. This should mask at
    # least one element.
    mask_data = np.random.randint(2, size=(batch_size, 4, 2))
    masked_output_data = model.predict([from_data, to_data, mask_data])

    # Invoke the same data, but with a null mask (where no elements are
    # masked).
    null_mask_data = np.ones((batch_size, 4, 2))
    unmasked_output_data = model.predict([from_data, to_data, null_mask_data])

    # Because one data is masked and one is not, the outputs should not be
    # the same.
    self.assertNotAllClose(masked_output_data, unmasked_output_data)

    # Create a model containing attention scores.
    output, scores = test_layer(
        query=query,
        value=value,
        attention_mask=mask_tensor,
        return_attention_scores=True,
    )
    model = tf_keras.Model([query, value, mask_tensor], [output, scores])
    masked_output_data_score, masked_score = model.predict(
        [from_data, to_data, mask_data]
    )
    unmasked_output_data_score, unmasked_score = model.predict(
        [from_data, to_data, null_mask_data]
    )
    self.assertNotAllClose(masked_output_data_score, unmasked_output_data_score)
    self.assertAllClose(masked_output_data, masked_output_data_score)
    self.assertAllClose(unmasked_output_data, unmasked_output_data_score)
    self.assertNotAllClose(masked_score, unmasked_score)

  def test_initializer(self):
    """Test with a specified initializer."""
    test_layer = block_sparse_attention.MultiHeadAttention(
        num_heads=12,
        key_dim=64,
        src_block_size=10,
        kernel_initializer=tf_keras.initializers.TruncatedNormal(stddev=0.02),
    )
    # Create a 3-dimensional input (the first dimension is implicit).
    query = tf_keras.Input(shape=(40, 80))
    output = test_layer(query, query)
    self.assertEqual(output.shape.as_list(), [None, 40, 80])

    # Make sure the sub layers have different kernel init value, and not
    # reusing the initializers.
    self.assertNotAllClose(
        tf_keras.backend.eval(test_layer._query_dense.kernel),
        tf_keras.backend.eval(test_layer._key_dense.kernel),
    )
    self.assertNotAllClose(
        tf_keras.backend.eval(test_layer._query_dense.kernel),
        tf_keras.backend.eval(test_layer._value_dense.kernel),
    )
    self.assertNotAllClose(
        tf_keras.backend.eval(test_layer._query_dense.kernel),
        tf_keras.backend.eval(test_layer._output_dense.kernel),
    )

  @parameterized.named_parameters(
      ("bfloat16", tf.bfloat16),
      ("float16", tf.float16),
      ("float32", tf.float32),
      ("float64", tf.float64),
  )
  def test_sublayer_dtypes(self, dtype):
    test_layer = block_sparse_attention.MultiHeadAttention(
        num_heads=12, key_dim=64, src_block_size=10, dtype=dtype
    )

    query = tf_keras.Input(shape=(40, 80), dtype=dtype)
    # Build the layer
    test_layer(query=query, value=query)

    self.assertEqual(test_layer._query_dense.dtype, dtype)
    self.assertEqual(test_layer._key_dense.dtype, dtype)
    self.assertEqual(test_layer._value_dense.dtype, dtype)
    self.assertEqual(test_layer._output_dense.dtype, dtype)

  def test_dropout(self):
    test_layer = block_sparse_attention.MultiHeadAttention(
        num_heads=2, key_dim=2, dropout=0.5, src_block_size=2, tgt_block_size=1,
    )

    # Generate data for the input (non-mask) tensors.
    from_data = tf_keras.backend.ones(shape=(32, 4, 8))
    to_data = tf_keras.backend.ones(shape=(32, 2, 8))
    train_out = test_layer(from_data, to_data, None, None, None, True)
    test_out = test_layer(from_data, to_data, None, None, None, False)

    # Output should be close when not in training mode,
    # and should not be close when enabling dropout in training mode.
    self.assertNotAllClose(
        tf_keras.backend.eval(train_out), tf_keras.backend.eval(test_out)
    )

  def test_query_mask_progagation(self):
    """Test automatic propagation of the query's mask."""
    test_layer = block_sparse_attention.MultiHeadAttention(
        num_heads=2,
        key_dim=2,
        src_block_size=2,
        tgt_block_size=1,
    )
    self.assertTrue(test_layer.supports_masking)
    query = tf.constant(
        [[1, 2, 3, 0, 0, 0], [3, 3, 1, 1, 2, 0], [1, 1, 0, 0, 0, 0]]
    )
    masked_query = tf_keras.layers.Embedding(4, 8, mask_zero=True)(query)
    value = tf.random.normal((3, 3, 8))
    output = test_layer(query=masked_query, value=value)
    self.assertTrue(hasattr(output, "_keras_mask"))
    self.assertAllEqual(masked_query._keras_mask, output._keras_mask)

  def test_value_mask(self):
    """Test that the value mask is taken into account."""
    test_layer = block_sparse_attention.MultiHeadAttention(
        num_heads=2,
        key_dim=2,
        src_block_size=2,
        tgt_block_size=1,
    )
    query = tf.constant(
        [[1, 2, 3, 0, 0, 0], [3, 3, 1, 1, 2, 0], [1, 1, 0, 0, 0, 0]]
    )
    masked_query = tf_keras.layers.Embedding(4, 8, mask_zero=True)(query)
    value = tf.constant([[5, 4, 0], [3, 0, 0], [2, 1, 1]])
    masked_value = tf_keras.layers.Embedding(6, 8, mask_zero=True)(value)
    output = test_layer(
        query=masked_query,
        value=masked_value,
    )
    mask = tf.constant(
        [[[True, True, False]] * 3 + [[False, False, False]] * 2]
        + [[[True, False, False]] * 5]
        + [[[True, True, True]] + [[False, False, False]] * 4]
    )
    del masked_query._keras_mask
    del masked_value._keras_mask
    output_with_manual_mask = test_layer(
        query=masked_query, value=masked_value, attention_mask=mask
    )
    self.assertAllClose(output, output_with_manual_mask)

  def test_masks_are_cast_to_bool(self):
    """Test that the implicit and explicit masks are cast to bool."""
    test_layer = block_sparse_attention.MultiHeadAttention(
        num_heads=2, key_dim=2, src_block_size=2, tgt_block_size=1,
    )
    query = np.array(
        [[1, 2, 3, 0, 0, 0], [3, 3, 1, 1, 2, 0], [1, 1, 0, 0, 0, 0]]
    )
    masked_query = tf_keras.layers.Embedding(4, 8, mask_zero=True)(query)
    masked_query._keras_mask = tf.cast(masked_query._keras_mask, tf.float32)
    value = np.array([[5, 4, 0], [3, 0, 0], [2, 1, 1]])
    masked_value = tf_keras.layers.Embedding(6, 8, mask_zero=True)(value)
    masked_value._keras_mask = tf.cast(masked_value._keras_mask, tf.float32)
    float_mask = tf.constant([[[1.0]]])
    # if all works well, the following should not raise any exception:
    _ = test_layer(
        query=masked_query,
        value=masked_value,
        attention_mask=float_mask,
    )


if __name__ == "__main__":
  tf.test.main()