tensorflow/models

View on GitHub
official/vision/modeling/layers/nn_layers.py

Summary

Maintainability
F
4 days
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains common building blocks for neural networks."""

from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

from absl import logging
import tensorflow as tf, tf_keras

from official.modeling import tf_utils
from official.vision.ops import spatial_transform_ops


# Type annotations.
States = Dict[str, tf.Tensor]
Activation = Union[str, Callable]


def make_divisible(value: float,
                   divisor: int,
                   min_value: Optional[float] = None,
                   round_down_protect: bool = True,
                   ) -> int:
  """This is to ensure that all layers have channels that are divisible by 8.

  Args:
    value: A `float` of original value.
    divisor: An `int` of the divisor that need to be checked upon.
    min_value: A `float` of  minimum value threshold.
    round_down_protect: A `bool` indicating whether round down more than 10%
      will be allowed.

  Returns:
    The adjusted value in `int` that is divisible against divisor.
  """
  if min_value is None:
    min_value = divisor
  new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
  # Make sure that round down does not go down by more than 10%.
  if round_down_protect and new_value < 0.9 * value:
    new_value += divisor
  return int(new_value)


def round_filters(filters: int,
                  multiplier: float,
                  divisor: int = 8,
                  min_depth: Optional[int] = None,
                  round_down_protect: bool = True,
                  skip: bool = False) -> int:
  """Rounds number of filters based on width multiplier."""
  orig_f = filters
  if skip or not multiplier:
    return filters

  new_filters = make_divisible(value=filters * multiplier,
                               divisor=divisor,
                               min_value=min_depth,
                               round_down_protect=round_down_protect)

  logging.info('round_filter input=%s output=%s', orig_f, new_filters)
  return int(new_filters)


def get_padding_for_kernel_size(kernel_size):
  """Compute padding size given kernel size."""
  if kernel_size == 7:
    return (3, 3)
  elif kernel_size == 3:
    return (1, 1)
  else:
    raise ValueError('Padding for kernel size {} not known.'.format(
        kernel_size))


@tf_keras.utils.register_keras_serializable(package='Vision')
class SqueezeExcitation(tf_keras.layers.Layer):
  """Creates a squeeze and excitation layer."""

  def __init__(self,
               in_filters,
               out_filters,
               se_ratio,
               divisible_by=1,
               use_3d_input=False,
               kernel_initializer='VarianceScaling',
               kernel_regularizer=None,
               bias_regularizer=None,
               activation='relu',
               gating_activation='sigmoid',
               round_down_protect=True,
               **kwargs):
    """Initializes a squeeze and excitation layer.

    Args:
      in_filters: An `int` number of filters of the input tensor.
      out_filters: An `int` number of filters of the output tensor.
      se_ratio: A `float` or None. If not None, se ratio for the squeeze and
        excitation layer.
      divisible_by: An `int` that ensures all inner dimensions are divisible by
        this number.
      use_3d_input: A `bool` of whether input is 2D or 3D image.
      kernel_initializer: A `str` of kernel_initializer for convolutional
        layers.
      kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
        Conv2D. Default to None.
      bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2d.
        Default to None.
      activation: A `str` name of the activation function.
      gating_activation: A `str` name of the activation function for final
        gating function.
      round_down_protect: A `bool` of whether round down more than 10% will be
        allowed.
      **kwargs: Additional keyword arguments to be passed.
    """
    super(SqueezeExcitation, self).__init__(**kwargs)

    self._in_filters = in_filters
    self._out_filters = out_filters
    self._se_ratio = se_ratio
    self._divisible_by = divisible_by
    self._round_down_protect = round_down_protect
    self._use_3d_input = use_3d_input
    self._activation = activation
    self._gating_activation = gating_activation
    self._kernel_initializer = kernel_initializer
    self._kernel_regularizer = kernel_regularizer
    self._bias_regularizer = bias_regularizer
    if tf_keras.backend.image_data_format() == 'channels_last':
      if not use_3d_input:
        self._spatial_axis = [1, 2]
      else:
        self._spatial_axis = [1, 2, 3]
    else:
      if not use_3d_input:
        self._spatial_axis = [2, 3]
      else:
        self._spatial_axis = [2, 3, 4]
    self._activation_fn = tf_utils.get_activation(activation)
    self._gating_activation_fn = tf_utils.get_activation(gating_activation)

  def build(self, input_shape):
    num_reduced_filters = make_divisible(
        max(1, int(self._in_filters * self._se_ratio)),
        divisor=self._divisible_by,
        round_down_protect=self._round_down_protect)

    self._se_reduce = tf_keras.layers.Conv2D(
        filters=num_reduced_filters,
        kernel_size=1,
        strides=1,
        padding='same',
        use_bias=True,
        kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer)

    self._se_expand = tf_keras.layers.Conv2D(
        filters=self._out_filters,
        kernel_size=1,
        strides=1,
        padding='same',
        use_bias=True,
        kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer)

    super(SqueezeExcitation, self).build(input_shape)

  def get_config(self):
    config = {
        'in_filters': self._in_filters,
        'out_filters': self._out_filters,
        'se_ratio': self._se_ratio,
        'divisible_by': self._divisible_by,
        'use_3d_input': self._use_3d_input,
        'kernel_initializer': self._kernel_initializer,
        'kernel_regularizer': self._kernel_regularizer,
        'bias_regularizer': self._bias_regularizer,
        'activation': self._activation,
        'gating_activation': self._gating_activation,
        'round_down_protect': self._round_down_protect,
    }
    base_config = super(SqueezeExcitation, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def call(self, inputs):
    x = tf.reduce_mean(inputs, self._spatial_axis, keepdims=True)
    x = self._activation_fn(self._se_reduce(x))
    x = self._gating_activation_fn(self._se_expand(x))
    return x * inputs


def get_stochastic_depth_rate(init_rate, i, n):
  """Get drop connect rate for the ith block.

  Args:
    init_rate: A `float` of initial drop rate.
    i: An `int` of order of the current block.
    n: An `int` total number of blocks.

  Returns:
    Drop rate of the ith block.
  """
  if init_rate is not None:
    if init_rate < 0 or init_rate > 1:
      raise ValueError('Initial drop rate must be within 0 and 1.')
    rate = init_rate * float(i) / n
  else:
    rate = None
  return rate


@tf_keras.utils.register_keras_serializable(package='Vision')
class StochasticDepth(tf_keras.layers.Layer):
  """Creates a stochastic depth layer."""

  def __init__(self, stochastic_depth_drop_rate, **kwargs):
    """Initializes a stochastic depth layer.

    Args:
      stochastic_depth_drop_rate: A `float` of drop rate.
      **kwargs: Additional keyword arguments to be passed.

    Returns:
      A output `tf.Tensor` of which should have the same shape as input.
    """
    super(StochasticDepth, self).__init__(**kwargs)
    self._drop_rate = stochastic_depth_drop_rate

  def get_config(self):
    config = {'stochastic_depth_drop_rate': self._drop_rate}
    base_config = super(StochasticDepth, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def call(self, inputs, training=None):
    if training is None:
      training = tf_keras.backend.learning_phase()
    if not training or self._drop_rate is None or self._drop_rate == 0:
      return inputs

    keep_prob = 1.0 - self._drop_rate
    batch_size = tf.shape(inputs)[0]
    random_tensor = keep_prob
    random_tensor += tf.random.uniform(
        [batch_size] + [1] * (inputs.shape.rank - 1), dtype=inputs.dtype)
    binary_tensor = tf.floor(random_tensor)
    output = tf.math.divide(inputs, keep_prob) * binary_tensor
    return output


@tf_keras.utils.register_keras_serializable(package='Vision')
def pyramid_feature_fusion(inputs, target_level):
  """Fuses all feature maps in the feature pyramid at the target level.

  Args:
    inputs: A dictionary containing the feature pyramid. The size of the input
      tensor needs to be fixed.
    target_level: An `int` of the target feature level for feature fusion.

  Returns:
    A `float` `tf.Tensor` of shape [batch_size, feature_height, feature_width,
      feature_channel].
  """
  # Convert keys to int.
  pyramid_feats = {int(k): v for k, v in inputs.items()}
  min_level = min(pyramid_feats.keys())
  max_level = max(pyramid_feats.keys())
  resampled_feats = []

  for l in range(min_level, max_level + 1):
    if l == target_level:
      resampled_feats.append(pyramid_feats[l])
    else:
      feat = pyramid_feats[l]
      target_size = list(feat.shape[1:3])
      target_size[0] *= 2**(l - target_level)
      target_size[1] *= 2**(l - target_level)
      # Casts feat to float32 so the resize op can be run on TPU.
      feat = tf.cast(feat, tf.float32)
      feat = tf.image.resize(
          feat, size=target_size, method=tf.image.ResizeMethod.BILINEAR)
      # Casts it back to be compatible with the rest opetations.
      feat = tf.cast(feat, pyramid_feats[l].dtype)
      resampled_feats.append(feat)

  return tf.math.add_n(resampled_feats)


class PanopticFPNFusion(tf_keras.Model):
  """Creates a Panoptic FPN feature Fusion layer.

  This implements feature fusion for semantic segmentation head from the paper:
  Alexander Kirillov, Ross Girshick, Kaiming He and Piotr Dollar.
  Panoptic Feature Pyramid Networks.
  (https://arxiv.org/pdf/1901.02446.pdf)
  """

  def __init__(
      self,
      min_level: int = 2,
      max_level: int = 5,
      target_level: int = 2,
      num_filters: int = 128,
      num_fpn_filters: int = 256,
      activation: str = 'relu',
      kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      **kwargs):

    """Initializes panoptic FPN feature fusion layer.

    Args:
      min_level: An `int` of minimum level to use in feature fusion.
      max_level: An `int` of maximum level to use in feature fusion.
      target_level: An `int` of the target feature level for feature fusion.
      num_filters: An `int` number of filters in conv2d layers.
      num_fpn_filters: An `int` number of filters in the FPN outputs
      activation: A `str` name of the activation function.
      kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
        Conv2D. Default is None.
      bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D.
      **kwargs: Additional keyword arguments to be passed.
    Returns:
      A `float` `tf.Tensor` of shape [batch_size, feature_height, feature_width,
        feature_channel].
    """
    if target_level > max_level:
      raise ValueError('target_level should be less than max_level')

    self._config_dict = {
        'min_level': min_level,
        'max_level': max_level,
        'target_level': target_level,
        'num_filters': num_filters,
        'num_fpn_filters': num_fpn_filters,
        'activation': activation,
        'kernel_regularizer': kernel_regularizer,
        'bias_regularizer': bias_regularizer,
    }
    norm = tf_keras.layers.GroupNormalization
    conv2d = tf_keras.layers.Conv2D
    activation_fn = tf_utils.get_activation(activation)
    if tf_keras.backend.image_data_format() == 'channels_last':
      norm_axis = -1
    else:
      norm_axis = 1
    inputs = self._build_inputs(num_fpn_filters, min_level, max_level)

    upscaled_features = []
    for level in range(min_level, max_level + 1):
      num_conv_layers = max(1, level - target_level)
      x = inputs[str(level)]
      for i in range(num_conv_layers):
        x = conv2d(
            filters=num_filters,
            kernel_size=3,
            padding='same',
            kernel_initializer=tf_keras.initializers.VarianceScaling(),
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer)(x)
        x = norm(groups=32, axis=norm_axis)(x)
        x = activation_fn(x)
        if level != target_level:
          x = spatial_transform_ops.nearest_upsampling(x, scale=2)
      upscaled_features.append(x)

    fused_features = tf.math.add_n(upscaled_features)
    self._output_specs = {str(target_level): fused_features.get_shape()}

    super(PanopticFPNFusion, self).__init__(
        inputs=inputs, outputs=fused_features, **kwargs)

  def _build_inputs(self, num_filters: int,
                    min_level: int, max_level: int):
    inputs = {}
    for level in range(min_level, max_level + 1):
      inputs[str(level)] = tf_keras.Input(shape=[None, None, num_filters])
    return inputs

  def get_config(self) -> Mapping[str, Any]:
    return self._config_dict

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)

  @property
  def output_specs(self) -> Mapping[str, tf.TensorShape]:
    """A dict of {level: TensorShape} pairs for the model output."""
    return self._output_specs


@tf_keras.utils.register_keras_serializable(package='Vision')
class Scale(tf_keras.layers.Layer):
  """Scales the input by a trainable scalar weight.

  This is useful for applying ReZero to layers, which improves convergence
  speed. This implements the paper:
  ReZero is All You Need: Fast Convergence at Large Depth.
  (https://arxiv.org/pdf/2003.04887.pdf).
  """

  def __init__(
      self,
      initializer: tf_keras.initializers.Initializer = 'ones',
      regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      **kwargs):
    """Initializes a scale layer.

    Args:
      initializer: A `str` of initializer for the scalar weight.
      regularizer: A `tf_keras.regularizers.Regularizer` for the scalar weight.
      **kwargs: Additional keyword arguments to be passed to this layer.

    Returns:
      An `tf.Tensor` of which should have the same shape as input.
    """
    super(Scale, self).__init__(**kwargs)

    self._initializer = initializer
    self._regularizer = regularizer

    self._scale = self.add_weight(
        name='scale',
        shape=[],
        dtype=self.dtype,
        initializer=self._initializer,
        regularizer=self._regularizer,
        trainable=True)

  def get_config(self):
    """Returns a dictionary containing the config used for initialization."""
    config = {
        'initializer': self._initializer,
        'regularizer': self._regularizer,
    }
    base_config = super(Scale, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def call(self, inputs):
    """Calls the layer with the given inputs."""
    scale = tf.cast(self._scale, inputs.dtype)
    return scale * inputs


@tf_keras.utils.register_keras_serializable(package='Vision')
class TemporalSoftmaxPool(tf_keras.layers.Layer):
  """Creates a network layer corresponding to temporal softmax pooling.

  This is useful for multi-class logits (used in e.g., Charades). Modified from
  AssembleNet Charades evaluation from:

  Michael S. Ryoo, AJ Piergiovanni, Mingxing Tan, Anelia Angelova.
  AssembleNet: Searching for Multi-Stream Neural Connectivity in Video
  Architectures.
  (https://arxiv.org/pdf/1905.13209.pdf).
  """

  def call(self, inputs):
    """Calls the layer with the given inputs."""
    assert inputs.shape.rank in (3, 4, 5)
    frames = tf.shape(inputs)[1]
    pre_logits = inputs / tf.sqrt(tf.cast(frames, inputs.dtype))
    activations = tf.nn.softmax(pre_logits, axis=1)
    outputs = inputs * activations
    return outputs


@tf_keras.utils.register_keras_serializable(package='Vision')
class PositionalEncoding(tf_keras.layers.Layer):
  """Creates a network layer that adds a sinusoidal positional encoding.

  Positional encoding is incremented across frames, and is added to the input.
  The positional encoding is first weighted at 0 so that the network can choose
  to ignore it. This implements:

  Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones,
  Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin.
  Attention Is All You Need.
  (https://arxiv.org/pdf/1706.03762.pdf).
  """

  def __init__(self,
               initializer: tf_keras.initializers.Initializer = 'zeros',
               cache_encoding: bool = False,
               state_prefix: Optional[str] = None,
               **kwargs):
    """Initializes positional encoding.

    Args:
      initializer: A `str` of initializer for weighting the positional encoding.
      cache_encoding: A `bool`. If True, cache the positional encoding tensor
        after calling build. Otherwise, rebuild the tensor for every call.
        Setting this to False can be useful when we want to input a variable
        number of frames, so the positional encoding tensor can change shape.
      state_prefix: a prefix string to identify states.
      **kwargs: Additional keyword arguments to be passed to this layer.

    Returns:
      A `tf.Tensor` of which should have the same shape as input.
    """
    super(PositionalEncoding, self).__init__(**kwargs)
    self._initializer = initializer
    self._cache_encoding = cache_encoding
    self._pos_encoding = None
    self._rezero = Scale(initializer=initializer, name='rezero')
    state_prefix = state_prefix if state_prefix is not None else ''
    self._state_prefix = state_prefix
    self._frame_count_name = f'{state_prefix}_pos_enc_frame_count'

  def get_config(self):
    """Returns a dictionary containing the config used for initialization."""
    config = {
        'initializer': self._initializer,
        'cache_encoding': self._cache_encoding,
        'state_prefix': self._state_prefix,
    }
    base_config = super(PositionalEncoding, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def _positional_encoding(self,
                           num_positions: Union[int, tf.Tensor],
                           hidden_size: Union[int, tf.Tensor],
                           start_position: Union[int, tf.Tensor] = 0,
                           dtype: str = 'float32') -> tf.Tensor:
    """Creates a sequence of sinusoidal positional encoding vectors.

    Args:
      num_positions: the total number of positions (frames).
      hidden_size: the number of channels used for the hidden vectors.
      start_position: the start position.
      dtype: the dtype of the output tensor.

    Returns:
      The positional encoding tensor with shape [num_positions, hidden_size].
    """
    if isinstance(start_position, tf.Tensor) and start_position.shape.rank == 1:
      start_position = start_position[0]

    # Calling `tf.range` with `dtype=tf.bfloat16` results in an error,
    # so we cast afterward.
    positions = tf.range(start_position, start_position + num_positions)
    positions = tf.cast(positions, dtype)[:, tf.newaxis]
    idx = tf.range(hidden_size)[tf.newaxis, :]

    power = tf.cast(2 * (idx // 2), dtype)
    power /= tf.cast(hidden_size, dtype)
    angles = 1. / tf.math.pow(10_000., power)
    radians = positions * angles

    sin = tf.math.sin(radians[:, 0::2])
    cos = tf.math.cos(radians[:, 1::2])
    pos_encoding = tf.concat([sin, cos], axis=-1)

    return pos_encoding

  def _get_pos_encoding(self,
                        input_shape: tf.Tensor,
                        frame_count: int = 0) -> tf.Tensor:
    """Calculates the positional encoding from the input shape.

    Args:
      input_shape: the shape of the input.
      frame_count: a count of frames that indicates the index of the first
        frame.

    Returns:
      The positional encoding tensor with shape [num_positions, hidden_size].

    """
    frames = input_shape[1]
    channels = input_shape[-1]
    pos_encoding = self._positional_encoding(
        frames, channels, start_position=frame_count, dtype=self.dtype)
    pos_encoding = tf.reshape(pos_encoding, [1, frames, 1, 1, channels])
    return pos_encoding

  def build(self, input_shape):
    """Builds the layer with the given input shape.

    Args:
      input_shape: The input shape.

    Raises:
      ValueError: If using 'channels_first' data format.
    """
    if tf_keras.backend.image_data_format() == 'channels_first':
      raise ValueError('"channels_first" mode is unsupported.')

    if self._cache_encoding:
      self._pos_encoding = self._get_pos_encoding(input_shape)

    super(PositionalEncoding, self).build(input_shape)

  def call(
      self,
      inputs: tf.Tensor,
      states: Optional[States] = None,
      output_states: bool = True,
  ) -> Union[tf.Tensor, Tuple[tf.Tensor, States]]:
    """Calls the layer with the given inputs.

    Args:
      inputs: An input `tf.Tensor`.
      states: A `dict` of states such that, if any of the keys match for this
        layer, will overwrite the contents of the buffer(s). Expected keys
        include `state_prefix + '_pos_enc_frame_count'`.
      output_states: A `bool`. If True, returns the output tensor and output
        states. Returns just the output tensor otherwise.

    Returns:
      An output `tf.Tensor` (and optionally the states if `output_states=True`).

    Raises:
      ValueError: If using 'channels_first' data format.
    """
    states = dict(states) if states is not None else {}

    # Keep a count of frames encountered across input iterations in
    # num_frames to be able to accurately update the positional encoding.
    num_frames = tf.shape(inputs)[1]
    frame_count = tf.cast(states.get(self._frame_count_name, [0]), tf.int32)
    states[self._frame_count_name] = frame_count + num_frames

    if self._cache_encoding:
      pos_encoding = self._pos_encoding
    else:
      pos_encoding = self._get_pos_encoding(
          tf.shape(inputs), frame_count=frame_count)
    pos_encoding = tf.cast(pos_encoding, inputs.dtype)
    pos_encoding = self._rezero(pos_encoding)
    outputs = inputs + pos_encoding

    return (outputs, states) if output_states else outputs


@tf_keras.utils.register_keras_serializable(package='Vision')
class GlobalAveragePool3D(tf_keras.layers.Layer):
  """Creates a global average pooling layer with causal mode.

  Implements causal mode, which runs a cumulative sum (with `tf.cumsum`) across
  frames in the time dimension, allowing the use of a stream buffer. Sums any
  valid input state with the current input to allow state to accumulate over
  several iterations.
  """

  def __init__(self,
               keepdims: bool = False,
               causal: bool = False,
               state_prefix: Optional[str] = None,
               **kwargs):
    """Initializes a global average pool layer.

    Args:
      keepdims: A `bool`. If True, keep the averaged dimensions.
      causal: A `bool` of whether to run in causal mode with a cumulative sum
        across frames.
      state_prefix: a prefix string to identify states.
      **kwargs: Additional keyword arguments to be passed to this layer.

    Returns:
      An output `tf.Tensor`.
    """
    super(GlobalAveragePool3D, self).__init__(**kwargs)

    self._keepdims = keepdims
    self._causal = causal
    state_prefix = state_prefix if state_prefix is not None else ''
    self._state_prefix = state_prefix

    self._state_name = f'{state_prefix}_pool_buffer'
    self._frame_count_name = f'{state_prefix}_pool_frame_count'

  def get_config(self):
    """Returns a dictionary containing the config used for initialization."""
    config = {
        'keepdims': self._keepdims,
        'causal': self._causal,
        'state_prefix': self._state_prefix,
    }
    base_config = super(GlobalAveragePool3D, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def call(self,
           inputs: tf.Tensor,
           states: Optional[States] = None,
           output_states: bool = False
           ) -> Union[tf.Tensor, Tuple[tf.Tensor, States]]:
    """Calls the layer with the given inputs.

    Args:
      inputs: An input `tf.Tensor`.
      states: A `dict` of states such that, if any of the keys match for this
        layer, will overwrite the contents of the buffer(s).
        Expected keys include `state_prefix + '__pool_buffer'` and
        `state_prefix + '__pool_frame_count'`.
      output_states: A `bool`. If True, returns the output tensor and output
        states. Returns just the output tensor otherwise.

    Returns:
      An output `tf.Tensor` (and optionally the states if `output_states=True`).
      If `causal=True`, the output tensor will have shape
      `[batch_size, num_frames, 1, 1, channels]` if `keepdims=True`. We keep
      the frame dimension in this case to simulate a cumulative global average
      as if we are inputting one frame at a time. If `causal=False`, the output
      is equivalent to `tf_keras.layers.GlobalAveragePooling3D` with shape
      `[batch_size, 1, 1, 1, channels]` if `keepdims=True` (plus the optional
      buffer stored in `states`).

    Raises:
      ValueError: If using 'channels_first' data format.
    """
    states = dict(states) if states is not None else {}

    if tf_keras.backend.image_data_format() == 'channels_first':
      raise ValueError('"channels_first" mode is unsupported.')

    # Shape: [batch_size, 1, 1, 1, channels]
    buffer = states.get(self._state_name, None)
    if buffer is None:
      buffer = tf.zeros_like(inputs[:, :1, :1, :1], dtype=inputs.dtype)
      states[self._state_name] = buffer

    # Keep a count of frames encountered across input iterations in
    # num_frames to be able to accurately take a cumulative average across
    # all frames when running in streaming mode
    num_frames = tf.shape(inputs)[1]
    frame_count = states.get(self._frame_count_name, tf.constant([0]))
    frame_count = tf.cast(frame_count, tf.int32)
    states[self._frame_count_name] = frame_count + num_frames

    if self._causal:
      # Take a mean of spatial dimensions to make computation more efficient.
      x = tf.reduce_mean(inputs, axis=[2, 3], keepdims=True)
      x = tf.cumsum(x, axis=1)
      x = x + buffer

      # The last frame will be the value of the next state
      # Shape: [batch_size, 1, 1, 1, channels]
      states[self._state_name] = x[:, -1:]

      # In causal mode, the divisor increments by 1 for every frame to
      # calculate cumulative averages instead of one global average
      mean_divisors = tf.range(num_frames) + frame_count + 1
      mean_divisors = tf.reshape(mean_divisors, [1, num_frames, 1, 1, 1])
      mean_divisors = tf.cast(mean_divisors, x.dtype)

      # Shape: [batch_size, num_frames, 1, 1, channels]
      x = x / mean_divisors
    else:
      # In non-causal mode, we (optionally) sum across frames to take a
      # cumulative average across input iterations rather than individual
      # frames. If no buffer state is passed, this essentially becomes
      # regular global average pooling.
      # Shape: [batch_size, 1, 1, 1, channels]
      x = tf.reduce_sum(inputs, axis=(1, 2, 3), keepdims=True)
      x = x / tf.cast(tf.shape(inputs)[2] * tf.shape(inputs)[3], x.dtype)
      x = x + buffer

      # Shape: [batch_size, 1, 1, 1, channels]
      states[self._state_name] = x

      x = x / tf.cast(frame_count + num_frames, x.dtype)

    if not self._keepdims:
      x = tf.squeeze(x, axis=(1, 2, 3))

    return (x, states) if output_states else x


@tf_keras.utils.register_keras_serializable(package='Vision')
class SpatialAveragePool3D(tf_keras.layers.Layer):
  """Creates a global average pooling layer pooling across spatial dimentions."""

  def __init__(self, keepdims: bool = False, **kwargs):
    """Initializes a global average pool layer.

    Args:
      keepdims: A `bool`. If True, keep the averaged dimensions.
      **kwargs: Additional keyword arguments to be passed to this layer.

    Returns:
      An output `tf.Tensor`.
    """
    super(SpatialAveragePool3D, self).__init__(**kwargs)
    self._keepdims = keepdims

  def get_config(self):
    """Returns a dictionary containing the config used for initialization."""
    config = {
        'keepdims': self._keepdims,
    }
    base_config = super(SpatialAveragePool3D, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def build(self, input_shape):
    """Builds the layer with the given input shape."""
    if tf_keras.backend.image_data_format() == 'channels_first':
      raise ValueError('"channels_first" mode is unsupported.')

    super(SpatialAveragePool3D, self).build(input_shape)

  def call(self, inputs, states=None, output_states: bool = False):
    """Calls the layer with the given inputs."""
    if inputs.shape.rank != 5:
      raise ValueError(
          'Input should have rank {}, got {}'.format(5, inputs.shape.rank))

    output = tf.reduce_mean(inputs, axis=(2, 3), keepdims=self._keepdims)
    return (output, states) if output_states else output


class CausalConvMixin:
  """Mixin class to implement CausalConv for `tf_keras.layers.Conv` layers."""

  @property
  def use_buffered_input(self) -> bool:
    return self._use_buffered_input

  @use_buffered_input.setter
  def use_buffered_input(self, variable: bool):
    self._use_buffered_input = variable

  def _compute_buffered_causal_padding(self,
                                       inputs: tf.Tensor,
                                       use_buffered_input: bool = False,
                                       time_axis: int = 1,
                                       ) -> List[List[int]]:
    """Calculates padding for 'causal' option for conv layers.

    Args:
      inputs: An optional input `tf.Tensor` to be padded.
      use_buffered_input: A `bool`. If True, use 'valid' padding along the time
        dimension. This should be set when applying the stream buffer.
      time_axis: An `int` of the axis of the time dimension.

    Returns:
      A list of paddings for `tf.pad`.
    """
    input_shape = tf.shape(inputs)[1:-1]

    if tf_keras.backend.image_data_format() == 'channels_first':
      raise ValueError('"channels_first" mode is unsupported.')

    kernel_size_effective = [
        (self.kernel_size[i] +
         (self.kernel_size[i] - 1) * (self.dilation_rate[i] - 1))
        for i in range(self.rank)
    ]
    pad_total = [kernel_size_effective[0] - 1]
    for i in range(1, self.rank):
      overlap = (input_shape[i] - 1) % self.strides[i] + 1
      pad_total.append(tf.maximum(kernel_size_effective[i] - overlap, 0))
    pad_beg = [pad_total[i] // 2 for i in range(self.rank)]
    pad_end = [pad_total[i] - pad_beg[i] for i in range(self.rank)]
    padding = [[pad_beg[i], pad_end[i]] for i in range(self.rank)]
    padding = [[0, 0]] + padding + [[0, 0]]

    if use_buffered_input:
      padding[time_axis] = [0, 0]
    else:
      padding[time_axis] = [padding[time_axis][0] + padding[time_axis][1], 0]
    return padding

  def _causal_validate_init(self):
    """Validates the Conv layer initial configuration."""
    # Overriding this method is meant to circumvent unnecessary errors when
    # using causal padding.
    if (self.filters is not None
        and self.filters % self.groups != 0):
      raise ValueError(
          'The number of filters must be evenly divisible by the number of '
          'groups. Received: groups={}, filters={}'.format(
              self.groups, self.filters))

    if not all(self.kernel_size):
      raise ValueError('The argument `kernel_size` cannot contain 0(s). '
                       'Received: %s' % (self.kernel_size,))

  def _buffered_spatial_output_shape(self, spatial_output_shape: List[int]):
    """Computes the spatial output shape from the input shape."""
    # When buffer padding, use 'valid' padding across time. The output shape
    # across time should be the input shape minus any padding, assuming
    # the stride across time is 1.
    if self._use_buffered_input and spatial_output_shape[0] is not None:
      padding = self._compute_buffered_causal_padding(
          tf.zeros([1] + spatial_output_shape + [1]), use_buffered_input=False)
      spatial_output_shape[0] -= sum(padding[1])
    return spatial_output_shape


@tf_keras.utils.register_keras_serializable(package='Vision')
class Conv2D(tf_keras.layers.Conv2D, CausalConvMixin):
  """Conv2D layer supporting CausalConv.

  Supports `padding='causal'` option (like in `tf_keras.layers.Conv1D`),
  which applies causal padding to the temporal dimension, and same padding in
  the spatial dimensions.
  """

  def __init__(self, *args, use_buffered_input=False, **kwargs):
    """Initializes conv2d.

    Args:
      *args: Arguments to be passed.
      use_buffered_input: A `bool`. If True, the input is expected to be padded
        beforehand. In effect, calling this layer will use 'valid' padding on
        the temporal dimension to simulate 'causal' padding.
      **kwargs: Additional keyword arguments to be passed.

    Returns:
      An output `tf.Tensor` of the Conv2D operation.
    """
    super(Conv2D, self).__init__(*args, **kwargs)
    self._use_buffered_input = use_buffered_input

  def get_config(self):
    """Returns a dictionary containing the config used for initialization."""
    config = {
        'use_buffered_input': self._use_buffered_input,
    }
    base_config = super(Conv2D, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def _compute_causal_padding(self, inputs):
    """Computes causal padding dimensions for the given inputs."""
    return self._compute_buffered_causal_padding(
        inputs, use_buffered_input=self._use_buffered_input)

  def _validate_init(self):
    """Validates the Conv layer initial configuration."""
    self._causal_validate_init()

  def _spatial_output_shape(self, spatial_input_shape: List[int]):
    """Computes the spatial output shape from the input shape."""
    shape = super(Conv2D, self)._spatial_output_shape(spatial_input_shape)
    return self._buffered_spatial_output_shape(shape)


@tf_keras.utils.register_keras_serializable(package='Vision')
class DepthwiseConv2D(tf_keras.layers.DepthwiseConv2D, CausalConvMixin):
  """DepthwiseConv2D layer supporting CausalConv.

  Supports `padding='causal'` option (like in `tf_keras.layers.Conv1D`),
  which applies causal padding to the temporal dimension, and same padding in
  the spatial dimensions.
  """

  def __init__(self, *args, use_buffered_input=False, **kwargs):
    """Initializes depthwise conv2d.

    Args:
      *args: Arguments to be passed.
      use_buffered_input: A `bool`. If True, the input is expected to be padded
        beforehand. In effect, calling this layer will use 'valid' padding on
        the temporal dimension to simulate 'causal' padding.
      **kwargs: Additional keyword arguments to be passed.

    Returns:
      An output `tf.Tensor` of the DepthwiseConv2D operation.
    """
    super(DepthwiseConv2D, self).__init__(*args, **kwargs)
    self._use_buffered_input = use_buffered_input

    # Causal padding is unsupported by default for DepthwiseConv2D,
    # so we resort to valid padding internally. However, we handle
    # causal padding as a special case with `self._is_causal`, which is
    # defined by the super class.
    if self.padding == 'causal':
      self.padding = 'valid'

  def get_config(self):
    """Returns a dictionary containing the config used for initialization."""
    config = {
        'use_buffered_input': self._use_buffered_input,
    }
    base_config = super(DepthwiseConv2D, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def call(self, inputs):
    """Calls the layer with the given inputs."""
    if self._is_causal:
      inputs = tf.pad(inputs, self._compute_causal_padding(inputs))
    return super(DepthwiseConv2D, self).call(inputs)

  def _compute_causal_padding(self, inputs):
    """Computes causal padding dimensions for the given inputs."""
    return self._compute_buffered_causal_padding(
        inputs, use_buffered_input=self._use_buffered_input)

  def _validate_init(self):
    """Validates the Conv layer initial configuration."""
    self._causal_validate_init()

  def _spatial_output_shape(self, spatial_input_shape: List[int]):
    """Computes the spatial output shape from the input shape."""
    shape = super(DepthwiseConv2D, self)._spatial_output_shape(
        spatial_input_shape)
    return self._buffered_spatial_output_shape(shape)


@tf_keras.utils.register_keras_serializable(package='Vision')
class Conv3D(tf_keras.layers.Conv3D, CausalConvMixin):
  """Conv3D layer supporting CausalConv.

  Supports `padding='causal'` option (like in `tf_keras.layers.Conv1D`),
  which applies causal padding to the temporal dimension, and same padding in
  the spatial dimensions.
  """

  def __init__(self, *args, use_buffered_input=False, **kwargs):
    """Initializes conv3d.

    Args:
      *args: Arguments to be passed.
      use_buffered_input: A `bool`. If True, the input is expected to be padded
        beforehand. In effect, calling this layer will use 'valid' padding on
        the temporal dimension to simulate 'causal' padding.
      **kwargs: Additional keyword arguments to be passed.

    Returns:
      An output `tf.Tensor` of the Conv3D operation.
    """
    super(Conv3D, self).__init__(*args, **kwargs)
    self._use_buffered_input = use_buffered_input

  def get_config(self):
    """Returns a dictionary containing the config used for initialization."""
    config = {
        'use_buffered_input': self._use_buffered_input,
    }
    base_config = super(Conv3D, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def call(self, inputs):
    """Call the layer with the given inputs."""
    # Note: tf.nn.conv3d with depthwise kernels on CPU is currently only
    # supported when compiling with TF graph (XLA) using tf.function, so it
    # is compiled by default here (b/186463870).
    conv_fn = tf.function(super(Conv3D, self).call, jit_compile=True)
    return conv_fn(inputs)

  def _compute_causal_padding(self, inputs):
    """Computes causal padding dimensions for the given inputs."""
    return self._compute_buffered_causal_padding(
        inputs, use_buffered_input=self._use_buffered_input)

  def _validate_init(self):
    """Validates the Conv layer initial configuration."""
    self._causal_validate_init()

  def _spatial_output_shape(self, spatial_input_shape: List[int]):
    """Computes the spatial output shape from the input shape."""
    shape = super(Conv3D, self)._spatial_output_shape(spatial_input_shape)
    return self._buffered_spatial_output_shape(shape)


@tf_keras.utils.register_keras_serializable(package='Vision')
class SpatialPyramidPooling(tf_keras.layers.Layer):
  """Implements the Atrous Spatial Pyramid Pooling.

  References:
    [Rethinking Atrous Convolution for Semantic Image Segmentation](
      https://arxiv.org/pdf/1706.05587.pdf)
    [Encoder-Decoder with Atrous Separable Convolution for Semantic Image
    Segmentation](https://arxiv.org/pdf/1802.02611.pdf)
  """

  def __init__(
      self,
      output_channels: int,
      dilation_rates: List[int],
      pool_kernel_size: Optional[List[int]] = None,
      use_sync_bn: bool = False,
      batchnorm_momentum: float = 0.99,
      batchnorm_epsilon: float = 0.001,
      activation: str = 'relu',
      dropout: float = 0.5,
      kernel_initializer: str = 'GlorotUniform',
      kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      interpolation: str = 'bilinear',
      use_depthwise_convolution: bool = False,
      **kwargs):
    """Initializes `SpatialPyramidPooling`.

    Args:
      output_channels: Number of channels produced by SpatialPyramidPooling.
      dilation_rates: A list of integers for parallel dilated conv.
      pool_kernel_size: A list of integers or None. If None, global average
        pooling is applied, otherwise an average pooling of pool_kernel_size is
        applied.
      use_sync_bn: A bool, whether or not to use sync batch normalization.
      batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to
        0.99.
      batchnorm_epsilon: A float for the epsilon value in BatchNorm. Defaults to
        0.001.
      activation: A `str` for type of activation to be used. Defaults to 'relu'.
      dropout: A float for the dropout rate before output. Defaults to 0.5.
      kernel_initializer: Kernel initializer for conv layers. Defaults to
        `glorot_uniform`.
      kernel_regularizer: Kernel regularizer for conv layers. Defaults to None.
      interpolation: The interpolation method for upsampling. Defaults to
        `bilinear`.
      use_depthwise_convolution: Allows spatial pooling to be separable
        depthwise convolusions. [Encoder-Decoder with Atrous Separable
        Convolution for Semantic Image Segmentation](
         https://arxiv.org/pdf/1802.02611.pdf)
      **kwargs: Other keyword arguments for the layer.
    """
    super().__init__(**kwargs)

    self._output_channels = output_channels
    self._dilation_rates = dilation_rates
    self._use_sync_bn = use_sync_bn
    self._batchnorm_momentum = batchnorm_momentum
    self._batchnorm_epsilon = batchnorm_epsilon
    self._activation = activation
    self._dropout = dropout
    self._kernel_initializer = kernel_initializer
    self._kernel_regularizer = kernel_regularizer
    self._interpolation = interpolation
    self._pool_kernel_size = pool_kernel_size
    self._use_depthwise_convolution = use_depthwise_convolution
    self._activation_fn = tf_utils.get_activation(activation)
    self._bn_op = tf_keras.layers.BatchNormalization

    if tf_keras.backend.image_data_format() == 'channels_last':
      self._bn_axis = -1
    else:
      self._bn_axis = 1

  def build(self, input_shape):
    height = input_shape[1]
    width = input_shape[2]
    channels = input_shape[3]

    self.aspp_layers = []

    conv1 = tf_keras.layers.Conv2D(
        filters=self._output_channels,
        kernel_size=(1, 1),
        kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
        kernel_regularizer=self._kernel_regularizer,
        use_bias=False)
    norm1 = self._bn_op(
        axis=self._bn_axis,
        momentum=self._batchnorm_momentum,
        epsilon=self._batchnorm_epsilon,
        synchronized=self._use_sync_bn)

    self.aspp_layers.append([conv1, norm1])

    for dilation_rate in self._dilation_rates:
      leading_layers = []
      kernel_size = (3, 3)
      if self._use_depthwise_convolution:
        leading_layers += [
            tf_keras.layers.DepthwiseConv2D(
                depth_multiplier=1,
                kernel_size=kernel_size,
                padding='same',
                depthwise_regularizer=self._kernel_regularizer,
                depthwise_initializer=tf_utils.clone_initializer(
                    self._kernel_initializer),
                dilation_rate=dilation_rate,
                use_bias=False)
        ]
        kernel_size = (1, 1)
      conv_dilation = leading_layers + [
          tf_keras.layers.Conv2D(
              filters=self._output_channels,
              kernel_size=kernel_size,
              padding='same',
              kernel_regularizer=self._kernel_regularizer,
              kernel_initializer=tf_utils.clone_initializer(
                  self._kernel_initializer),
              dilation_rate=dilation_rate,
              use_bias=False)
      ]
      norm_dilation = self._bn_op(
          axis=self._bn_axis,
          momentum=self._batchnorm_momentum,
          epsilon=self._batchnorm_epsilon,
          synchronized=self._use_sync_bn)

      self.aspp_layers.append(conv_dilation + [norm_dilation])

    if self._pool_kernel_size is None:
      pooling = [
          tf_keras.layers.GlobalAveragePooling2D(),
          tf_keras.layers.Reshape((1, 1, channels))
      ]
    else:
      pooling = [tf_keras.layers.AveragePooling2D(self._pool_kernel_size)]

    conv2 = tf_keras.layers.Conv2D(
        filters=self._output_channels,
        kernel_size=(1, 1),
        kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
        kernel_regularizer=self._kernel_regularizer,
        use_bias=False)
    norm2 = self._bn_op(
        axis=self._bn_axis,
        momentum=self._batchnorm_momentum,
        epsilon=self._batchnorm_epsilon,
        synchronized=self._use_sync_bn)

    self.aspp_layers.append(pooling + [conv2, norm2])

    self._resizing_layer = tf_keras.layers.Resizing(
        height, width, interpolation=self._interpolation, dtype=tf.float32)

    self._projection = [
        tf_keras.layers.Conv2D(
            filters=self._output_channels,
            kernel_size=(1, 1),
            kernel_initializer=tf_utils.clone_initializer(
                self._kernel_initializer),
            kernel_regularizer=self._kernel_regularizer,
            use_bias=False),
        self._bn_op(
            axis=self._bn_axis,
            momentum=self._batchnorm_momentum,
            epsilon=self._batchnorm_epsilon,
            synchronized=self._use_sync_bn)
    ]
    self._dropout_layer = tf_keras.layers.Dropout(rate=self._dropout)
    self._concat_layer = tf_keras.layers.Concatenate(axis=-1)

  def call(self,
           inputs: tf.Tensor,
           training: Optional[bool] = None) -> tf.Tensor:
    if training is None:
      training = tf_keras.backend.learning_phase()
    result = []
    for i, layers in enumerate(self.aspp_layers):
      x = inputs
      for layer in layers:
        # Apply layers sequentially.
        x = layer(x, training=training)
      x = self._activation_fn(x)

      # Apply resize layer to the end of the last set of layers.
      if i == len(self.aspp_layers) - 1:
        x = self._resizing_layer(x)

      result.append(tf.cast(x, inputs.dtype))
    x = self._concat_layer(result)
    for layer in self._projection:
      x = layer(x, training=training)
    x = self._activation_fn(x)
    return self._dropout_layer(x)

  def get_config(self):
    config = {
        'output_channels': self._output_channels,
        'dilation_rates': self._dilation_rates,
        'pool_kernel_size': self._pool_kernel_size,
        'use_sync_bn': self._use_sync_bn,
        'batchnorm_momentum': self._batchnorm_momentum,
        'batchnorm_epsilon': self._batchnorm_epsilon,
        'activation': self._activation,
        'dropout': self._dropout,
        'kernel_initializer': self._kernel_initializer,
        'kernel_regularizer': self._kernel_regularizer,
        'interpolation': self._interpolation,
    }
    base_config = super().get_config()
    return dict(list(base_config.items()) + list(config.items()))


@tf_keras.utils.register_keras_serializable(package='Vision')
class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
  """MultiHeadAttention layer.

  This is an implementation of multi-headed attention as described in the paper
  "Attention is all you Need" (Vaswani et al., 2017).
  """

  def __init__(
      self,
      *args,
      partition_dims: Optional[Tuple[int, int, int, int]] = None,
      max_inference_parallelism: Optional[int] = None,
      **kwargs,
  ):
    """Initializes MultiHeadAttention.

    Args:
      *args: Positional arguments passed to super().__init__.
      partition_dims: Spatial partition dimensions.
      max_inference_parallelism: The number of examples to run in parallel
        during inference. Set this limit to reduce the peak memory usage. If
        None, use vectorized operations to run the whole batch in parallel.
      **kwargs: Keyword arguments passed to super().__init__.
    """
    super().__init__(*args, **kwargs)
    self._partition_dims = partition_dims
    self._max_inference_parallelism = max_inference_parallelism

  def get_config(self):
    config = super().get_config()
    config.update({
        'partition_dims': self._partition_dims,
        'max_inference_parallelism': self._max_inference_parallelism,
    })
    return config

  def _compute_attention(
      self,
      query: tf.Tensor,
      key: tf.Tensor,
      value: tf.Tensor,
      attention_mask: Optional[tf.Tensor] = None,
      training: Optional[bool] = None,
  ):
    """Applies dot-product attention with query, key, value tensors.

    Args:
      query: Projected query `Tensor` of shape `(B, T, N, key_dim)`.
      key: Projected key `Tensor` of shape `(B, S, N, key_dim)`.
      value: Projected value `Tensor` of shape `(B, S, N, value_dim)`.
      attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
        attention to certain positions. It is generally not needed if the
        `query` and `value` (and/or `key`) are masked.
      training: Python boolean indicating whether the layer should behave in
        training mode (adding dropout) or in inference mode (doing nothing).

    Returns:
      attention_output: Multi-headed outputs of attention computation.
      attention_scores: Multi-headed attention weights.
    """
    if self._partition_dims is not None:
      strategy = tf.distribute.get_strategy()
      # `query` = [B, T, N ,H]
      query = strategy.experimental_split_to_logical_devices(
          query, self._partition_dims)
      key = strategy.experimental_split_to_logical_devices(
          key, self._partition_dims)
      value = strategy.experimental_split_to_logical_devices(
          value, self._partition_dims)

    batch_size = query.get_shape().as_list()[0]  # None if dynamic.

    if (
        training
        or self._max_inference_parallelism is None
        or self._max_inference_parallelism <= 0
        or (
            # If the whole batch is allowed to be run in parallel, use fully
            # vectorized computation instead of tf.map_fn to make things more
            # efficient.
            batch_size is not None
            and batch_size <= self._max_inference_parallelism
        )
    ):
      return self._compute_attention_delegate(
          query, key, value, attention_mask, training
      )
    else:
      # Sequentialize the inference execution with limited parallelism.
      def _compute_fn(x):
        attention_output, attention_scores = self._compute_attention_delegate(
            query=x[0][tf.newaxis, ...],
            key=x[1][tf.newaxis, ...],
            value=x[2][tf.newaxis, ...],
            attention_mask=x[3][tf.newaxis, ...] if len(x) >= 4 else None,
            training=training,
        )
        attention_output = tf.squeeze(attention_output, axis=0)
        attention_scores = tf.squeeze(attention_scores, axis=0)
        return attention_output, attention_scores

      if attention_mask is not None:
        elems = [query, key, value, attention_mask]
      else:
        elems = [query, key, value]

      return tf.map_fn(
          fn=_compute_fn,
          elems=elems,
          fn_output_signature=(value.dtype, value.dtype),
          parallel_iterations=self._max_inference_parallelism,
      )

  def _compute_attention_delegate(
      self,
      query: tf.Tensor,
      key: tf.Tensor,
      value: tf.Tensor,
      attention_mask: Optional[tf.Tensor] = None,
      training: Optional[bool] = None,
  ):
    """Implements dot-product attention with query, key, value tensors."""
    # Simply calls the implementation of the super class here, while the users
    # can override this function for customizing attention computation.
    return super()._compute_attention(
        query, key, value, attention_mask, training
    )