tensorflow/models

View on GitHub
official/projects/maxvit/modeling/maxvit.py

Summary

Maintainability
F
3 days
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=logging-fstring-interpolation
r"""MaxViT layers and model class."""

import functools
from typing import Any, Mapping, Optional, Tuple, Union

from absl import logging
import tensorflow as tf, tf_keras

from official.projects.maxvit.modeling import common_ops as ops
from official.projects.maxvit.modeling import layers
from official.vision.modeling.backbones import factory


MAXVIT_SPECS = {
    'maxvit-tiny-for-test': dict(
        survival_prob=None,
        stem_hsize=(8, 8),
        block_type=('maxvit', 'maxvit', 'maxvit', 'maxvit'),
        num_blocks=(2, 3, 3, 2),
        hidden_size=(32, 32, 32, 768),
    ),
    'maxvit-tiny': dict(
        survival_prob=0.8,
        stem_hsize=(64, 64),
        block_type=('maxvit', 'maxvit', 'maxvit', 'maxvit'),
        num_blocks=(2, 2, 5, 2),
        hidden_size=(64, 128, 256, 512),
    ),
    'maxvit-small': dict(
        survival_prob=0.7,
        stem_hsize=(64, 64),
        block_type=('maxvit', 'maxvit', 'maxvit', 'maxvit'),
        num_blocks=(2, 2, 5, 2),
        hidden_size=(96, 192, 384, 768),
    ),
    'maxvit-base': dict(
        survival_prob=0.6,
        stem_hsize=(64, 64),
        block_type=('maxvit', 'maxvit', 'maxvit', 'maxvit'),
        num_blocks=(2, 6, 14, 2),
        hidden_size=(96, 192, 384, 768),
    ),
    'maxvit-large': dict(
        survival_prob=0.4,
        stem_hsize=(128, 128),
        block_type=('maxvit', 'maxvit', 'maxvit', 'maxvit'),
        num_blocks=(2, 6, 14, 2),
        hidden_size=(128, 256, 512, 1024),
    ),
    'maxvit-xlarge': dict(
        survival_prob=0.3,
        stem_hsize=(192, 192),
        block_type=('maxvit', 'maxvit', 'maxvit', 'maxvit'),
        num_blocks=(2, 6, 14, 2),
        hidden_size=(192, 384, 768, 1536),
    ),
}


