tensorflow/models

View on GitHub
official/nlp/modeling/layers/tn_expand_condense.py

Summary

Maintainability
A
45 mins
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""ExpandCondense tensor network layer used in TN-BERT."""
# pylint: disable=g-classes-have-attributes
from typing import List, Optional, Text, Any, Dict
import tensorflow as tf, tf_keras

from official.modeling import tf_utils

Layer = tf_keras.layers.Layer
activations = tf_keras.activations
initializers = tf_keras.initializers


@tf_keras.utils.register_keras_serializable(package='Text')
class TNExpandCondense(Layer):
  """A TPU-optimized TensorNetwork layer.

  Designed for use in models that currently use Dense layers to achieve
  up projection followed by down projection.

  This layer is a TPU-optimized combination of 3 operations:
  Expand, Apply Activation, and Condense. The layer projects up from
  `input_shape[-1]` to `input_shape[-1] * proj_multiplier`, applies
  `self.activation`, and then condenses back to `input_shape[-1]`.

  Note the input shape and output shape will be identical.

  Args:
    proj_multiplier: Positive integer, multiple of `input_shape[-1]` to project
      up to. Must be one of `[2, 4, 6, 8]`.
    use_bias: Boolean, whether the layer uses a bias vector.
    activation: Activation function to use between Expand and Condense. If you
      don't specify anything, no activation is applied
      (ie. "linear" activation: `a(x) = x`).
    kernel_initializer: Initializer for the weight matrices.
    bias_initializer: Initializer for the bias vector.
  Input shape:
    N-D tensor with shape: `(batch_size, ..., input_shape[-1])`.
  Output shape:
    N-D tensor with shape: `(batch_size, ..., input_shape[-1])`.
  """

  def __init__(self,
               proj_multiplier: int,
               use_bias: Optional[bool] = True,
               activation: Optional[Text] = 'relu',
               kernel_initializer: Optional[Text] = 'glorot_uniform',
               bias_initializer: Optional[Text] = 'zeros',
               **kwargs) -> None:

    # Allow specification of input_dim instead of input_shape,
    # for compatability with Keras layers that support this
    if 'input_shape' not in kwargs and 'input_dim' in kwargs:
      kwargs['input_shape'] = (kwargs.pop('input_dim'),)

    super().__init__(**kwargs)

    assert proj_multiplier in [
        2, 4, 6, 8, 10, 12
    ], 'proj_multiplier needs to be one of [2, 4, 6, 8, 10, 12]'
    self.proj_multiplier = proj_multiplier

    self.use_bias = use_bias
    self.activation = activations.get(activation)
    self.kernel_initializer = initializers.get(kernel_initializer)
    self.bias_initializer = initializers.get(bias_initializer)

  def build(self, input_shape: List[int]) -> None:
    # Disable the attribute-defined-outside-init violations in this function
    # pylint: disable=attribute-defined-outside-init
    if input_shape[-1] is None:
      raise ValueError(
          'The last dimension of the inputs to `TNExpandCondense` '
          'should be defined. Found `None`.')

    super().build(input_shape)

    self.proj_size = self.proj_multiplier * input_shape[-1]

    assert (self.proj_size // input_shape[-1]) * input_shape[
        -1] == self.proj_size, (f'{self.proj_size} / {input_shape[-1]} must be '
                                f'round')
    assert (input_shape[-1] // 128
           ) * 128 == input_shape[-1], f'{input_shape[-1]} / 128 must be round'

    self.w1 = self.add_weight(
        name='w1',
        shape=(input_shape[-1], input_shape[-1]),
        trainable=True,
        initializer=tf_utils.clone_initializer(self.kernel_initializer))

    self.w2 = self.add_weight(
        name='w2',
        shape=(128, (128 * (self.proj_size // input_shape[-1]))),
        trainable=True,
        initializer=tf_utils.clone_initializer(self.kernel_initializer))

    self.w3 = self.add_weight(
        name='w3',
        shape=(128 * (self.proj_size // input_shape[-1]), 128),
        trainable=True,
        initializer=tf_utils.clone_initializer(self.kernel_initializer))
    self.w4 = self.add_weight(
        name='w4',
        shape=(input_shape[-1] // 128, 128, input_shape[-1]),
        trainable=True,
        initializer=tf_utils.clone_initializer(self.kernel_initializer))

    if self.use_bias:
      self.bias = self.add_weight(
          name='b',
          shape=(input_shape[-1] // 128, 1,
                 128 * (self.proj_size // input_shape[-1])),
          trainable=True,
          initializer=self.bias_initializer)
    else:
      self.bias = None

  def call(self, inputs: tf.Tensor, **kwargs):
    orig_shape = tf.shape(inputs)
    input_dim = inputs.shape[-1]
    tmp = tf.reshape(inputs, (-1, input_dim))
    # Shape is (BatchSeq, input_dim)

    # Expansion network
    tmp = tf.einsum('ab,Qb->aQ', self.w1, tmp)
    # Note: Letter Q will always represent the BatchSeq axis.
    tmp = tf.reshape(tmp, (input_dim // 128, 128, -1))
    tmp = tf.einsum('abQ,bd->aQd', tmp, self.w2)

    # Apply activation and then Condense
    tmp = self.activation(tmp + self.bias)
    tmp = tf.einsum('aQd,db->aQb', tmp, self.w3)
    tmp = tf.einsum('aQb,abd->Qd', tmp, self.w4)

    out = tf.reshape(tmp, orig_shape)
    return out

  def compute_output_shape(self, input_shape: List[int]) -> List[int]:
    return input_shape

  def get_config(self) -> Dict[Any, Any]:
    """Returns the config of the layer.

    The same layer can be reinstantiated later
    (without its trained weights) from this configuration.

    Returns:
      Python dictionary containing the configuration of the layer.
    """
    config = {}

    # Include the layer-specific arguments
    args = ['proj_multiplier', 'use_bias']
    for arg in args:
      config[arg] = getattr(self, arg)

    # Serialize the activation
    config['activation'] = activations.serialize(getattr(self, 'activation'))

    # Serialize the initializers
    decomp_initializers = ['kernel_initializer', 'bias_initializer']
    for initializer_arg in decomp_initializers:
      config[initializer_arg] = initializers.serialize(
          getattr(self, initializer_arg))

    # Get base config
    base_config = super().get_config()
    return dict(list(base_config.items()) + list(config.items()))