tensorflow/models

View on GitHub
official/nlp/modeling/layers/transformer_encoder_block_test.py

Summary

Maintainability
F
2 wks
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for Keras-based transformer block layer."""

from absl.testing import parameterized
import numpy as np
import tensorflow as tf, tf_keras

from official.nlp.modeling.layers.transformer_encoder_block import TransformerEncoderBlock


@parameterized.named_parameters(('base', TransformerEncoderBlock))
class TransformerEncoderBlockLayerTest(
    tf.test.TestCase, parameterized.TestCase):

  def tearDown(self):
    super(TransformerEncoderBlockLayerTest, self).tearDown()
    tf_keras.mixed_precision.set_global_policy('float32')

  def test_layer_creation(self, transformer_cls):
    test_layer = transformer_cls(
        num_attention_heads=10, inner_dim=2048, inner_activation='relu')
    sequence_length = 21
    width = 80
    # Create a 3-dimensional input (the first dimension is implicit).
    data_tensor = tf_keras.Input(shape=(sequence_length, width))
    output_tensor = test_layer(data_tensor)
    # The default output of a transformer layer should be the same as the input.
    self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())

  def test_layer_creation_with_mask(self, transformer_cls):
    test_layer = transformer_cls(
        num_attention_heads=10, inner_dim=2048, inner_activation='relu')
    sequence_length = 21
    width = 80
    # Create a 3-dimensional input (the first dimension is implicit).
    data_tensor = tf_keras.Input(shape=(sequence_length, width))
    # Create a 2-dimensional input (the first dimension is implicit).
    mask_tensor = tf_keras.Input(shape=(sequence_length, sequence_length))
    output_tensor = test_layer([data_tensor, mask_tensor])
    # The default output of a transformer layer should be the same as the input.
    self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())

  def test_layer_invocation(self, transformer_cls):
    test_layer = transformer_cls(
        num_attention_heads=10, inner_dim=2048, inner_activation='relu')
    sequence_length = 21
    width = 80
    # Create a 3-dimensional input (the first dimension is implicit).
    data_tensor = tf_keras.Input(shape=(sequence_length, width))
    output_tensor = test_layer(data_tensor)

    # Create a model from the test layer.
    model = tf_keras.Model(data_tensor, output_tensor)

    # Invoke the model on test data. We can't validate the output data itself
    # (the NN is too complex) but this will rule out structural runtime errors.
    batch_size = 6
    input_data = 10 * np.random.random_sample(
        (batch_size, sequence_length, width))
    _ = model.predict(input_data)

  def test_layer_invocation_with_mask(self, transformer_cls):
    test_layer = transformer_cls(
        num_attention_heads=10, inner_dim=2048, inner_activation='relu')
    sequence_length = 21
    width = 80
    # Create a 3-dimensional input (the first dimension is implicit).
    data_tensor = tf_keras.Input(shape=(sequence_length, width))
    # Create a 2-dimensional input (the first dimension is implicit).
    mask_tensor = tf_keras.Input(shape=(sequence_length, sequence_length))
    output_tensor = test_layer([data_tensor, mask_tensor])

    # Create a model from the test layer.
    model = tf_keras.Model([data_tensor, mask_tensor], output_tensor)

    # Invoke the model on test data. We can't validate the output data itself
    # (the NN is too complex) but this will rule out structural runtime errors.
    batch_size = 6
    input_data = 10 * np.random.random_sample(
        (batch_size, sequence_length, width))
    # The attention mask should be of shape (batch, from_seq_len, to_seq_len),
    # which here is (batch, sequence_length, sequence_length)
    mask_data = np.random.randint(
        2, size=(batch_size, sequence_length, sequence_length))
    _ = model.predict([input_data, mask_data])

  def test_layer_output_range(self, transformer_cls):
    test_layer = transformer_cls(
        num_attention_heads=10, inner_dim=2048, inner_activation='relu')
    sequence_length = 21
    width = 80

    batch_size = 6
    input_data = 10 * np.random.random_sample(
        (batch_size, sequence_length, width))
    mask_data = np.random.randint(
        2, size=(batch_size, sequence_length, sequence_length))
    output_tensor = test_layer([input_data, mask_data])

    # The layer only attends to the first token and outputs the first token
    # embedding.
    new_layer = transformer_cls(
        num_attention_heads=10,
        inner_dim=2048,
        inner_activation='relu')
    _ = new_layer([input_data, mask_data], output_range=1)
    new_layer.set_weights(test_layer.get_weights())
    new_output_tensor = new_layer([input_data, mask_data], output_range=1)
    self.assertAllClose(
        new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)

    output_tensor = test_layer([input_data, mask_data], output_range=1)
    self.assertAllClose(new_output_tensor, output_tensor, atol=5e-5, rtol=0.003)

  def test_layer_output_range_without_mask(self, transformer_cls):
    test_layer = transformer_cls(
        num_attention_heads=10,
        inner_dim=2048,
        inner_activation='relu',
        norm_first=True)
    sequence_length = 21
    width = 80

    batch_size = 6
    input_data = 10 * np.random.random_sample(
        (batch_size, sequence_length, width))
    output_tensor = test_layer(input_data)

    # The layer only attends to the first token and outputs the first token
    # embedding.
    new_layer = transformer_cls(
        num_attention_heads=10,
        inner_dim=2048,
        inner_activation='relu',
        norm_first=True)
    _ = new_layer(input_data, output_range=1)
    new_layer.set_weights(test_layer.get_weights())
    new_output_tensor = new_layer(input_data, output_range=1)
    self.assertAllClose(
        new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)

  def test_layer_output_range_with_pre_norm(self, transformer_cls):
    test_layer = transformer_cls(
        num_attention_heads=10,
        inner_dim=2048,
        inner_activation='relu',
        norm_first=True)
    sequence_length = 21
    width = 80

    batch_size = 6
    input_data = 10 * np.random.random_sample(
        (batch_size, sequence_length, width))
    mask_data = np.random.randint(
        2, size=(batch_size, sequence_length, sequence_length))
    output_tensor = test_layer([input_data, mask_data])

    # The layer only attends to the first token and outputs the first token
    # embedding.
    new_layer = transformer_cls(
        num_attention_heads=10,
        inner_dim=2048,
        inner_activation='relu',
        norm_first=True)
    _ = new_layer([input_data, mask_data], output_range=1)
    new_layer.set_weights(test_layer.get_weights())
    new_output_tensor = new_layer([input_data, mask_data], output_range=1)
    self.assertAllClose(
        new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)

    output_tensor = test_layer([input_data, mask_data], output_range=1)
    self.assertAllClose(new_output_tensor, output_tensor, atol=5e-5, rtol=0.003)

  def test_layer_invocation_with_float16_dtype(self, transformer_cls):
    tf_keras.mixed_precision.set_global_policy('mixed_float16')
    test_layer = transformer_cls(
        num_attention_heads=10, inner_dim=2048, inner_activation='relu')
    sequence_length = 21
    width = 80
    # Create a 3-dimensional input (the first dimension is implicit).
    data_tensor = tf_keras.Input(shape=(sequence_length, width))
    # Create a 2-dimensional input (the first dimension is implicit).
    mask_tensor = tf_keras.Input(shape=(sequence_length, sequence_length))
    output_tensor = test_layer([data_tensor, mask_tensor])

    # Create a model from the test layer.
    model = tf_keras.Model([data_tensor, mask_tensor], output_tensor)

    # Invoke the model on test data. We can't validate the output data itself
    # (the NN is too complex) but this will rule out structural runtime errors.
    batch_size = 6
    input_data = (10 * np.random.random_sample(
        (batch_size, sequence_length, width)))
    # The attention mask should be of shape (batch, from_seq_len, to_seq_len),
    # which here is (batch, sequence_length, sequence_length)
    mask_data = np.random.randint(
        2, size=(batch_size, sequence_length, sequence_length))
    _ = model.predict([input_data, mask_data])

  def test_transform_with_initializer(self, transformer_cls):
    test_layer = transformer_cls(
        num_attention_heads=10,
        inner_dim=2048,
        inner_activation='relu',
        kernel_initializer=tf_keras.initializers.TruncatedNormal(stddev=0.02))
    sequence_length = 21
    width = 80
    # Create a 3-dimensional input (the first dimension is implicit).
    data_tensor = tf_keras.Input(shape=(sequence_length, width))
    output = test_layer(data_tensor)
    # The default output of a transformer layer should be the same as the input.
    self.assertEqual(data_tensor.shape.as_list(), output.shape.as_list())

  def test_dynamic_layer_sequence(self, transformer_cls):
    test_layer = transformer_cls(
        num_attention_heads=10,
        inner_dim=2048,
        inner_activation='relu',
        kernel_initializer=tf_keras.initializers.TruncatedNormal(stddev=0.02))
    # Create a 3-dimensional input (the first dimension is implicit).
    width = 30
    input_tensor = tf_keras.Input(shape=(None, width))
    output_tensor = test_layer(input_tensor)
    model = tf_keras.Model(input_tensor, output_tensor)

    input_length = 17
    input_data = np.ones((1, input_length, width))
    output_data = model.predict(input_data)

    self.assertAllEqual([1, input_length, width], output_data.shape)

  def test_separate_qkv(self, transformer_cls):
    test_layer = transformer_cls(
        num_attention_heads=2,
        inner_dim=128,
        inner_activation='relu',
        kernel_initializer=tf_keras.initializers.TruncatedNormal(stddev=0.02))
    # Forward path.
    q_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
    kv_tensor = tf.zeros([2, 8, 16], dtype=tf.float32)
    dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32)
    inputs = [q_tensor, kv_tensor, dummy_mask]
    output = test_layer(inputs)
    self.assertEqual(output.shape, q_tensor.shape)


class TransformerEncoderBlockLayerTestWithoutParams(
    tf.test.TestCase, parameterized.TestCase):

  def tearDown(self):
    super(TransformerEncoderBlockLayerTestWithoutParams, self).tearDown()
    tf_keras.mixed_precision.set_global_policy('float32')

  def test_raises_invalid_arg_error_when_q_kv_dims_are_different(self):
    test_layer = TransformerEncoderBlock(
        num_attention_heads=2,
        inner_dim=128,
        inner_activation='relu',
        norm_first=True)
    # Forward path.
    q_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
    kv_tensor = tf.zeros([2, 8, 32], dtype=tf.float32)
    dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32)
    inputs = [q_tensor, kv_tensor, dummy_mask]
    with self.assertRaises(tf.errors.InvalidArgumentError):
      test_layer(inputs)

  @parameterized.named_parameters(('output_range_not_none', 2),
                                  ('output_range_none', None))
  def test_needs_diff_q_kv_att_layer_norm_to_be_true_for_diff_q_and_kv_dims(
      self, output_range):
    test_layer = TransformerEncoderBlock(
        num_attention_heads=2,
        inner_dim=128,
        inner_activation='relu',
        norm_first=True)
    # Forward path.
    q_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
    kv_tensor = tf.zeros([2, 8, 32], dtype=tf.float32)
    dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32)
    inputs = [q_tensor, kv_tensor, dummy_mask]
    with self.assertRaises(tf.errors.InvalidArgumentError):
      test_layer(inputs, output_range=output_range)

    test_layer = TransformerEncoderBlock(
        num_attention_heads=2,
        inner_dim=128,
        inner_activation='relu',
        diff_q_kv_att_layer_norm=True,
        norm_first=True)
    # Forward path.
    test_layer(inputs)

  @parameterized.named_parameters(('norm_first_is_true', True),
                                  ('norm_first_is_false', False))
  def test_use_query_residual_false_removes_add_op(self, norm_first):
    graph_with_res = tf.Graph()
    with graph_with_res.as_default():
      layer = TransformerEncoderBlock(
          num_attention_heads=2,
          inner_dim=128,
          inner_activation='relu',
          norm_first=norm_first)
      inputs = tf_keras.Input(shape=(None, None, 2))
      outputs = layer(inputs)
      tf_keras.Model(inputs=inputs, outputs=outputs)

    graph_without_res = tf.Graph()
    with graph_without_res.as_default():
      layer = TransformerEncoderBlock(
          num_attention_heads=2,
          inner_dim=128,
          inner_activation='relu',
          norm_first=norm_first,
          use_query_residual=False)
      inputs = tf_keras.Input(shape=(None, None, 2))
      outputs = layer(inputs)
      tf_keras.Model(inputs=inputs, outputs=outputs)
    graph_with_res_names = {x.name for x in graph_with_res.get_operations()}
    graph_without_res_names = {
        x.name for x in graph_without_res.get_operations()
    }

    self.assertIn('transformer_encoder_block/add',
                  list(graph_with_res_names - graph_without_res_names)[0])
    self.assertEmpty(graph_without_res_names - graph_with_res_names)

  @parameterized.named_parameters(('key_dim_is_none', None, 128, 2, 128 // 2),
                                  ('key_dim_is_not_none', 30, 128, 2, 30))
  def test_key_dim(self, key_dim, q_tensor_last_dim, some_num_attention_heads,
                   expected):
    some_inner_dim = 32
    some_inner_activation = 'relu'
    test_layer = TransformerEncoderBlock(
        num_attention_heads=some_num_attention_heads,
        inner_dim=some_inner_dim,
        inner_activation=some_inner_activation,
        key_dim=key_dim)

    q_tensor = tf.zeros([2, 4, q_tensor_last_dim], dtype=tf.float32)
    kv_tensor = tf.zeros([2, 8, 32], dtype=tf.float32)
    dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32)
    test_layer([q_tensor, kv_tensor, dummy_mask])

    self.assertEqual(expected,
                     test_layer._attention_layer.get_config()['key_dim'])

  @parameterized.named_parameters(
      ('output_last_dim_is_none_use_query_residual_false', False, None, 128,
       128),
      ('output_last_dim_is_none_use_query_residual_true', True, None, 128, 128),
      ('output_last_dim_is_not_none', False, 30, 128, 30))
  def test_output_last_dim(self, use_query_residual, output_last_dim,
                           q_tensor_last_dim, expected):
    some_num_attention_heads = 2
    some_inner_dim = 32
    some_inner_activation = 'relu'
    test_layer = TransformerEncoderBlock(
        num_attention_heads=some_num_attention_heads,
        inner_dim=some_inner_dim,
        inner_activation=some_inner_activation,
        # Must be false for multi-head output to be different from
        # first input's last dim
        use_query_residual=use_query_residual,
        output_last_dim=output_last_dim)

    q_tensor = tf.zeros([2, 4, q_tensor_last_dim], dtype=tf.float32)
    kv_tensor = tf.zeros([2, 8, 32], dtype=tf.float32)
    dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32)
    output = test_layer([q_tensor, kv_tensor, dummy_mask])

    self.assertEqual(output.numpy().shape[-1], expected)

  @parameterized.named_parameters(('value_dim_is_none', None, 128, 2, 128 // 2),
                                  ('value_dim_is_not_none', 30, 128, 2, 30))
  def test_value_dim(self, value_dim, q_tensor_last_dim,
                     some_num_attention_heads, expected):
    some_inner_dim = 32
    some_inner_activation = 'relu'
    test_layer = TransformerEncoderBlock(
        num_attention_heads=some_num_attention_heads,
        inner_dim=some_inner_dim,
        inner_activation=some_inner_activation,
        value_dim=value_dim)

    q_tensor = tf.zeros([2, 4, q_tensor_last_dim], dtype=tf.float32)
    kv_tensor = tf.zeros([2, 8, 32], dtype=tf.float32)
    dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32)
    test_layer([q_tensor, kv_tensor, dummy_mask])

    self.assertEqual(expected,
                     test_layer._attention_layer.get_config()['value_dim'])


class TransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):

  def test_use_bias_norm_first(self):
    num_attention_heads = 2
    hidden_size = 16
    encoder_block = TransformerEncoderBlock(
        num_attention_heads=num_attention_heads,
        inner_dim=32,
        inner_activation='relu',
        output_dropout=0.1,
        attention_dropout=0.1,
        use_bias=False,
        norm_first=True,
        norm_epsilon=1e-6,
        inner_dropout=0.1,
        attention_initializer=tf_keras.initializers.RandomUniform(
            minval=0., maxval=1.))
    # Forward path.
    dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
    dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
    inputs = [dummy_tensor, dummy_mask]
    output = encoder_block(inputs)
    self.assertEqual(output.shape, (2, 4, hidden_size))

  def test_use_rms_norm(self):
    num_attention_heads = 2
    hidden_size = 16
    encoder_block = TransformerEncoderBlock(
        num_attention_heads=num_attention_heads,
        inner_dim=32,
        inner_activation='relu',
        output_dropout=0.1,
        attention_dropout=0.1,
        use_bias=False,
        use_rms_norm=True,
        norm_epsilon=1e-6,
        inner_dropout=0.1,
        attention_initializer=tf_keras.initializers.RandomUniform(
            minval=0., maxval=1.))
    # Forward path.
    dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
    dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
    inputs = [dummy_tensor, dummy_mask]
    output = encoder_block(inputs)
    self.assertEqual(output.shape, (2, 4, hidden_size))

  def test_norm_first_false_and_diff_q_kv_att_layer_norm_true_raises(self):
    some_num_attention_heads = 2
    some_inner_dim = 32
    some_inner_activation = 'relu'
    with self.assertRaises(ValueError):
      TransformerEncoderBlock(
          num_attention_heads=some_num_attention_heads,
          inner_dim=some_inner_dim,
          inner_activation=some_inner_activation,
          norm_first=False,
          diff_q_kv_att_layer_norm=True)

  def test_diff_q_kv_att_layer_norm_is_part_of_config_1(self):
    some_num_attention_heads = 2
    some_inner_dim = 32
    some_inner_activation = 'relu'
    encoder = TransformerEncoderBlock(
        num_attention_heads=some_num_attention_heads,
        inner_dim=some_inner_dim,
        inner_activation=some_inner_activation,
        norm_first=False)
    self.assertIn('diff_q_kv_att_layer_norm', encoder.get_config())
    self.assertFalse(encoder.get_config()['diff_q_kv_att_layer_norm'])

  def test_diff_q_kv_att_layer_norm_is_part_of_config_2(self):
    some_num_attention_heads = 2
    some_inner_dim = 32
    some_inner_activation = 'relu'
    encoder = TransformerEncoderBlock(
        num_attention_heads=some_num_attention_heads,
        inner_dim=some_inner_dim,
        inner_activation=some_inner_activation,
        norm_first=True,
        diff_q_kv_att_layer_norm=True)
    self.assertIn('diff_q_kv_att_layer_norm', encoder.get_config())
    self.assertTrue(encoder.get_config()['diff_q_kv_att_layer_norm'])

  def test_use_query_residual_is_part_of_config_1(self):
    some_num_attention_heads = 2
    some_inner_dim = 32
    some_inner_activation = 'relu'
    encoder = TransformerEncoderBlock(
        num_attention_heads=some_num_attention_heads,
        inner_dim=some_inner_dim,
        inner_activation=some_inner_activation)
    self.assertIn('use_query_residual', encoder.get_config())
    self.assertTrue(encoder.get_config()['use_query_residual'])

  def test_use_query_residual_is_part_of_config_2(self):
    some_num_attention_heads = 2
    some_inner_dim = 32
    some_inner_activation = 'relu'
    encoder = TransformerEncoderBlock(
        num_attention_heads=some_num_attention_heads,
        inner_dim=some_inner_dim,
        inner_activation=some_inner_activation,
        use_query_residual=False)
    self.assertIn('use_query_residual', encoder.get_config())
    self.assertFalse(encoder.get_config()['use_query_residual'])

  def test_key_dim_is_part_of_config_1(self):
    some_num_attention_heads = 2
    some_inner_dim = 32
    some_inner_activation = 'relu'
    encoder = TransformerEncoderBlock(
        num_attention_heads=some_num_attention_heads,
        inner_dim=some_inner_dim,
        inner_activation=some_inner_activation)
    self.assertIn('key_dim', encoder.get_config())
    self.assertIsNone(encoder.get_config()['key_dim'])

  def test_key_dim_is_part_of_config_2(self):
    some_num_attention_heads = 2
    some_inner_dim = 32
    some_inner_activation = 'relu'
    key_dim = 10
    encoder = TransformerEncoderBlock(
        num_attention_heads=some_num_attention_heads,
        inner_dim=some_inner_dim,
        inner_activation=some_inner_activation,
        key_dim=key_dim)
    self.assertIn('key_dim', encoder.get_config())
    self.assertEqual(key_dim, encoder.get_config()['key_dim'])

  def test_value_dim_is_part_of_config_1(self):
    some_num_attention_heads = 2
    some_inner_dim = 32
    some_inner_activation = 'relu'
    encoder = TransformerEncoderBlock(
        num_attention_heads=some_num_attention_heads,
        inner_dim=some_inner_dim,
        inner_activation=some_inner_activation)
    self.assertIn('value_dim', encoder.get_config())
    self.assertIsNone(encoder.get_config()['value_dim'])

  def test_value_dim_is_part_of_config_2(self):
    some_num_attention_heads = 2
    some_inner_dim = 32
    some_inner_activation = 'relu'
    value_dim = 10
    encoder = TransformerEncoderBlock(
        num_attention_heads=some_num_attention_heads,
        inner_dim=some_inner_dim,
        inner_activation=some_inner_activation,
        value_dim=value_dim)
    self.assertIn('value_dim', encoder.get_config())
    self.assertEqual(value_dim, encoder.get_config()['value_dim'])

  def test_output_last_dim_is_part_of_config_1(self):
    some_num_attention_heads = 2
    some_inner_dim = 32
    some_inner_activation = 'relu'
    encoder = TransformerEncoderBlock(
        num_attention_heads=some_num_attention_heads,
        inner_dim=some_inner_dim,
        inner_activation=some_inner_activation)
    self.assertIn('output_last_dim', encoder.get_config())
    self.assertIsNone(encoder.get_config()['output_last_dim'])

  def test_output_last_dim_is_part_of_config_2(self):
    some_num_attention_heads = 2
    some_inner_dim = 32
    some_inner_activation = 'relu'
    output_last_dim = 10
    encoder = TransformerEncoderBlock(
        num_attention_heads=some_num_attention_heads,
        inner_dim=some_inner_dim,
        inner_activation=some_inner_activation,
        output_last_dim=output_last_dim)
    self.assertIn('output_last_dim', encoder.get_config())
    self.assertEqual(output_last_dim, encoder.get_config()['output_last_dim'])

  def test_get_config(self):
    num_attention_heads = 2
    encoder_block = TransformerEncoderBlock(
        num_attention_heads=num_attention_heads,
        inner_dim=32,
        inner_activation='relu',
        output_dropout=0.1,
        attention_dropout=0.1,
        use_bias=False,
        norm_first=True,
        norm_epsilon=1e-6,
        inner_dropout=0.1,
        attention_initializer=tf_keras.initializers.RandomUniform(
            minval=0., maxval=1.),
        use_query_residual=False,
        key_dim=20,
        value_dim=30,
        output_last_dim=40,
        diff_q_kv_att_layer_norm=True)
    encoder_block_config = encoder_block.get_config()
    new_encoder_block = TransformerEncoderBlock.from_config(
        encoder_block_config)
    self.assertEqual(encoder_block_config, new_encoder_block.get_config())

  @parameterized.parameters({'attention_axes': None}, {'attention_axes': [1]},
                            {'attention_axes': [2]}, {'attention_axes': [1, 2]})
  def test_several_attention_axes(self, attention_axes):
    test_layer = TransformerEncoderBlock(
        inner_dim=32,
        inner_activation='relu',
        output_dropout=0.1,
        attention_dropout=0.1,
        use_bias=False,
        norm_first=True,
        norm_epsilon=1e-6,
        inner_dropout=0.1,
        num_attention_heads=10,
        attention_axes=attention_axes)
    num_rows = 21
    num_cols = 13
    width = 80
    # Create a 3-dimensional input (the first dimension is implicit).
    data_tensor = tf_keras.Input(shape=(num_rows, num_cols, width))
    output_tensor = test_layer(data_tensor)
    # The default output of a transformer layer should be the same as the input.
    self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())

  @parameterized.parameters(
      {
          'output_dropout': 0.1,
          'attention_dropout': 0.2,
          'inner_dropout': 0.3
      }, {
          'output_dropout': 0.0,
          'attention_dropout': 0.2,
          'inner_dropout': 0.3
      }, {
          'output_dropout': 0.1,
          'attention_dropout': 0.0,
          'inner_dropout': 0.3
      }, {
          'output_dropout': 0.1,
          'attention_dropout': 0.2,
          'inner_dropout': 0.0
      })
  def test_dropout_config(self, output_dropout, attention_dropout,
                          inner_dropout):
    test_layer = TransformerEncoderBlock(
        num_attention_heads=2,
        inner_dim=32,
        inner_activation='relu',
        output_dropout=output_dropout,
        attention_dropout=attention_dropout,
        inner_dropout=inner_dropout)
    seq_len = 21
    hidden_size = 512
    input_tensor = tf_keras.Input(shape=(seq_len, hidden_size))
    _ = test_layer(input_tensor)

    true_output_dropout = test_layer._output_dropout.get_config()['rate']
    true_attention_dropout = test_layer._attention_dropout.get_config()['rate']
    true_inner_dropout = test_layer._inner_dropout_layer.get_config()['rate']
    self.assertEqual(true_output_dropout, output_dropout)
    self.assertEqual(true_attention_dropout, attention_dropout)
    self.assertEqual(true_inner_dropout, inner_dropout)

  @parameterized.named_parameters(
      (
          'return_attention_scores_is_false',
          False,
      ),
      (
          'return_attention_scores_is_true',
          True,
      ),
  )
  def test_return_attention_scores(self, return_attention_scores):
    num_attention_heads = 7
    sequence_length = 21
    width = 80

    test_layer = TransformerEncoderBlock(
        num_attention_heads=num_attention_heads,
        inner_dim=2048,
        inner_activation='relu',
        return_attention_scores=return_attention_scores)
    # Create a 3-dimensional input (the first dimension is implicit).
    data_tensor = tf_keras.Input(shape=(sequence_length, width))
    output_tensor = test_layer(data_tensor)

    expected_layer_output_shape = [None, sequence_length, width]
    expected_attention_scores_shape = [
        None, num_attention_heads, sequence_length, sequence_length
    ]

    if return_attention_scores:
      self.assertIsInstance(output_tensor, tuple)
      self.assertLen(output_tensor, 2)
      # First is the standard output.
      self.assertEqual(output_tensor[0].shape.as_list(),
                       expected_layer_output_shape)
      # Second is the attention scores.
      self.assertEqual(output_tensor[1].shape.as_list(),
                       expected_attention_scores_shape)
    else:
      # Only the standard layer output.
      self.assertEqual(output_tensor.shape.as_list(),
                       expected_layer_output_shape)


if __name__ == '__main__':
  tf.test.main()