class MaxViTBlock(tf_keras.layers.Layer):
  """MaxViT block = MBConv + Block-Attention + FFN + Grid-Attention + FFN."""

  def __init__(
      self,
      hidden_size: int,
      head_size: int,
      window_size: int,
      grid_size: int,
      num_heads: Optional[int] = None,
      downsample_loc: str = 'depth_conv',
      data_format: str = 'channels_last',
      kernel_size: int = 3,
      expansion_rate: int = 4,
      se_ratio: float = 0.25,
      activation: str = 'gelu',
      pool_type: str = '2d:avg',
      pool_stride: int = 1,
      dropcnn: Optional[float] = None,
      dropatt: Optional[Union[float, tf.Tensor]] = None,
      dropout: Optional[Union[float, tf.Tensor]] = None,
      rel_attn_type: Optional[str] = None,
      scale_ratio: Optional[str] = None,
      survival_prob: Optional[Union[float, tf.Tensor]] = None,
      ln_epsilon: float = 1e-5,
      ln_dtype: Optional[tf.DType] = None,
      norm_type: str = 'sync_batch_norm',
      bn_epsilon: float = 1e-3,
      bn_momentum: float = 0.99,
      kernel_initializer: Optional[str] = 'glorot_uniform',
      bias_initializer: Optional[str] = 'zeros',
      name: str = 'maxvit_block',
  ) -> None:
    super().__init__(name=name)

    self._hidden_size = hidden_size
    self._head_size = head_size
    self._window_size = window_size
    self._grid_size = grid_size
    self._num_heads = num_heads
    self._downsample_loc = downsample_loc
    self._data_format = data_format
    self._kernel_size = kernel_size
    self._expansion_rate = expansion_rate
    self._se_ratio = se_ratio
    self._dropcnn = dropcnn
    self._activation = activation
    self._norm_type = norm_type
    self._bn_epsilon = bn_epsilon
    self._bn_momentum = bn_momentum
    self._pool_type = pool_type
    self._pool_stride = pool_stride
    self._dropatt = dropatt
    self._dropout = dropout
    self._rel_attn_type = rel_attn_type
    self._scale_ratio = scale_ratio
    self._survival_prob = survival_prob
    self._ln_epsilon = ln_epsilon
    self._ln_dtype = ln_dtype
    self._kernel_initializer = kernel_initializer
    self._bias_initializer = bias_initializer

  def build(self, input_shape: tf.TensorShape) -> None:
    input_size = input_shape.as_list()[-1]

    if input_size != self._hidden_size:
      self._shortcut_proj = layers.TrailDense(
          self._hidden_size,
          kernel_initializer=self._kernel_initializer,
          bias_initializer=self._bias_initializer,
          name='shortcut_proj',
      )
    else:
      self._shortcut_proj = None

    self._block_attn_layer_norm = tf_keras.layers.LayerNormalization(
        axis=-1,
        epsilon=self._ln_epsilon,
        dtype=self._ln_dtype,
        name='attn_layer_norm',
    )

    self._grid_attn_layer_norm = tf_keras.layers.LayerNormalization(
        axis=-1,
        epsilon=self._ln_epsilon,
        dtype=self._ln_dtype,
        name='attn_layer_norm_1',
    )

    self._block_attention = layers.Attention(
        self._hidden_size,
        self._head_size,
        num_heads=self._num_heads,
        dropatt=self._dropatt,
        rel_attn_type=self._rel_attn_type,
        scale_ratio=self._scale_ratio,
        kernel_initializer=self._kernel_initializer,
        bias_initializer=self._bias_initializer,
        name='attention',
    )

    self._grid_attention = layers.Attention(
        self._hidden_size,
        self._head_size,
        num_heads=self._num_heads,
        dropatt=self._dropatt,
        rel_attn_type=self._rel_attn_type,
        scale_ratio=self._scale_ratio,
        kernel_initializer=self._kernel_initializer,
        bias_initializer=self._bias_initializer,
        name='attention_1',
    )

    self._block_ffn_layer_norm = tf_keras.layers.LayerNormalization(
        axis=-1,
        epsilon=self._ln_epsilon,
        dtype=self._ln_dtype,
        name='ffn_layer_norm',
    )

    self._grid_ffn_layer_norm = tf_keras.layers.LayerNormalization(
        axis=-1,
        epsilon=self._ln_epsilon,
        dtype=self._ln_dtype,
        name='ffn_layer_norm_1',
    )

    self._block_ffn = layers.FFN(
        self._hidden_size,
        dropout=self._dropout,
        expansion_rate=self._expansion_rate,
        activation=self._activation,
        kernel_initializer=self._kernel_initializer,
        bias_initializer=self._bias_initializer,
        name='ffn',
    )

    self._grid_ffn = layers.FFN(
        self._hidden_size,
        dropout=self._dropout,
        expansion_rate=self._expansion_rate,
        activation=self._activation,
        kernel_initializer=self._kernel_initializer,
        bias_initializer=self._bias_initializer,
        name='ffn_1',
    )

    self._mbconv = layers.MBConvBlock(
        self._hidden_size,
        downsample_loc=self._downsample_loc,
        data_format=self._data_format,
        kernel_size=self._kernel_size,
        expansion_rate=self._expansion_rate,
        se_ratio=self._se_ratio,
        activation=self._activation,
        pool_type='avg' if self._pool_type == '2d:avg' else 'max',
        pool_stride=self._pool_stride,
        dropcnn=self._dropcnn,
        survival_prob=self._survival_prob,
        norm_type=self._norm_type,
        bn_epsilon=self._bn_epsilon,
        bn_momentum=self._bn_momentum,
        kernel_initializer=self._kernel_initializer,
        bias_initializer=self._bias_initializer,
        name='mbconv',
    )

  def downsample(self, inputs, name):
    output = inputs
    if self._pool_stride > 1:
      output = ops.maybe_reshape_to_2d(output)
      output = ops.pooling_2d(
          output,
          self._pool_type,
          self._pool_stride,
          padding='same',
          data_format='channels_last',
          name=name,
      )
    return output

  def window_partition(self, features: tf.Tensor) -> tf.Tensor:
    """Partition the input feature maps into non-overlapping windows.

    Note that unsuitable feature or window sizes may be costly on TPU due to
    padding sizes:
    https://docs.google.com/document/d/1GojE1Q7hR2qyi0mIfnTHgERfl7Dmsj6xPQ31MQo3xUk/edit#

    Args:
      features: [B, H, W, C] feature maps.

    Returns:
      Partitioned features: [B, nH, nW, wSize, wSize, c].

    Raises:
      ValueError: If the feature map sizes are not divisible by window sizes.
    """

    _, h, w, c = features.shape
    window_size = self._window_size

    if h % window_size != 0 or w % window_size != 0:
      raise ValueError(
          f'Feature map sizes {(h, w)} '
          f'not divisible by window size ({window_size}).'
      )

    features = tf.reshape(
        features,
        (-1, h // window_size, window_size, w // window_size, window_size, c),
    )
    features = tf.transpose(features, (0, 1, 3, 2, 4, 5))
    features = tf.reshape(features, (-1, window_size, window_size, c))
    return features

  def window_stitch_back(
      self, features: tf.Tensor, window_size: int, h: int, w: int
  ) -> tf.Tensor:
    """Reverse window_partition."""
    features = tf.reshape(
        features,
        [
            -1,
            h // window_size,
            w // window_size,
            window_size,
            window_size,
            features.shape[-1],
        ],
    )
    return tf.reshape(
        tf.transpose(features, (0, 1, 3, 2, 4, 5)),
        [-1, h, w, features.shape[-1]],
    )

  def grid_partition(self, features: tf.Tensor) -> tf.Tensor:
    """Partition the input feature maps into non-overlapping windows.

    Note that unsuitable feature or window sizes may be costly on TPU due to
    padding sizes:
    https://docs.google.com/document/d/1GojE1Q7hR2qyi0mIfnTHgERfl7Dmsj6xPQ31MQo3xUk/edit#

    Args:
      features: [B, H, W, C] feature maps.

    Returns:
      Partitioned features: [B, nH, nW, wSize, wSize, c].

    Raises:
      ValueError: If the feature map sizes are not divisible by window sizes.
    """
    _, h, w, c = features.shape
    grid_size = self._grid_size
    if h % grid_size != 0 or w % grid_size != 0:
      raise ValueError(
          f'Feature map sizes {(h, w)} '
          f'not divisible by window size ({grid_size}).'
      )
    features = tf.reshape(
        features, (-1, grid_size, h // grid_size, grid_size, w // grid_size, c)
    )
    features = tf.transpose(features, (0, 2, 4, 1, 3, 5))
    features = tf.reshape(features, (-1, grid_size, grid_size, c))
    return features

  def grid_stitch_back(
      self, features: tf.Tensor, grid_size: int, h: int, w: int
  ) -> tf.Tensor:
    """Reverse window_partition."""
    features = tf.reshape(
        features,
        [
            -1,
            h // grid_size,
            w // grid_size,
            grid_size,
            grid_size,
            features.shape[-1],
        ],
    )
    return tf.reshape(
        tf.transpose(features, (0, 3, 1, 4, 2, 5)),
        [-1, h, w, features.shape[-1]],
    )

  def block_attn_branch(
      self, inputs: tf.Tensor, training: bool, attn_mask: tf.Tensor
  ) -> tf.Tensor:
    output = self._block_attn_layer_norm(inputs)
    # If put grid-attention in front, we don't need to downsample.
    # Apply local block-attention
    _, h, w, _ = output.shape
    output = self.window_partition(output)
    output = ops.maybe_reshape_to_1d(output)
    output = self._block_attention(output, training, attn_mask=attn_mask)
    output = self.window_stitch_back(output, self._window_size, h, w)
    return output

  def grid_attn_branch(
      self, inputs: tf.Tensor, training: bool, attn_mask: tf.Tensor
  ) -> tf.Tensor:
    output = self._grid_attn_layer_norm(inputs)
    # output = self.downsample(output, 'residual_pool')
    # Apply global grid
    _, h, w, _ = output.shape
    output = self.grid_partition(output)
    output = ops.maybe_reshape_to_1d(output)
    output = self._grid_attention(output, training, attn_mask=attn_mask)
    output = self.grid_stitch_back(output, self._grid_size, h, w)
    return output

  def block_ffn_branch(self, inputs: tf.Tensor, training: bool) -> tf.Tensor:
    output = self._block_ffn_layer_norm(inputs)
    output = self._block_ffn(output, training)
    return output

  def grid_ffn_branch(self, inputs: tf.Tensor, training: bool) -> tf.Tensor:
    output = self._grid_ffn_layer_norm(inputs)
    output = self._grid_ffn(output, training)
    return output

  def mbconv_branch(self, inputs: tf.Tensor, training: bool) -> tf.Tensor:
    output = self._mbconv(inputs, training=training)
    return output

  def call(
      self,
      inputs: tf.Tensor,
      training: bool,
      attn_mask: Optional[tf.Tensor] = None,
  ) -> tf.Tensor:
    logging.debug(
        'Block %s input shape: %s (%s)', self.name, inputs.shape, inputs.dtype
    )

    # MBConv
    output = self.mbconv_branch(inputs, training)

    # block self-attention
    shortcut = output
    output = self.block_attn_branch(output, training, attn_mask)
    if self._dropout:
      output = tf_keras.layers.Dropout(
          self._dropout, name='after_block_attn_drop'
      )(output, training=training)
    output = ops.residual_add(output, shortcut, self._survival_prob, training)

    shortcut = output
    output = self.block_ffn_branch(output, training)
    if self._dropout:
      output = tf_keras.layers.Dropout(
          self._dropout, name='after_block_ffn_drop_1'
      )(output, training=training)
    output = ops.residual_add(output, shortcut, self._survival_prob, training)

    # grid self-attention
    shortcut = output
    output = self.grid_attn_branch(output, training, attn_mask)
    if self._dropout:
      output = tf_keras.layers.Dropout(
          self._dropout, name='after_grid_attn_drop'
      )(output, training=training)
    output = ops.residual_add(output, shortcut, self._survival_prob, training)

    shortcut = output
    output = self.grid_ffn_branch(output, training)
    if self._dropout:
      output = tf_keras.layers.Dropout(
          self._dropout, name='after_grid_ffn_drop'
      )(output, training=training)
    output = ops.residual_add(output, shortcut, self._survival_prob, training)

    return output


class MaxViT(tf_keras.Model):
  """MaxViT's backbone that outputs the pre-global-pooled features."""

  def __init__(
      self,
      block_type: Tuple[str, ...],
      num_blocks: Tuple[int, ...],
      hidden_size: Tuple[int, ...],
      stem_hsize: Tuple[int, ...],
      head_size: int = 32,
      num_heads: Optional[int] = None,
      dropatt: Optional[float] = None,
      dropout: Optional[float] = None,
      rel_attn_type: str = '2d_multi_head',
      window_size: int = 7,
      grid_size: int = 7,
      scale_ratio: Optional[str] = None,
      ln_epsilon: float = 1e-5,
      ln_dtype: Optional[tf.DType] = None,
      downsample_loc: str = 'depth_conv',
      kernel_size: int = 3,
      se_ratio: float = 0.25,
      dropcnn: Optional[float] = None,
      data_format: str = 'channels_last',
      norm_type: str = 'sync_batch_norm',
      bn_epsilon: float = 1e-3,
      bn_momentum: float = 0.99,
      add_pos_enc: bool = False,
      pool_type: str = '2d:avg',
      pool_stride: int = 2,
      expansion_rate: int = 4,
      activation: str = 'gelu',
      survival_prob: Optional[float] = None,
      survival_prob_anneal: bool = True,
      representation_size: Optional[int] = None,
      add_gap_layer_norm: bool = False,
      kernel_initializer: Optional[str] = 'glorot_uniform',
      bias_initializer: Optional[str] = 'zeros',
      name: str = 'maxvit',
      **kwargs,
  ):
    """Initializes MaxViT backbone.

    Args:
      block_type: a tuple of `str`, specify each block type.
      num_blocks: a tuple of `int`, specify the number of blocks in each stage.
      hidden_size: a tuple of `int`, specify hidden size of block in each stage.
      stem_hsize: a tuple of `int`, specify the hidden size of stem network.
      head_size: embedding size of each attention head.
      num_heads: number of attention head.
      dropatt: an optional float of attention dropout rate.
      dropout: an optional float of dropping rate for dropout regularization.
      rel_attn_type: =a `str` specify the type of relative attention head,
        possible values are ['2d_multi_head', '2d_single_head'].
      window_size: window size for conducting block attention module.
      grid_size: grid size for conducting sparse global grid attention.
      scale_ratio: a optional string for finetuning at different window size,
        e.g. '14/7'.
      ln_epsilon: layer normalization epsilon.
      ln_dtype: layer normalization data type.
      downsample_loc: location to conduct downsampleing to feature maps.
      kernel_size: stem convoluation kernal size.
      se_ratio: se ratio for `mbconv` block.
      dropcnn: an optional float of CNN dropout rate.
      data_format: image data format, usualy 'channels_last'.
      norm_type: normalization type, one of ['batch_norm', 'sync_batch_norm',
        'layer_norm'].
      bn_epsilon: batch normalization epsilon.
      bn_momentum: batch normalization momentum.
      add_pos_enc: if add position embedding.
      pool_type: pooling operation type, one of ['2d:avg', '2d:max', '1d:avg',
        '1d:max'].
      pool_stride: pooling stride size.
      expansion_rate: expansion rate value.
      activation: activate function.
      survival_prob: survival probability.
      survival_prob_anneal: if anneal survival probability.
      representation_size: an optional `int` of representation size.
      add_gap_layer_norm: if add layer norm to GAP of backbone final output.
      kernel_initializer: kernel initializer.
      bias_initializer: bias initializer.
      name: specify module name.
      **kwargs: extra keyword arguments to be passed.
    """

    super().__init__(name=name)
    self._block_type = block_type
    self._num_blocks = num_blocks
    self._hidden_size = hidden_size
    self._stem_hsize = stem_hsize
    self._head_size = head_size
    self._num_heads = num_heads
    self._dropatt = dropatt
    self._dropout = dropout
    self._rel_attn_type = rel_attn_type
    self._window_size = window_size
    self._grid_size = grid_size
    self._scale_ratio = scale_ratio
    self._ln_epsilon = ln_epsilon
    self._ln_dtype = ln_dtype
    self._downsample_loc = downsample_loc
    self._kernel_size = kernel_size
    self._se_ratio = se_ratio
    self._dropcnn = dropcnn
    self._data_format = data_format
    self._norm_type = norm_type
    self._bn_epsilon = bn_epsilon
    self._bn_momentum = bn_momentum
    self._add_pos_enc = add_pos_enc
    self._pool_type = pool_type
    self._pool_stride = pool_stride
    self._expansion_rate = expansion_rate
    self._activation = activation
    self._survival_prob = survival_prob
    self._survival_prob_anneal = survival_prob_anneal
    self._representation_size = representation_size
    self._add_gap_layer_norm = add_gap_layer_norm
    self._kernel_initializer = kernel_initializer
    self._bias_initializer = bias_initializer
    self._output_specs = {}

  def build(self, input_shape: tf.TensorShape) -> None:
    if self._norm_type == 'layer_norm':
      bn_class = functools.partial(
          tf_keras.layers.LayerNormalization, epsilon=self._ln_epsilon
      )
    elif self._norm_type == 'batch_norm':
      bn_class = functools.partial(
          tf_keras.layers.BatchNormalization,
          momentum=self._bn_momentum,
          epsilon=self._bn_epsilon,
      )
    elif self._norm_type == 'sync_batch_norm':
      bn_class = functools.partial(
          tf_keras.layers.BatchNormalization,
          momentum=self._bn_momentum,
          epsilon=self._bn_epsilon,
          synchronized=True,
      )
    else:
      raise ValueError(f'Unsupported norm_type {self._norm_type}.')

    _, self.height, self.width, _ = input_shape.as_list()
    logging.info(
        f'Build backbone with input size: ({self.height}, {self.width}).'
    )

    # Stem
    stem_layers = []
    for i, _ in enumerate(self._stem_hsize):
      conv_layer = tf_keras.layers.Conv2D(
          filters=self._stem_hsize[i],
          kernel_size=self._kernel_size,
          strides=2 if i == 0 else 1,
          padding='same',
          data_format=self._data_format,
          kernel_initializer=self._kernel_initializer,
          bias_initializer=self._bias_initializer,
          use_bias=True,
          name='conv_{}'.format(i),
      )
      stem_layers.append(conv_layer)
      if i < len(self._stem_hsize) - 1:
        stem_layers.append(bn_class(name='norm_{}'.format(i)))
        stem_layers.append(
            tf_keras.layers.Activation(
                ops.get_act_fn(self._activation), name=f'act_{i}'
            )
        )
    self._stem = tf_keras.Sequential(layers=stem_layers, name='stem')

    # Backbone
    self._blocks = []
    total_num_blocks = sum(self._num_blocks)
    bid = 0
    for i, _ in enumerate(self._block_type):
      self._blocks.append([])
      for j in range(self._num_blocks[i]):
        # block name
        block_name = f'block_{i:0>2d}_{j:0>2d}'

        ##### Update per-block config
        # No pooling if not the first block in the stage
        if j == 0:
          pool_stride = self._pool_stride
        else:
          pool_stride = 1

        # anneal the survival prob
        survival_prob = self._survival_prob
        if survival_prob and self._survival_prob_anneal:
          drop_rate = 1.0 - survival_prob
          survival_prob = 1.0 - drop_rate * bid / total_num_blocks
          logging.info(
              '[%02d/%02d] %s survival_prob: %.4f',
              bid,
              total_num_blocks,
              block_name,
              survival_prob,
          )

        ##### Init block
        if self._block_type[i] == 'tfm':
          block = layers.TransformerBlock(
              hidden_size=self._hidden_size[i],
              head_size=self._head_size,
              input_origin_height=self.height,
              input_origin_width=self.width,
              num_heads=self._num_heads,
              expansion_rate=self._expansion_rate,
              activation=self._activation,
              pool_type=self._pool_type,
              pool_stride=pool_stride,
              dropatt=self._dropatt,
              dropout=self._dropout,
              rel_attn_type=self._rel_attn_type,
              scale_ratio=self._scale_ratio,
              survival_prob=survival_prob,
              ln_epsilon=self._ln_epsilon,
              ln_dtype=self._ln_dtype,
              kernel_initializer=self._kernel_initializer,
              bias_initializer=self._bias_initializer,
              name=block_name,
          )
        elif self._block_type[i] == 'mbconv':
          assert self._pool_type in ['2d:max', '2d:avg'], (
              'Invalid pool_type %s for MBConv block' % self._pool_type
          )
          pool_type = self._pool_type.split(':')[-1]
          block = layers.MBConvBlock(
              hidden_size=self._hidden_size[i],
              downsample_loc=self._downsample_loc,
              data_format=self._data_format,
              kernel_size=self._kernel_size,
              expansion_rate=self._expansion_rate,
              se_ratio=self._se_ratio,
              activation=self._activation,
              pool_type=pool_type,
              pool_stride=pool_stride,
              dropcnn=self._dropcnn,
              survival_prob=survival_prob,
              norm_type=self._norm_type,
              bn_epsilon=self._bn_epsilon,
              bn_momentum=self._bn_momentum,
              kernel_initializer=self._kernel_initializer,
              bias_initializer=self._bias_initializer,
              name=block_name,
          )
        elif self._block_type[i] == 'maxvit':
          block = MaxViTBlock(
              hidden_size=self._hidden_size[i],
              head_size=self._head_size,
              window_size=self._window_size,
              grid_size=self._grid_size,
              num_heads=self._num_heads,
              downsample_loc=self._downsample_loc,
              data_format=self._data_format,
              kernel_size=self._kernel_size,
              expansion_rate=self._expansion_rate,
              se_ratio=self._se_ratio,
              activation=self._activation,
              pool_type=self._pool_type,
              pool_stride=pool_stride,
              dropcnn=self._dropcnn,
              dropatt=self._dropatt,
              dropout=self._dropout,
              rel_attn_type=self._rel_attn_type,
              scale_ratio=self._scale_ratio,
              survival_prob=survival_prob,
              ln_epsilon=self._ln_epsilon,
              ln_dtype=self._ln_dtype,
              norm_type=self._norm_type,
              bn_epsilon=self._bn_epsilon,
              bn_momentum=self._bn_momentum,
              kernel_initializer=self._kernel_initializer,
              bias_initializer=self._bias_initializer,
              name=block_name,
          )
        else:
          raise ValueError(f'Unsupported block_type {self._block_type[i]}')
        self._blocks[-1].append(block)
        bid += 1

    if self._representation_size and self._representation_size > 0:
      self._dense = tf_keras.layers.Dense(
          self._representation_size, name='pre_logits')
      if self._add_gap_layer_norm:
        self._final_layer_norm = tf_keras.layers.LayerNormalization(
            epsilon=self._ln_epsilon, name='final_layer_norm')

  def _add_absolute_position_encoding(self, inputs: tf.Tensor) -> tf.Tensor:
    """Add absolute sinusoid position encoding, which is computed on the fly."""
    output = ops.maybe_reshape_to_2d(inputs)
    h, w = tf.shape(output)[1], tf.shape(output)[2]
    enc_size = output.shape.as_list()[-1] // 2
    # sinusoid positional encoding that can be generated online
    h_seq = tf.range(-h / 2, h / 2)
    w_seq = tf.range(-w / 2, w / 2)
    pos_enc_h = ops.absolute_position_encoding(
        h_seq, enc_size, dtype=output.dtype
    )
    pos_enc_w = ops.absolute_position_encoding(
        w_seq, enc_size, dtype=output.dtype
    )
    abs_pos_enc = tf.concat(
        [
            tf.tile(pos_enc_h[:, None, :], [1, w, 1]),
            tf.tile(pos_enc_w[None, :, :], [h, 1, 1]),
        ],
        axis=-1,
    )
    output += abs_pos_enc
    if inputs.shape.rank == 3:
      output = ops.maybe_reshape_to_1d(output)
    return output

  def call(
      self, inputs: tf.Tensor, mask: Optional[Any] = None, training: bool = None
  ) -> Mapping[str, tf.Tensor]:
    logging.info(
        'MaxViT inputs: shape %s, dtype %s.', inputs.shape, inputs.dtype
    )
    output = self._stem(inputs, training=training)
    logging.info(
        'Stage 0 (stem) output: shape %s, dtype %s.', output.shape, output.dtype
    )

    endpoints = {}
    add_pos_enc = self._add_pos_enc
    for idx, stage_blocks in enumerate(self._blocks):
      # Add position encoding
      # Note: the position encoding is usually added to the input of the first
      # transformer block. For MaxViT, it is the first block of stage 3.
      if (isinstance(add_pos_enc, (tuple, list)) and add_pos_enc[idx]) or (
          isinstance(add_pos_enc, bool) and add_pos_enc
      ):
        logging.info('Add position encoding at stage %d.', idx + 1)
        output = self._add_absolute_position_encoding(output)

      # Blocks forward
      for block in stage_blocks:
        output = block(output, training=training)

      if self._block_type[idx] == 'tfm':
        height, width = ops.get_shape_from_length(
            output.shape[1], self.height, self.width
        )
        output = tf.reshape(output, [-1, height, width, output.shape[-1]])

      endpoints[str(idx + 2)] = output
      logging.info(
          'Stage %d output: feature level %s shape %s, dtype %s.',
          idx + 1,
          idx + 2,
          output.shape,
          output.dtype,
      )

    self._output_specs = {
        idx: endpoint.get_shape() for idx, endpoint in endpoints.items()
    }

    if self._representation_size and self._representation_size > 0:
      # Backbone's output is [batch_size, height, weight, channel_size].
      output = tf_keras.layers.GlobalAveragePooling2D()(output)
      # Maybe add a layer_norm after global average pooling.
      if self._add_gap_layer_norm:
        output = self._final_layer_norm(output)
      endpoints['pre_logits'] = tf.nn.tanh(self._dense(output))

    return endpoints

  @property
  def output_specs(self):
    """A dict of {level: TensorShape} pairs for the model output."""
    return self._output_specs


def override_predefined_spec_and_build_maxvit(
    predefined_maxvit_spec, backbone_cfg, norm_activation_config
):
  """Builds a MaxViT backbone.

  Args:
    predefined_maxvit_spec: a dict predefined maxvit specifications.
    backbone_cfg: the MaxViT backbone config.
    norm_activation_config: normalization and activation config.

  Returns:
    The built MaxViT backbone.
  """
  survival_prob = (
      predefined_maxvit_spec['survival_prob']
      if backbone_cfg.survival_prob is None
      else backbone_cfg.survival_prob
  )
  stem_hsize = (
      predefined_maxvit_spec['stem_hsize']
      if backbone_cfg.stem_hsize is None
      else backbone_cfg.stem_hsize
  )
  block_type = (
      predefined_maxvit_spec['block_type']
      if backbone_cfg.block_type is None
      else backbone_cfg.block_type
  )
  num_blocks = (
      predefined_maxvit_spec['num_blocks']
      if backbone_cfg.num_blocks is None
      else backbone_cfg.num_blocks
  )
  hidden_size = (
      predefined_maxvit_spec['hidden_size']
      if backbone_cfg.hidden_size is None
      else backbone_cfg.hidden_size
  )

  logging.info(
      (
          'Final MaxViT specs: survival_prob=%s, stem_hsize=%s, hidden_size=%s,'
          'block_type=%s, num_blocks=%s,.'
      ),
      survival_prob,
      stem_hsize,
      hidden_size,
      block_type,
      num_blocks,
  )

  return MaxViT(
      block_type=block_type,
      num_blocks=num_blocks,
      hidden_size=hidden_size,
      stem_hsize=stem_hsize,
      head_size=backbone_cfg.head_size,
      dropatt=backbone_cfg.dropatt,
      dropout=backbone_cfg.dropout,
      rel_attn_type=backbone_cfg.rel_attn_type,
      window_size=backbone_cfg.window_size,
      grid_size=backbone_cfg.grid_size,
      scale_ratio=backbone_cfg.scale_ratio,
      ln_epsilon=backbone_cfg.ln_epsilon,
      ln_dtype=backbone_cfg.ln_dtype,
      downsample_loc=backbone_cfg.downsample_loc,
      kernel_size=backbone_cfg.kernel_size,
      se_ratio=backbone_cfg.se_ratio,
      dropcnn=backbone_cfg.dropcnn,
      data_format=backbone_cfg.data_format,
      norm_type=backbone_cfg.norm_type,
      bn_epsilon=norm_activation_config.norm_epsilon,
      bn_momentum=norm_activation_config.norm_momentum,
      add_pos_enc=backbone_cfg.add_pos_enc,
      pool_type=backbone_cfg.pool_type,
      pool_stride=backbone_cfg.pool_stride,
      expansion_rate=backbone_cfg.expansion_rate,
      activation=norm_activation_config.activation,
      survival_prob=survival_prob,
      survival_prob_anneal=backbone_cfg.survival_prob_anneal,
      representation_size=backbone_cfg.representation_size,
      add_gap_layer_norm=backbone_cfg.add_gap_layer_norm,
      kernel_initializer=backbone_cfg.kernel_initializer,
      bias_initializer=backbone_cfg.bias_initializer,
  )


@factory.register_backbone_builder('maxvit')
def build_maxvit(
    input_specs,
    backbone_config,
    norm_activation_config,
    l2_regularizer=None,
):
  """Builds a MaxViT backbone."""
  del l2_regularizer
  backbone_cfg = backbone_config.get()
  maxvit = override_predefined_spec_and_build_maxvit(
      predefined_maxvit_spec=MAXVIT_SPECS[backbone_cfg.model_name],
      backbone_cfg=backbone_cfg,
      norm_activation_config=norm_activation_config,
  )
  # Build the backbone to get a proper `output_specs`.
  dummy_inputs = tf_keras.Input(input_specs.shape[1:])
  _ = maxvit(dummy_inputs, training=False)
  return maxvit