tensorflow/models

View on GitHub
official/projects/pix2seq/modeling/transformer.py

Summary

Maintainability
F
3 days
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Specialized Transformers for Pix2Seq.

the position embeddings are added to the query and key for every self- and
cross-attention layer.
"""

import tensorflow as tf, tf_keras


class TransformerEncoder(tf_keras.layers.Layer):
  """Transformer encoder."""

  def __init__(
      self,
      num_layers,
      dim,
      mlp_ratio,
      num_heads,
      drop_path=0.1,
      drop_units=0.1,
      drop_att=0.0,
      self_attention=True,
      use_ffn_ln=False,
      ln_scale_shift=True,
      **kwargs
  ):
    super().__init__(**kwargs)
    self._num_layers = num_layers
    self._dim = dim
    self._mlp_ratio = mlp_ratio
    self._num_heads = num_heads
    self._drop_path = drop_path
    self._drop_units = drop_units
    self._drop_att = drop_att
    self._self_attention = self_attention
    self._use_ffn_ln = use_ffn_ln
    self._ln_scale_shift = ln_scale_shift

    self.enc_layers = [
        TransformerEncoderLayer(  # pylint: disable=g-complex-comprehension
            dim,
            mlp_ratio,
            num_heads,
            drop_path,
            drop_units,
            drop_att,
            self_attention=self_attention,
            use_ffn_ln=use_ffn_ln,
            ln_scale_shift=ln_scale_shift,
            name='transformer_encoder' + suffix_id(i),
        )
        for i in range(num_layers)
    ]

  def call(self, x, mask, training, ret_list=False):
    x_list = [x]
    for i in range(self._num_layers):
      x = self.enc_layers[i](x, mask, training)
      x_list.append(x)
    return (x, x_list) if ret_list else x

  def get_config(self):
    config = super().get_config()
    updates = {
        'num_layers': self._num_layers,
        'dim': self._dim,
        'mlp_ratio': self._mlp_ratio,
        'num_heads': self._num_heads,
        'drop_path': self._drop_path,
        'drop_units': self._drop_units,
        'drop_att': self._drop_att,
        'self_attention': self._self_attention,
        'use_ffn_ln': self._use_ffn_ln,
        'ln_scale_shift': self._ln_scale_shift,
    }
    config.update(updates)
    return config


class TransformerEncoderLayer(tf_keras.layers.Layer):  # pylint: disable=missing-docstring

  def __init__(
      self,
      dim,
      mlp_ratio,
      num_heads,
      drop_path=0.1,
      drop_units=0.1,
      drop_att=0.0,
      self_attention=True,
      use_ffn_ln=False,
      ln_scale_shift=True,
      **kwargs
  ):
    super().__init__(**kwargs)
    self._dim = dim
    self._mlp_ratio = mlp_ratio
    self._num_heads = num_heads
    self._drop_path = drop_path
    self._drop_units = drop_units
    self._drop_att = drop_att
    self.self_attention = self_attention
    self._use_ffn_ln = use_ffn_ln
    self._ln_scale_shift = ln_scale_shift

    if self_attention:
      self.mha_ln = tf_keras.layers.LayerNormalization(
          epsilon=1e-6,
          center=ln_scale_shift,
          scale=ln_scale_shift,
          name='mha/ln',
      )
      self.mha = tf_keras.layers.MultiHeadAttention(
          num_heads, dim // num_heads, dropout=drop_att, name='mha'
      )
    self.mlp = MLP(
        1,
        dim,
        mlp_ratio,
        drop_path,
        drop_units,
        use_ffn_ln=use_ffn_ln,
        ln_scale_shift=ln_scale_shift,
        name='mlp',
    )
    self.dropp = DropPath(drop_path)

  def call(self, x, mask, training):
    # x shape (bsz, seq_len, dim_att), mask shape (bsz, seq_len, seq_len).
    if self.self_attention:
      x_ln = self.mha_ln(x)
      x_residual = self.mha(x_ln, x_ln, x_ln, mask, training=training)
      x = x + self.dropp(x_residual, training)
    x = self.mlp(x, training)
    return x

  def get_config(self):
    config = super().get_config()
    updates = {
        'dim': self._dim,
        'mlp_ratio': self._mlp_ratio,
        'num_heads': self._num_heads,
        'drop_path': self._drop_path,
        'drop_units': self._drop_units,
        'drop_att': self._drop_att,
        'self_attention': self._self_attention,
        'use_ffn_ln': self._use_ffn_ln,
        'ln_scale_shift': self._ln_scale_shift,
    }
    config.update(updates)
    return config


def suffix_id(i):
  """Return suffix id for layer/variable name."""
  return '' if i == 0 else '_%d' % i


class DropPath(tf_keras.layers.Layer):
  """For stochastic depth."""

  def __init__(self, drop_rate=0.0, **kwargs):
    """Initializes a drop path layer."""
    super().__init__(**kwargs)
    self._drop_rate = drop_rate
    if self._drop_rate < 0 or self._drop_rate >= 1.0:
      raise ValueError('drop_rate {} is outside [0, 1)'.format(self._drop_rate))

  def call(self, x, training=False):
    """Performs a forward pass.

    Args:
      x: An input tensor of type tf.Tensor with shape [batch, height, width,
        channels].
      training: A boolean flag indicating whether training behavior should be
        used (default: False).

    Returns:
      The output tensor.
    """
    if self._drop_rate == 0.0 or not training:
      return x

    keep_rate = 1.0 - self._drop_rate
    xshape = tf.shape(x)
    drop_mask_shape = [xshape[0]] + [1] * (len(xshape) - 1)
    drop_mask = keep_rate + tf.random.uniform(drop_mask_shape, dtype=x.dtype)
    drop_mask = tf.math.divide(tf.floor(drop_mask), keep_rate)
    return x * drop_mask

  def get_config(self):
    config = super().get_config()
    updates = {
        'drop_rate': self._drop_rate,
    }
    config.update(updates)
    return config


class FeedForwardLayer(tf_keras.layers.Layer):  # pylint: disable=missing-docstring

  def __init__(
      self,
      dim_att=256,
      dim_mlp=1024,
      drop_units=0.1,
      use_ln=False,
      ln_scale_shift=False,
      **kwargs
  ):
    super().__init__(**kwargs)
    self._dim_att = dim_att
    self._dim_mlp = dim_mlp
    self._drop_units = drop_units
    self._use_ln = use_ln
    self._ln_scale_shift = ln_scale_shift

    self.dense1 = tf_keras.layers.Dense(
        dim_mlp, activation=tf.nn.gelu, name='dense1'
    )
    self.dropout = tf_keras.layers.Dropout(drop_units)
    self.dense2 = tf_keras.layers.Dense(dim_att, name='dense2')
    if use_ln:
      self.ln = tf_keras.layers.LayerNormalization(
          epsilon=1e-6,
          center=ln_scale_shift,
          scale=ln_scale_shift,
          name='mlp_ln',
      )
    else:
      self.ln = lambda x: x

  def call(self, x, training):
    return self.dense2(self.dropout(self.ln(self.dense1(x)), training=training))

  def get_config(self):
    config = super().get_config()
    updates = {
        'dim_att': self._dim_att,
        'dim_mlp': self._dim_mlp,
        'drop_units': self._drop_units,
        'use_ln': self._use_ln,
        'ln_scale_shift': self._ln_scale_shift,
    }
    config.update(updates)
    return config


class MLP(tf_keras.layers.Layer):  # pylint: disable=missing-docstring

  def __init__(
      self,
      num_layers,
      dim,
      mlp_ratio,
      drop_path=0.1,
      drop_units=0.0,
      use_ffn_ln=False,
      ln_scale_shift=True,
      **kwargs
  ):
    super().__init__(**kwargs)
    self._num_layers = num_layers
    self._dim = dim
    self._mlp_ratio = mlp_ratio
    self._drop_path = drop_path
    self._drop_units = drop_units
    self._use_ffn_ln = use_ffn_ln
    self._ln_scale_shift = ln_scale_shift

    self.mlp_layers = []
    self.layernorms = []
    for i in range(num_layers):
      self.mlp_layers.append(
          FeedForwardLayer(
              dim,
              dim * mlp_ratio,
              drop_units,
              use_ln=use_ffn_ln,
              ln_scale_shift=ln_scale_shift,
              name='ffn' + suffix_id(i),
          )
      )
      self.layernorms.append(
          tf_keras.layers.LayerNormalization(
              epsilon=1e-6,
              center=ln_scale_shift,
              scale=ln_scale_shift,
              name='ffn/ln' + suffix_id(i),
          )
      )
    self.dropp = DropPath(drop_path)

  def call(self, x, training, ret_list=False):
    x_list = [x]
    for i in range(self._num_layers):
      x_residual = self.mlp_layers[i](self.layernorms[i](x), training)
      x = x + self.dropp(x_residual, training)
      x_list.append(x)
    return (x, x_list) if ret_list else x

  def get_config(self):
    config = super().get_config()
    updates = {
        'num_layers': self._num_layers,
        'dim': self._dim,
        'mlp_ratio': self._mlp_ratio,
        'drop_path': self._drop_path,
        'drop_units': self._drop_units,
        'use_ffn_ln': self._use_ffn_ln,
        'ln_scale_shift': self._ln_scale_shift,
    }
    config.update(updates)
    return config


class TransformerDecoderLayer(tf_keras.layers.Layer):  # pylint: disable=missing-docstring

  def __init__(
      self,
      dim,
      mlp_ratio,
      num_heads,
      drop_path=0.1,
      drop_units=0.1,
      drop_att=0.0,
      dim_x_att=None,
      self_attention=True,
      cross_attention=True,
      use_mlp=True,
      use_enc_ln=False,
      use_ffn_ln=False,
      ln_scale_shift=True,
      **kwargs
  ):
    super().__init__(**kwargs)
    self._dim = dim
    self._mlp_ratio = mlp_ratio
    self._num_heads = num_heads
    self._drop_path = drop_path
    self._drop_units = drop_units
    self._drop_att = drop_att
    self._dim_x_att = dim_x_att
    self._self_attention = self_attention
    self._cross_attention = cross_attention
    self._use_mlp = use_mlp
    self._use_enc_ln = use_enc_ln
    self._use_ffn_ln = use_ffn_ln
    self._ln_scale_shift = ln_scale_shift

    if self_attention:
      self.self_ln = tf_keras.layers.LayerNormalization(
          epsilon=1e-6,
          center=ln_scale_shift,
          scale=ln_scale_shift,
          name='self_mha/ln',
      )
      self.self_mha = tf_keras.layers.MultiHeadAttention(
          num_heads, dim // num_heads, dropout=drop_att, name='self_mha'
      )
    if cross_attention:
      self.cross_ln = tf_keras.layers.LayerNormalization(
          epsilon=1e-6,
          center=ln_scale_shift,
          scale=ln_scale_shift,
          name='cross_mha/ln',
      )
      if use_enc_ln:
        self.enc_ln = tf_keras.layers.LayerNormalization(
            epsilon=1e-6,
            center=ln_scale_shift,
            scale=ln_scale_shift,
            name='cross_mha/enc_ln',
        )
      else:
        self.enc_ln = lambda x: x
      dim_x_att = dim if dim_x_att is None else dim_x_att
      self.cross_mha = tf_keras.layers.MultiHeadAttention(
          num_heads, dim_x_att // num_heads, dropout=drop_att, name='cross_mha'
      )
    if use_mlp:
      self.mlp = MLP(
          1,
          dim,
          mlp_ratio,
          drop_path,
          drop_units,
          use_ffn_ln=use_ffn_ln,
          ln_scale_shift=ln_scale_shift,
          name='mlp',
      )
    self.dropp = DropPath(drop_path)

  def call(self, x, enc, cache, mask_self, mask_cross, training):
    """x in (bsz, seq, d), enc in (bsz, seq', d)."""
    x_for_cache = []
    if self._self_attention:
      x_for_cache = x_ln = kv_ln = self.self_ln(x)
      if cache is not None:  # Augment kv_ln with cache in (bsz, c_size, d).
        q_size, k_size = tf.shape(x)[1], tf.shape(cache)[1]
        mask_self = tf.concat([tf.ones([1, 1, q_size, k_size]), mask_self], -1)
        kv_ln = tf.concat([cache, x_ln], axis=1)
      x_res = self.self_mha(x_ln, kv_ln, kv_ln, mask_self, training=training)
      x = x + self.dropp(x_res, training)
    if self._cross_attention:
      x_ln = self.cross_ln(x)
      enc = self.enc_ln(enc)
      x_res = self.cross_mha(x_ln, enc, enc, mask_cross, training=training)
      x = x + self.dropp(x_res, training)
    if self._use_mlp:
      x = self.mlp(x, training)
    return x, x_for_cache

  def get_config(self):
    config = super().get_config()
    updates = {
        'dim': self._dim,
        'mlp_ratio': self._mlp_ratio,
        'num_heads': self._num_heads,
        'drop_path': self._drop_path,
        'drop_units': self._drop_units,
        'drop_att': self._drop_att,
        'dim_x_att': self._dim_x_att,
        'self_attention': self._self_attention,
        'cross_attention': self._cross_attention,
        'use_mlp': self._use_mlp,
        'use_enc_ln': self._use_enc_ln,
        'use_ffn_ln': self._use_ffn_ln,
        'ln_scale_shift': self._ln_scale_shift,
    }
    config.update(updates)
    return config


class TransformerDecoder(tf_keras.layers.Layer):  # pylint: disable=missing-docstring

  def __init__(
      self,
      num_layers,
      dim,
      mlp_ratio,
      num_heads,
      drop_path=0.1,
      drop_units=0.1,
      drop_att=0.0,
      dim_x_att=None,
      self_attention=True,
      cross_attention=True,
      use_mlp=True,
      use_enc_ln=False,
      use_ffn_ln=False,
      ln_scale_shift=True,
      **kwargs
  ):
    super().__init__(**kwargs)
    self._num_layers = num_layers
    self._dim = dim
    self._mlp_ratio = mlp_ratio
    self._num_heads = num_heads
    self._drop_path = drop_path
    self._drop_units = drop_units
    self._drop_att = drop_att
    self._dim_x_att = dim_x_att
    self._self_attention = self_attention
    self._cross_attention = cross_attention
    self._use_mlp = use_mlp
    self._use_enc_ln = use_enc_ln
    self._use_ffn_ln = use_ffn_ln
    self._ln_scale_shift = ln_scale_shift

    self.dec_layers = [
        TransformerDecoderLayer(  # pylint: disable=g-complex-comprehension
            dim,
            mlp_ratio,
            num_heads,
            drop_path,
            drop_units,
            drop_att,
            dim_x_att=dim_x_att,
            self_attention=self_attention,
            cross_attention=cross_attention,
            use_mlp=use_mlp,
            use_enc_ln=use_enc_ln,
            use_ffn_ln=use_ffn_ln,
            ln_scale_shift=ln_scale_shift,
            name='transformer_decoder_layer' + suffix_id(i),
        )
        for i in range(num_layers)
    ]

  def call(self, x, enc, caches, mask_self, mask_cross, training):
    """x in (bsz, seq, d), enc in (bsz, seq', d)."""
    presents = []
    for i in range(self._num_layers):
      cache = None if caches is None else caches[i]
      x, x_for_cache = self.dec_layers[i](
          x, enc, cache, mask_self, mask_cross, training
      )
      presents.append(x_for_cache)

    return x, tf.stack(presents)

  def get_config(self):
    config = super().get_config()
    updates = {
        'num_layers': self._num_layers,
        'dim': self._dim,
        'mlp_ratio': self._mlp_ratio,
        'num_heads': self._num_heads,
        'drop_path': self._drop_path,
        'drop_units': self._drop_units,
        'drop_att': self._drop_att,
        'dim_x_att': self._dim_x_att,
        'self_attention': self._self_attention,
        'cross_attention': self._cross_attention,
        'use_mlp': self._use_mlp,
        'use_enc_ln': self._use_enc_ln,
        'use_ffn_ln': self._use_ffn_ln,
        'ln_scale_shift': self._ln_scale_shift,
    }
    config.update(updates)
    return config