tensorflow/models

View on GitHub
official/nlp/modeling/models/t5_test.py

Summary

Maintainability
F
6 days
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for t5."""

from absl.testing import parameterized
import numpy as np
import tensorflow as tf, tf_keras

from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.nlp.modeling.models import t5


def _create_cache(batch_size,
                  init_decode_length,
                  num_heads,
                  head_size,
                  dtype=tf.float32):
  if num_heads is None:
    kv_shape = [batch_size, init_decode_length, head_size]
  else:
    kv_shape = [batch_size, init_decode_length, num_heads, head_size]

  return {
      "key": tf.zeros(kv_shape, dtype=dtype),
      "value": tf.zeros(kv_shape, dtype=dtype)
  }


class ModulesTest(tf.test.TestCase, parameterized.TestCase):

  @parameterized.named_parameters(("bfloat16", tf.bfloat16),
                                  ("float32", tf.float32))
  def test_embed(self, dtype):
    l = t5.Embed(vocab_size=5, features=4, compute_dtype=dtype, name="foo")
    inputs = np.array([[2, 3], [1, 2]], dtype=np.int32)
    inputs = tf.convert_to_tensor(inputs)
    one_hot_outputs = l(inputs, one_hot=True)
    gather_outputs = l(inputs, one_hot=False)
    self.assertEqual(one_hot_outputs.shape, (2, 2, 4))
    self.assertLen(l.trainable_variables, 1)
    self.assertAllClose(one_hot_outputs, gather_outputs)

    outputs = l.attend(query=tf.zeros((2, 2, 4), dtype))
    self.assertEqual(outputs.shape, (2, 2, 5))

    # Test initializers.
    l = t5.Embed(
        vocab_size=5,
        features=4,
        compute_dtype=dtype,
        name="foo",
        embeddings_initializer=tf_keras.initializers.Zeros())
    self.assertAllClose(l(inputs), tf.zeros((2, 2, 4), dtype))

  @parameterized.named_parameters(("bfloat16", tf.bfloat16),
                                  ("float32", tf.float32))
  def test_rms_norm(self, dtype):
    l = t5.RMSNorm(hidden_size=4, epsilon=0.0, name="foo")
    inputs = tf.ones((2, 4), dtype=dtype)
    outputs = l(inputs)
    self.assertAllEqual(l(inputs), inputs)
    self.assertEqual(outputs.dtype, dtype)
    self.assertLen(l.trainable_variables, 1)
    self.assertIn("foo/scale", l.trainable_variables[0].name)

  @parameterized.named_parameters(("bfloat16", tf.bfloat16),
                                  ("float32", tf.float32))
  def test_linear(self, dtype):
    l = t5.Linear(
        in_features=4,
        out_features=4,
        w_init=tf_keras.initializers.Ones(),
        name="foo")
    inputs = tf.ones((2, 4), dtype=dtype)
    outputs = l(inputs)
    self.assertEqual(outputs.shape, inputs.shape)
    self.assertEqual(outputs.dtype, dtype)
    self.assertLen(l.trainable_variables, 2)

  def test_linear3d(self):
    batch_size = 2
    l = t5.Linear3D(
        in_features=4,
        out_features=4,
        num_heads=2,
        to_3d=True,
        w_init=tf_keras.initializers.Ones(),
        name="foo")
    inputs = np.ones((batch_size, 2, 4), dtype=np.float32)
    self.assertEqual(l(inputs).shape, (batch_size, 2, 2, 4))

    l = t5.Linear3D(
        in_features=2,
        out_features=4,
        num_heads=2,
        to_3d=False,
        w_init=tf_keras.initializers.Ones(),
        name="foo")
    inputs = np.ones((batch_size, 2, 2, 2), dtype=np.float32)
    self.assertEqual(l(inputs).shape, (batch_size, 2, 4))

  def test_ffn(self):
    inputs = np.ones((2, 4), dtype=np.float32)
    for activation in ["relu", "linear", "gelu", "swish"]:
      l = t5.FFN(
          d_model=4,
          d_ff=8,
          use_bias=True,
          dropout_rate=0.1,
          activations=[activation],
          name="foo")
      self.assertEqual(l(inputs).shape, inputs.shape)
      self.assertLen(l.trainable_variables, 4)

    l = t5.FFN(
        d_model=4,
        d_ff=8,
        dropout_rate=0.1,
        activations=["linear", "gelu"],
        name="bar")
    self.assertLen(l.trainable_variables, 3)
    self.assertEqual(l(inputs).shape, inputs.shape)

  @parameterized.named_parameters(("bfloat16", tf.bfloat16),
                                  ("float32", tf.float32))
  def test_relative_position(self, dtype):
    l = t5.RelativePositionEmbedding(
        num_heads=4,
        bidirectional=False,
        embeddings_initializer=tf_keras.initializers.Ones(),
        compute_dtype=dtype,
        name="foo")
    self.assertEqual(l(4, 2).shape, (1, 4, 4, 2))
    l = t5.RelativePositionEmbedding(
        num_heads=4,
        bidirectional=True,
        embeddings_initializer=tf_keras.initializers.Ones(),
        compute_dtype=dtype,
        name="bar")
    outputs = l(4, 2)
    self.assertEqual(outputs.shape, (1, 4, 4, 2))
    self.assertEqual(outputs.dtype, dtype)

  def test_masks(self):
    causal_mask = t5.make_causal_mask(np.zeros((2, 5)))
    self.assertEqual(causal_mask.shape, (2, 1, 5, 5))

  @combinations.generate(
      combinations.combine(
          distribution=[
              strategy_combinations.default_strategy,
              strategy_combinations.cloud_tpu_strategy,
          ],
          mode="eager"))
  def test_attention(self, distribution):
    num_heads, head_size = 2, 4
    from_seq_length, to_seq_length = 4, 6
    batch_size = 2
    pos_embed = t5.RelativePositionEmbedding(
        num_heads=4,
        bidirectional=False,
        embeddings_initializer=tf_keras.initializers.Ones(),
        name="pos_embed")
    position_bias = pos_embed(from_seq_length, from_seq_length)
    l = t5.MultiHeadAttention(d_model=4, d_kv=2, num_heads=4, dropout_rate=0.1)
    query = tf.convert_to_tensor(
        np.ones((batch_size, from_seq_length, 4), dtype=np.float32))
    self.assertEqual(
        l(query, position_bias=position_bias)["context"].shape, query.shape)
    kv = tf.convert_to_tensor(
        np.ones((batch_size, to_seq_length, 4), dtype=np.float32))
    position_bias = pos_embed(from_seq_length, to_seq_length)
    outputs = l(query, kv=kv, position_bias=position_bias)
    self.assertEqual(outputs["context"].shape, query.shape)

    with distribution.scope():
      l = t5.MultiHeadAttention(
          d_model=4, d_kv=head_size, num_heads=num_heads, dropout_rate=0.1)

      @tf.function
      def step(inputs):

        def _step_fn(inputs):
          cache = _create_cache(batch_size, from_seq_length, num_heads,
                                head_size)
          mask = t5.make_causal_mask(tf.ones((batch_size, 1)))
          return l(
              query=inputs,
              mask=mask,
              cache=cache,
              decode_position=decode_position)

        outputs = distribution.run(_step_fn, args=(inputs,))
        return tf.nest.map_structure(distribution.experimental_local_results,
                                     outputs)

      decode_position = 2
      query = tf.convert_to_tensor(np.ones((2, 1, 4), dtype=np.float32))
      local_outputs = step(query)
      self.assertEqual(local_outputs["context"][0].shape, (2, 1, 4))
      self.assertNotEqual(
          np.sum(local_outputs["cache"]["key"][0][:, decode_position,
                                                  ...].numpy()), 0.0)


class T5Test(tf.test.TestCase, parameterized.TestCase):

  @combinations.generate(
      combinations.combine(
          distribution=[
              strategy_combinations.default_strategy,
              strategy_combinations.cloud_tpu_strategy,
          ],
          mode="eager"))
  def test_attention_layers(self, distribution):
    num_heads, head_size = 2, 2
    from_seq_length = 4
    # TPU decoding should pre-allocate the entire sequence.
    batch_size = 2
    with distribution.scope():
      pos_embed = t5.RelativePositionEmbedding(
          num_heads=head_size,
          bidirectional=False,
          embeddings_initializer=tf_keras.initializers.Ones(),
          name="pos_embed")
      l = t5.SelfAttention(
          d_model=4, d_kv=head_size, num_heads=num_heads, dropout_rate=0.1)
      decode_position = 2

      @tf.function
      def step(inputs):

        def _step_fn(inputs):
          cache = _create_cache(batch_size, from_seq_length, num_heads,
                                head_size)
          mask = t5.make_causal_mask(tf.ones((batch_size, 1)))
          position_bias = pos_embed(from_seq_length, from_seq_length)
          return l(
              hidden_states=inputs,
              cache=cache,
              attention_mask=mask,
              decode_position=decode_position,
              position_bias=position_bias)

        outputs = distribution.run(_step_fn, args=(inputs,))
        return tf.nest.map_structure(distribution.experimental_local_results,
                                     outputs)

      query = tf.convert_to_tensor(np.ones((2, 1, 4), dtype=np.float32))
      local_outputs = step(query)
      self.assertEqual(local_outputs["layer_output"][0].shape, (2, 1, 4))
      self.assertNotEqual(
          np.sum(
              local_outputs["cache"]["key"][0][:,
                                               decode_position, :, :].numpy()),
          0.0)

      l = t5.CrossAttention(
          d_model=4, d_kv=head_size, num_heads=num_heads, dropout_rate=0.1)
      to_seq_length = 6
      query = tf.convert_to_tensor(
          np.ones((2, from_seq_length, 4), dtype=np.float32))
      kv = tf.convert_to_tensor(
          np.ones((2, to_seq_length, 4), dtype=np.float32))

      @tf.function
      def step_cross_attn(inputs):

        def _step_fn(inputs):
          query, kv = inputs
          mask = t5.make_attention_mask(
              tf.ones((batch_size, from_seq_length)),
              tf.ones((batch_size, to_seq_length)))
          return l(hidden_states=query, kv=kv, attention_mask=mask)

        outputs = distribution.run(_step_fn, args=(inputs,))
        return tf.nest.map_structure(distribution.experimental_local_results,
                                     outputs)

      local_outputs = step_cross_attn((query, kv))
      self.assertEqual(local_outputs["layer_output"][0].shape,
                       (2, from_seq_length, 4))

  def test_encoder_block(self):
    batch_size = 2
    from_seq_length = 5
    d_model = 4
    l = t5.EncoderBlock(d_model=4, d_kv=3, num_heads=2, d_ff=8, name="foo")
    pos_embed = t5.RelativePositionEmbedding(
        num_heads=2,
        bidirectional=True,
        embeddings_initializer=tf_keras.initializers.Ones(),
        name="bar")
    attention_mask = t5.make_attention_mask(
        tf.ones((batch_size, from_seq_length)),
        tf.ones((batch_size, from_seq_length)))
    position_bias = pos_embed(from_seq_length, from_seq_length)
    inputs = tf.ones((batch_size, from_seq_length, d_model), dtype=tf.float32)
    outputs = l(
        inputs, attention_mask=attention_mask, position_bias=position_bias)
    self.assertEqual(outputs.shape, (batch_size, from_seq_length, d_model))

  def test_encdec_block(self):
    batch_size = 2
    from_seq_length = 5
    to_seq_length = 3
    d_model = 4
    l = t5.EncDecoderBlock(d_model=4, d_kv=3, num_heads=2, d_ff=8, name="foo")
    pos_embed = t5.RelativePositionEmbedding(
        num_heads=2,
        bidirectional=True,
        embeddings_initializer=tf_keras.initializers.Ones(),
        name="bar")
    encoder_decoder_mask = t5.make_attention_mask(
        tf.ones((batch_size, from_seq_length)),
        tf.ones((batch_size, to_seq_length)))
    position_bias = pos_embed(from_seq_length, from_seq_length)
    inputs = tf.ones((batch_size, from_seq_length, d_model), dtype=tf.float32)
    encoder_hidden_states = tf.ones((batch_size, to_seq_length, d_model),
                                    dtype=tf.float32)
    outputs = l(
        inputs,
        encoder_hidden_states,
        encoder_decoder_mask=encoder_decoder_mask,
        position_bias=position_bias)
    self.assertEqual(outputs[0].shape, (batch_size, from_seq_length, d_model))

  @parameterized.named_parameters(("bfloat16", tf.bfloat16),
                                  ("float32", tf.float32))
  def test_encoder(self, dtype):
    config = t5.T5TransformerParams(
        num_layers=2,
        d_model=4,
        d_kv=3,
        num_heads=4,
        d_ff=16,
        vocab_size=10,
        vocab_embeddings_initializer=tf_keras.initializers.Ones(),
        relative_embeddings_initializer=tf_keras.initializers.Ones())
    encoder = t5.Encoder(config, compute_dtype=dtype)
    encoded = encoder(tf.zeros((4, 8), dtype=tf.int32))
    self.assertEqual(encoded.shape, (4, 8, config.d_model))

  @parameterized.named_parameters(("return_score", True),
                                  ("not_return_score", False))
  def test_encoder_att_scores(self, return_attention_scores):
    config = t5.T5TransformerParams(
        num_layers=2,
        d_model=4,
        d_kv=3,
        num_heads=4,
        d_ff=16,
        vocab_size=10,
        vocab_embeddings_initializer=tf_keras.initializers.Ones(),
        relative_embeddings_initializer=tf_keras.initializers.Ones(),
        return_attention_scores=return_attention_scores)
    encoder = t5.Encoder(config, compute_dtype=tf.float32)
    encoded = encoder(tf.zeros((4, 8), dtype=tf.int32))
    if return_attention_scores:
      encoded, scores = encoded
      self.assertEqual(encoded.shape, (4, 8, config.d_model))
      self.assertIsNotNone(scores)
      self.assertLen(scores, 2)
      self.assertEqual(scores[0].shape, (4, 4, 8, 8))
    else:
      self.assertEqual(encoded.shape, (4, 8, config.d_model))

  @parameterized.named_parameters(("bfloat16", tf.bfloat16),
                                  ("float32", tf.float32))
  def test_encoder_with_dense(self, dtype):
    config = t5.T5TransformerParams(
        num_layers=2,
        d_model=4,
        d_kv=3,
        num_heads=4,
        d_ff=16,
        vocab_size=10,
        vocab_embeddings_initializer=tf_keras.initializers.Ones(),
        relative_embeddings_initializer=tf_keras.initializers.Ones())
    encoder = t5.Encoder(config, compute_dtype=dtype)
    encoded = encoder(
        tf.zeros((4, 8), dtype=tf.int32),
        dense_inputs=tf.ones((4, 2, 4), dtype=dtype))
    self.assertEqual(encoded.shape, (4, 10, config.d_model))

  @parameterized.named_parameters(("bfloat16", tf.bfloat16),
                                  ("float32", tf.float32))
  def test_encoder_only_dense(self, dtype):
    config = t5.T5TransformerParams(
        num_layers=2,
        d_model=4,
        d_kv=3,
        num_heads=4,
        d_ff=16,
        vocab_size=10,
        vocab_embeddings_initializer=tf_keras.initializers.Ones(),
        relative_embeddings_initializer=tf_keras.initializers.Ones())
    encoder = t5.Encoder(config, compute_dtype=dtype)
    encoded = encoder(dense_inputs=tf.ones((4, 2, 4), dtype=dtype))
    self.assertEqual(encoded.shape, (4, 2, config.d_model))

  def test_decoder(self):
    max_decode_len = 10
    config = t5.T5TransformerParams(
        num_layers=2,
        d_model=4,
        d_kv=3,
        num_heads=4,
        d_ff=16,
        vocab_size=10,
        vocab_embeddings_initializer=tf_keras.initializers.Ones(),
        relative_embeddings_initializer=tf_keras.initializers.Ones())
    decoder = t5.Decoder(config)
    batch_size = 4
    targets = tf.zeros((4, 8), dtype=tf.int32)
    encoded = tf.zeros((4, 8, config.d_model), dtype=tf.float32)
    outputs = decoder(targets, encoded)
    logits = outputs["logits"]
    self.assertEqual(logits.shape, (4, 8, config.vocab_size))

    cache = {}
    cache[0] = _create_cache(batch_size, max_decode_len, config.num_heads,
                             config.d_kv)
    cache[1] = _create_cache(batch_size, max_decode_len, config.num_heads,
                             config.d_kv)
    targets = tf.zeros((4, 1), dtype=tf.int32)
    outputs = decoder(
        targets,
        encoded,
        decode_position=2,
        cache=cache,
        decode=True,
        max_decode_len=max_decode_len)
    logits = outputs["logits"]
    cache = outputs["cache"]
    self.assertEqual(logits.shape, (batch_size, 1, config.vocab_size))
    for entry in cache.values():
      for tensor in entry.values():
        self.assertNotAllEqual(tensor.numpy()[:, 2, :, :], 0.0)

  @parameterized.named_parameters(
      ("t5_10", ("relu",), True, 26, False, tf.float32),
      ("t5_11", ("gelu", "linear"), False, 29, False, tf.float32),
      ("t5_10_bfloat16", ("relu",), True, 26, False, tf.bfloat16),
      ("t5_11_bfloat16", ("gelu", "linear"), False, 29, False, tf.bfloat16),
      ("t5_10_layer_sharing", ("relu",), True, 26, True, tf.float32),
      ("t5_11_layer_sharing", ("gelu", "linear"), False, 29, True, tf.float32),
      ("t5_10_bfloat16_layer_sharing", ("relu",), True, 26, True, tf.bfloat16),
      ("t5_11_bfloat16_layer_sharing",
       ("gelu", "linear"), False, 29, True, tf.bfloat16))
  def test_transformer(self, ffn_activations, logits_via_embedding,
                       expect_num_variables, layer_sharing, dtype):
    max_decode_len = 10
    config = t5.T5TransformerParams(
        num_layers=1,
        d_model=8,
        d_kv=4,
        num_heads=4,
        d_ff=32,
        vocab_size=10,
        shared_embedding=True,
        layer_sharing=layer_sharing,
        ffn_activations=ffn_activations,
        logits_via_embedding=logits_via_embedding)
    transformer = t5.T5Transformer(config, compute_dtype=dtype)
    self.assertLen(transformer.trainable_variables, expect_num_variables)
    inputs = tf.convert_to_tensor(
        np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]]))
    segments = tf.convert_to_tensor(
        np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]]))

    outputs = transformer(
        encoder_input_tokens=inputs,
        decoder_input_tokens=inputs,
        decoder_target_tokens=inputs,
        encoder_segment_ids=segments,
        decoder_segment_ids=segments)
    cache = {}
    batch_size = 2
    cache[0] = _create_cache(
        batch_size, max_decode_len, config.num_heads, config.d_kv, dtype=dtype)
    outputs = transformer.decode(
        encoder_input_tokens=inputs,
        encoded=outputs["encoded"],
        decoder_target_tokens=tf.ones((batch_size, 1), dtype=tf.int32),
        decode_position=1,
        decode=True,
        max_decode_len=max_decode_len,
        cache=cache)
    self.assertEqual(outputs["logits"].shape,
                     (batch_size, 1, config.vocab_size))
    for v in transformer.trainable_variables:
      self.assertEqual(v.dtype, tf.float32)

  def test_transformer_return_attn_scores(self):
    max_decode_len = 10
    config = t5.T5TransformerParams(
        num_layers=1,
        d_model=8,
        d_kv=4,
        num_heads=4,
        d_ff=32,
        vocab_size=10,
        shared_embedding=True,
        layer_sharing=False,
        ffn_activations=("relu",),
        logits_via_embedding=True,
        return_attention_scores=True,
    )
    transformer = t5.T5Transformer(config, compute_dtype=tf.float32)
    self.assertLen(transformer.trainable_variables, 26)
    inputs = tf.convert_to_tensor(
        np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]])
    )
    segments = tf.convert_to_tensor(
        np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]])
    )

    outputs = transformer(
        encoder_input_tokens=inputs,
        decoder_input_tokens=inputs,
        decoder_target_tokens=inputs,
        encoder_segment_ids=segments,
        decoder_segment_ids=segments,
    )
    self.assertIn("attention_scores", outputs)
    self.assertLen(outputs["attention_scores"], 1)
    self.assertEqual(outputs["attention_scores"][0].shape, (2, 4, 6, 6))
    cache = {}
    batch_size = 2
    cache[0] = _create_cache(
        batch_size,
        max_decode_len,
        config.num_heads,
        config.d_kv,
        dtype=tf.float32,
    )
    outputs = transformer.decode(
        encoder_input_tokens=inputs,
        encoded=outputs["encoded"],
        decoder_target_tokens=tf.ones((batch_size, 1), dtype=tf.int32),
        decode_position=1,
        decode=True,
        max_decode_len=max_decode_len,
        cache=cache)
    self.assertEqual(outputs["logits"].shape,
                     (batch_size, 1, config.vocab_size))
    for v in transformer.trainable_variables:
      self.assertEqual(v.dtype, tf.float32)

  @parameterized.named_parameters(
      ("t5_10_dense", ("relu",), True, 26, False, tf.float32),)
  def test_transformer_with_dense(self, ffn_activations, logits_via_embedding,
                                  expect_num_variables, layer_sharing, dtype):
    max_decode_len = 10
    config = t5.T5TransformerParams(
        num_layers=1,
        d_model=8,
        d_kv=4,
        num_heads=4,
        d_ff=32,
        vocab_size=10,
        shared_embedding=True,
        layer_sharing=layer_sharing,
        ffn_activations=ffn_activations,
        logits_via_embedding=logits_via_embedding)
    transformer = t5.T5Transformer(config, compute_dtype=dtype)

    self.assertLen(transformer.trainable_variables, expect_num_variables)
    inputs = tf.convert_to_tensor(
        np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]]))
    segments = tf.convert_to_tensor(
        np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]]))

    dense_inputs = tf.convert_to_tensor(np.random.randn(2, 2, 8), dtype=dtype)
    dense_segments = tf.convert_to_tensor(np.array([[1, 2], [1, 2]]))
    outputs = transformer(
        encoder_input_tokens=inputs,
        encoder_dense_inputs=dense_inputs,
        decoder_input_tokens=inputs,
        decoder_target_tokens=inputs,
        encoder_segment_ids=segments,
        encoder_dense_segment_ids=dense_segments,
        decoder_segment_ids=segments)
    cache = {}
    batch_size = 2
    cache[0] = _create_cache(
        batch_size, max_decode_len, config.num_heads, config.d_kv, dtype=dtype)
    outputs = transformer.decode(
        encoder_input_tokens=inputs,
        encoder_dense_inputs=dense_inputs,
        encoded=outputs["encoded"],
        decoder_target_tokens=tf.ones((batch_size, 1), dtype=tf.int32),
        decode_position=1,
        decode=True,
        max_decode_len=max_decode_len,
        cache=cache)
    self.assertEqual(outputs["logits"].shape,
                     (batch_size, 1, config.vocab_size))
    for v in transformer.trainable_variables:
      self.assertEqual(v.dtype, tf.float32)

  @parameterized.named_parameters(
      ("t5_10_dense_layerwise_relpos",
       ("relu",), True, 26, False, tf.float32, False, 1),
      ("t5_10_dense_shared_relpos_d2",
       ("relu",), True, 39, False, tf.float32, True, 2),
      ("t5_10_dense_layerwise_relpos_d2",
       ("relu",), True, 40, False, tf.float32, False, 2),
  )
  def test_transformer_with_lw_relpos(self, ffn_activations,
                                      logits_via_embedding,
                                      expect_num_variables, layer_sharing,
                                      dtype, use_shared_relpos,
                                      num_decoder_layers):
    max_decode_len = 10
    config = t5.T5TransformerParams(
        num_layers=1,
        num_decoder_layers=num_decoder_layers,
        d_model=8,
        d_kv=4,
        num_heads=4,
        d_ff=32,
        vocab_size=10,
        shared_embedding=True,
        layer_sharing=layer_sharing,
        ffn_activations=ffn_activations,
        logits_via_embedding=logits_via_embedding,
        use_shared_relative_position_bias=use_shared_relpos)
    transformer = t5.T5Transformer(config, compute_dtype=dtype)

    self.assertLen(transformer.trainable_variables, expect_num_variables)
    inputs = tf.convert_to_tensor(
        np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]]))
    segments = tf.convert_to_tensor(
        np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]]))

    dense_inputs = tf.convert_to_tensor(np.random.randn(2, 2, 8), dtype=dtype)
    dense_segments = tf.convert_to_tensor(np.array([[1, 2], [1, 2]]))
    outputs = transformer(
        encoder_input_tokens=inputs,
        encoder_dense_inputs=dense_inputs,
        decoder_input_tokens=inputs,
        decoder_target_tokens=inputs,
        encoder_segment_ids=segments,
        encoder_dense_segment_ids=dense_segments,
        decoder_segment_ids=segments)
    cache = {}
    batch_size = 2
    for i in range(num_decoder_layers):
      cache[i] = _create_cache(
          batch_size,
          max_decode_len,
          config.num_heads,
          config.d_kv,
          dtype=dtype)
    outputs = transformer.decode(
        encoder_input_tokens=inputs,
        encoder_dense_inputs=dense_inputs,
        encoded=outputs["encoded"],
        decoder_target_tokens=tf.ones((batch_size, 1), dtype=tf.int32),
        decode_position=1,
        decode=True,
        max_decode_len=max_decode_len,
        cache=cache)
    self.assertEqual(outputs["logits"].shape,
                     (batch_size, 1, config.vocab_size))
    for v in transformer.trainable_variables:
      self.assertEqual(v.dtype, tf.float32)

  @parameterized.named_parameters(
      ("t5_10", ("relu",), True, 26, False, tf.float32),)
  def test_transformer_with_dense_only(self, ffn_activations,
                                       logits_via_embedding,
                                       expect_num_variables, layer_sharing,
                                       dtype):
    max_decode_len = 10
    config = t5.T5TransformerParams(
        num_layers=1,
        d_model=8,
        d_kv=4,
        num_heads=4,
        d_ff=32,
        vocab_size=10,
        shared_embedding=True,
        layer_sharing=layer_sharing,
        ffn_activations=ffn_activations,
        logits_via_embedding=logits_via_embedding)
    transformer = t5.T5Transformer(config, compute_dtype=dtype)
    self.assertLen(transformer.trainable_variables, expect_num_variables)

    decoder_inputs = tf.convert_to_tensor(
        np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]]))
    decoder_segments = tf.convert_to_tensor(
        np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]]))

    dense_inputs = tf.convert_to_tensor(np.random.randn(2, 2, 8), dtype=dtype)
    dense_segments = tf.convert_to_tensor(np.array([[1, 2], [1, 2]]))
    outputs = transformer(
        encoder_dense_inputs=dense_inputs,
        encoder_dense_segment_ids=dense_segments,
        decoder_input_tokens=decoder_inputs,
        decoder_target_tokens=decoder_inputs,
        decoder_segment_ids=decoder_segments)
    cache = {}
    batch_size = 2
    cache[0] = _create_cache(
        batch_size, max_decode_len, config.num_heads, config.d_kv, dtype=dtype)
    outputs = transformer.decode(
        encoder_dense_inputs=dense_inputs,
        encoded=outputs["encoded"],
        decoder_target_tokens=tf.ones((batch_size, 1), dtype=tf.int32),
        decode_position=1,
        decode=True,
        max_decode_len=max_decode_len,
        cache=cache)
    self.assertEqual(outputs["logits"].shape,
                     (batch_size, 1, config.vocab_size))
    for v in transformer.trainable_variables:
      self.assertEqual(v.dtype, tf.float32)

  @parameterized.named_parameters(
      ("t5_10", ("relu",), True, 39, tf.float32, 2),
      ("t5_10_bfloat16", ("relu",), True, 39, tf.bfloat16, 2))
  def test_transformer_different_num_decoder_layers(self, ffn_activations,
                                                    logits_via_embedding,
                                                    expect_num_variables, dtype,
                                                    num_decoder_layers):
    max_decode_len = 10
    config = t5.T5TransformerParams(
        num_decoder_layers=num_decoder_layers,
        num_layers=1,
        d_model=8,
        d_kv=4,
        num_heads=4,
        d_ff=32,
        vocab_size=10,
        shared_embedding=True,
        ffn_activations=ffn_activations,
        logits_via_embedding=logits_via_embedding)
    transformer = t5.T5Transformer(config, compute_dtype=dtype)
    self.assertLen(transformer.trainable_variables, expect_num_variables)
    inputs = tf.convert_to_tensor(
        np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]]))
    segments = tf.convert_to_tensor(
        np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]]))

    outputs = transformer(
        encoder_input_tokens=inputs,
        decoder_input_tokens=inputs,
        decoder_target_tokens=inputs,
        encoder_segment_ids=segments,
        decoder_segment_ids=segments)
    cache = {}
    batch_size = 2
    for i in range(num_decoder_layers):
      cache[i] = _create_cache(
          batch_size,
          max_decode_len,
          config.num_heads,
          config.d_kv,
          dtype=dtype)
    outputs = transformer.decode(
        encoder_input_tokens=inputs,
        encoded=outputs["encoded"],
        decoder_target_tokens=tf.ones((batch_size, 1), dtype=tf.int32),
        decode_position=1,
        decode=True,
        max_decode_len=max_decode_len,
        cache=cache)
    self.assertEqual(outputs["logits"].shape,
                     (batch_size, 1, config.vocab_size))
    for v in transformer.trainable_variables:
      self.assertEqual(v.dtype, tf.float32)


if __name__ == "__main__":
  tf.test.main()