tensorflow/models

View on GitHub
official/nlp/modeling/layers/gated_feedforward.py

Summary

Maintainability
C
1 day
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Keras-based gated feedforward layer."""
# pylint: disable=g-classes-have-attributes

import gin
import tensorflow as tf, tf_keras

from official.modeling import tf_utils
from official.nlp.modeling.layers import util


@tf_keras.utils.register_keras_serializable(package="Text")
@gin.configurable
class GatedFeedforward(tf_keras.layers.Layer):
  """Gated linear feedforward layer.

  This layer follows the paper "GLU Variants Improve Transformer"
  (https://arxiv.org/abs/2002.05202). In additional, it allows to stack
  multiple feedforward blocks and specify the position of dropout layer.

  Args:
    intermediate_size: Size of the intermediate layer.
    intermediate_activation: Activation for the intermediate layer.
    dropout: Dropout probability for the output dropout.
    use_gate: Whether to use gated linear units. If True, assuming `GELU` as the
      activation and omitting bias, will apply
      `GEGLU(x, W, V, W_2) = (GEGLU(xW) * xV)W2`; if False, will follow
      "Attention Is All You Need" (https://arxiv.org/abs/1706.03762) paper and
        apply `FFN(x, W, W_2) = GELU(xW_1)W_2.`
    num_blocks: The number of feedforward blocks to stack. Each block contains a
      (gated) linear layer and a fully connected layer followed by dropout,
      layer norm and residual.
    dropout_position: Where to apply the dropout, the value can be either
      `before_residual` or `after_residual`. If `before_residual`, will apply
      `layer_output = layer_norm(dropout(layer_output) + layer_input)`; if
      `after residual`, will apply
      `layer_output = dropout(layer_norm(layer_output + layer_input))`.
    kernel_initializer: Initializer for dense layer kernels.
    bias_initializer: Initializer for dense layer biases.
    kernel_regularizer: Regularizer for dense layer kernels.
    bias_regularizer: Regularizer for dense layer biases.
    activity_regularizer: Regularizer for dense layer activity.
    kernel_constraint: Constraint for dense layer kernels.
    bias_constraint: Constraint for dense layer kernels.
  """

  def __init__(self,
               inner_dim=768,
               inner_activation=tf_utils.get_activation("gelu"),
               dropout=0.0,
               use_gate=True,
               apply_output_layer_norm=True,
               num_blocks=1,
               dropout_position="before_residual",
               kernel_initializer="glorot_uniform",
               bias_initializer="zeros",
               kernel_regularizer=None,
               bias_regularizer=None,
               activity_regularizer=None,
               kernel_constraint=None,
               bias_constraint=None,
               **kwargs):
    inner_dim = kwargs.pop("intermediate_size", inner_dim)
    inner_activation = kwargs.pop("intermediate_activation", inner_activation)
    util.filter_kwargs(kwargs)
    super().__init__(**kwargs)
    self._inner_dim = inner_dim
    self._inner_activation = inner_activation
    self._dropout = dropout
    self._use_gate = use_gate
    self._num_blocks = num_blocks
    self._apply_output_layer_norm = apply_output_layer_norm
    self._dropout_position = dropout_position
    if self._dropout_position not in ("before_residual", "after_residual"):
      raise ValueError(
          "The dropout_position should be either `before_residual` or"
          "`after_residual`, got: %s" % self._dropout_position)

    self._kernel_initializer = tf_keras.initializers.get(kernel_initializer)
    self._bias_initializer = tf_keras.initializers.get(bias_initializer)
    self._kernel_regularizer = tf_keras.regularizers.get(kernel_regularizer)
    self._bias_regularizer = tf_keras.regularizers.get(bias_regularizer)
    self._activity_regularizer = tf_keras.regularizers.get(activity_regularizer)
    self._kernel_constraint = tf_keras.constraints.get(kernel_constraint)
    self._bias_constraint = tf_keras.constraints.get(bias_constraint)

  def build(self, input_shape):
    hidden_size = input_shape.as_list()[-1]

    common_kwargs = dict(
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer,
        activity_regularizer=self._activity_regularizer,
        kernel_constraint=self._kernel_constraint,
        bias_constraint=self._bias_constraint)
    self._intermediate_dense = []
    self._inner_activation_layers = []
    self._gate_dense = []
    self._output_dense = []
    self._output_dropout = []
    self._output_layer_norm = []
    activation_policy = tf_keras.mixed_precision.global_policy()
    if activation_policy.name == "mixed_bfloat16":
      # bfloat16 causes BERT with the LAMB optimizer to not converge
      # as well, so we use float32.
      # TODO(b/154538392): Investigate this.
      activation_policy = tf.float32
    for i in range(self._num_blocks):
      self._intermediate_dense.append(
          tf_keras.layers.EinsumDense(
              "abc,cd->abd",
              output_shape=(None, self._inner_dim),
              bias_axes="d",
              name="intermediate_%d" % i,
              kernel_initializer=tf_utils.clone_initializer(
                  self._kernel_initializer),
              bias_initializer=tf_utils.clone_initializer(
                  self._bias_initializer),
              **common_kwargs))
      self._inner_activation_layers.append(
          tf_keras.layers.Activation(
              self._inner_activation, dtype=activation_policy))
      if self._use_gate:
        self._gate_dense.append(
            tf_keras.layers.EinsumDense(
                "abc,cd->abd",
                output_shape=(None, self._inner_dim),
                bias_axes="d",
                name="gate_%d" % i,
                kernel_initializer=tf_utils.clone_initializer(
                    self._kernel_initializer),
                bias_initializer=tf_utils.clone_initializer(
                    self._bias_initializer),
                **common_kwargs))
      self._output_dense.append(
          tf_keras.layers.EinsumDense(
              "abc,cd->abd",
              output_shape=(None, hidden_size),
              bias_axes="d",
              name="output_%d" % i,
              kernel_initializer=tf_utils.clone_initializer(
                  self._kernel_initializer),
              bias_initializer=tf_utils.clone_initializer(
                  self._bias_initializer),
              **common_kwargs))
      self._output_dropout.append(tf_keras.layers.Dropout(rate=self._dropout))
      # Use float32 in layernorm for numeric stability.
      if self._apply_output_layer_norm:
        self._output_layer_norm.append(
            tf_keras.layers.LayerNormalization(
                name="output_layer_norm_%d" % i,
                axis=-1,
                epsilon=1e-12,
                dtype=tf.float32))

  def get_config(self):
    config = {
        "inner_dim":
            self._inner_dim,
        "inner_activation":
            self._inner_activation,
        "dropout":
            self._dropout,
        "use_gate":
            self._use_gate,
        "num_blocks":
            self._num_blocks,
        "dropout_position":
            self._dropout_position,
        "kernel_initializer":
            tf_keras.initializers.serialize(self._kernel_initializer),
        "bias_initializer":
            tf_keras.initializers.serialize(self._bias_initializer),
        "kernel_regularizer":
            tf_keras.regularizers.serialize(self._kernel_regularizer),
        "bias_regularizer":
            tf_keras.regularizers.serialize(self._bias_regularizer),
        "activity_regularizer":
            tf_keras.regularizers.serialize(self._activity_regularizer),
        "kernel_constraint":
            tf_keras.constraints.serialize(self._kernel_constraint),
        "bias_constraint":
            tf_keras.constraints.serialize(self._bias_constraint)
    }
    base_config = super().get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def call(self, inputs):
    layer_output = inputs
    for i in range(self._num_blocks):
      layer_input = layer_output
      intermediate_output = self._intermediate_dense[i](layer_input)
      intermediate_output = self._inner_activation_layers[i](
          intermediate_output)
      if self._use_gate:
        gated_linear = self._gate_dense[i](layer_input)
        intermediate_output = intermediate_output * gated_linear

      layer_output = self._output_dense[i](intermediate_output)
      if self._dropout_position == "before_residual":
        layer_output = self._output_dropout[i](layer_output)

      # During mixed precision training, `layer_input` may be from layer norm.
      # If so, it is always fp32. Cast layer_output to fp32 for the subsequent
      # add.
      if layer_input.dtype == tf.float32:
        layer_output = tf.cast(layer_output, tf.float32)
      if self._apply_output_layer_norm:
        layer_output = self._output_layer_norm[i](layer_output + layer_input)
      if self._dropout_position == "after_residual":
        layer_output = self._output_dropout[i](layer_output)

    return layer_output