tensorflow/models

View on GitHub
official/projects/centernet/modeling/backbones/hourglass.py

Summary

Maintainability
A
2 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Build Hourglass backbone."""

from typing import Optional

import tensorflow as tf, tf_keras

from official.modeling import hyperparams
from official.projects.centernet.modeling.layers import cn_nn_blocks
from official.vision.modeling.backbones import factory
from official.vision.modeling.backbones import mobilenet
from official.vision.modeling.layers import nn_blocks

HOURGLASS_SPECS = {
    10: {
        'blocks_per_stage': [1, 1],
        'channel_dims_per_stage': [2, 2]
    },
    20: {
        'blocks_per_stage': [1, 2, 2],
        'channel_dims_per_stage': [2, 2, 3]
    },
    32: {
        'blocks_per_stage': [2, 2, 2, 2],
        'channel_dims_per_stage': [2, 2, 3, 3]
    },
    52: {
        'blocks_per_stage': [2, 2, 2, 2, 2, 4],
        'channel_dims_per_stage': [2, 2, 3, 3, 3, 4]
    },
    100: {
        'blocks_per_stage': [4, 4, 4, 4, 4, 8],
        'channel_dims_per_stage': [2, 2, 3, 3, 3, 4]
    },
}


class Hourglass(tf_keras.Model):
  """CenterNet Hourglass backbone."""

  def __init__(
      self,
      model_id: int,
      input_channel_dims: int,
      input_specs=tf_keras.layers.InputSpec(shape=[None, None, None, 3]),
      num_hourglasses: int = 1,
      initial_downsample: bool = True,
      activation: str = 'relu',
      use_sync_bn: bool = True,
      norm_momentum=0.1,
      norm_epsilon=1e-5,
      kernel_initializer: str = 'VarianceScaling',
      kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      **kwargs):
    """Initialize Hourglass backbone.

    Args:
      model_id: An `int` of the scale of Hourglass backbone model.
      input_channel_dims: `int`, number of filters used to downsample the
        input image.
      input_specs: A `tf_keras.layers.InputSpec` of specs of the input tensor.
      num_hourglasses: `int``, number of hourglass blocks in backbone. For
        example, hourglass-104 has two hourglass-52 modules.
      initial_downsample: `bool`, whether or not to downsample the input.
      activation: A `str` name of the activation function.
      use_sync_bn: If True, use synchronized batch normalization.
      norm_momentum: `float`, momentum for the batch normalization layers.
      norm_epsilon: `float`, epsilon for the batch normalization layers.
      kernel_initializer: A `str` for kernel initializer of conv layers.
      kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
        Conv2D. Default to None.
      bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D.
        Default to None.
      **kwargs: Additional keyword arguments to be passed.
    """
    self._input_channel_dims = input_channel_dims
    self._model_id = model_id
    self._num_hourglasses = num_hourglasses
    self._initial_downsample = initial_downsample
    self._activation = activation
    self._kernel_initializer = kernel_initializer
    self._kernel_regularizer = kernel_regularizer
    self._bias_regularizer = bias_regularizer
    self._use_sync_bn = use_sync_bn
    self._norm_momentum = norm_momentum
    self._norm_epsilon = norm_epsilon

    specs = HOURGLASS_SPECS[model_id]
    self._blocks_per_stage = specs['blocks_per_stage']
    self._channel_dims_per_stage = [item * self._input_channel_dims
                                    for item in specs['channel_dims_per_stage']]

    inputs = tf_keras.layers.Input(shape=input_specs.shape[1:])

    inp_filters = self._channel_dims_per_stage[0]

    # Downsample the input
    if initial_downsample:
      prelayer_kernel_size = 7
      prelayer_strides = 2
    else:
      prelayer_kernel_size = 3
      prelayer_strides = 1

    x_downsampled = mobilenet.Conv2DBNBlock(
        filters=self._input_channel_dims,
        kernel_size=prelayer_kernel_size,
        strides=prelayer_strides,
        use_explicit_padding=True,
        activation=self._activation,
        bias_regularizer=self._bias_regularizer,
        kernel_initializer=self._kernel_initializer,
        kernel_regularizer=self._kernel_regularizer,
        use_sync_bn=self._use_sync_bn,
        norm_momentum=self._norm_momentum,
        norm_epsilon=self._norm_epsilon)(inputs)

    x_downsampled = nn_blocks.ResidualBlock(
        filters=inp_filters,
        use_projection=True,
        use_explicit_padding=True,
        strides=prelayer_strides,
        bias_regularizer=self._bias_regularizer,
        kernel_initializer=self._kernel_initializer,
        kernel_regularizer=self._kernel_regularizer,
        use_sync_bn=self._use_sync_bn,
        norm_momentum=self._norm_momentum,
        norm_epsilon=self._norm_epsilon)(x_downsampled)

    all_heatmaps = {}
    for i in range(num_hourglasses):
      # Create an hourglass stack
      x_hg = cn_nn_blocks.HourglassBlock(
          channel_dims_per_stage=self._channel_dims_per_stage,
          blocks_per_stage=self._blocks_per_stage,
      )(x_downsampled)

      x_hg = mobilenet.Conv2DBNBlock(
          filters=inp_filters,
          kernel_size=3,
          strides=1,
          use_explicit_padding=True,
          activation=self._activation,
          bias_regularizer=self._bias_regularizer,
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          use_sync_bn=self._use_sync_bn,
          norm_momentum=self._norm_momentum,
          norm_epsilon=self._norm_epsilon
      )(x_hg)

      # Given two down-sampling blocks above, the starting level is set to 2
      # To make it compatible with implementation of remaining backbones, the
      # output of hourglass backbones is organized as
      # '2' -> the last layer of output
      # '2_0' -> the first layer of output
      # ......
      # '2_{num_hourglasses-2}' -> the second to last layer of output
      if i < num_hourglasses - 1:
        all_heatmaps['2_{}'.format(i)] = x_hg
      else:
        all_heatmaps['2'] = x_hg

      # Intermediate conv and residual layers between hourglasses
      if i < num_hourglasses - 1:
        inter_hg_conv1 = mobilenet.Conv2DBNBlock(
            filters=inp_filters,
            kernel_size=1,
            strides=1,
            activation='identity',
            bias_regularizer=self._bias_regularizer,
            kernel_initializer=self._kernel_initializer,
            kernel_regularizer=self._kernel_regularizer,
            use_sync_bn=self._use_sync_bn,
            norm_momentum=self._norm_momentum,
            norm_epsilon=self._norm_epsilon
        )(x_downsampled)

        inter_hg_conv2 = mobilenet.Conv2DBNBlock(
            filters=inp_filters,
            kernel_size=1,
            strides=1,
            activation='identity',
            bias_regularizer=self._bias_regularizer,
            kernel_initializer=self._kernel_initializer,
            kernel_regularizer=self._kernel_regularizer,
            use_sync_bn=self._use_sync_bn,
            norm_momentum=self._norm_momentum,
            norm_epsilon=self._norm_epsilon
        )(x_hg)

        x_downsampled = tf_keras.layers.Add()([inter_hg_conv1, inter_hg_conv2])
        x_downsampled = tf_keras.layers.ReLU()(x_downsampled)

        x_downsampled = nn_blocks.ResidualBlock(
            filters=inp_filters,
            use_projection=False,
            use_explicit_padding=True,
            strides=1,
            bias_regularizer=self._bias_regularizer,
            kernel_initializer=self._kernel_initializer,
            kernel_regularizer=self._kernel_regularizer,
            use_sync_bn=self._use_sync_bn,
            norm_momentum=self._norm_momentum,
            norm_epsilon=self._norm_epsilon
        )(x_downsampled)

    self._output_specs = {l: all_heatmaps[l].get_shape() for l in all_heatmaps}

    super().__init__(inputs=inputs, outputs=all_heatmaps, **kwargs)

  def get_config(self):
    config = {
        'model_id': self._model_id,
        'input_channel_dims': self._input_channel_dims,
        'num_hourglasses': self._num_hourglasses,
        'initial_downsample': self._initial_downsample,
        'kernel_initializer': self._kernel_initializer,
        'kernel_regularizer': self._kernel_regularizer,
        'bias_regularizer': self._bias_regularizer,
        'use_sync_bn': self._use_sync_bn,
        'norm_momentum': self._norm_momentum,
        'norm_epsilon': self._norm_epsilon
    }
    config.update(super(Hourglass, self).get_config())
    return config

  @property
  def num_hourglasses(self):
    return self._num_hourglasses

  @property
  def output_specs(self):
    return self._output_specs


@factory.register_backbone_builder('hourglass')
def build_hourglass(
    input_specs: tf_keras.layers.InputSpec,
    backbone_config: hyperparams.Config,
    norm_activation_config: hyperparams.Config,
    l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None
    ) -> tf_keras.Model:
  """Builds Hourglass backbone from a configuration."""
  backbone_type = backbone_config.type
  backbone_cfg = backbone_config.get()
  assert backbone_type == 'hourglass', (f'Inconsistent backbone type '
                                        f'{backbone_type}')

  return Hourglass(
      model_id=backbone_cfg.model_id,
      input_channel_dims=backbone_cfg.input_channel_dims,
      num_hourglasses=backbone_cfg.num_hourglasses,
      input_specs=input_specs,
      initial_downsample=backbone_cfg.initial_downsample,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer,
  )