tensorflow/models

View on GitHub
official/projects/deepmac_maskrcnn/modeling/heads/hourglass_network.py

Summary

Maintainability
F
2 wks
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Hourglass[1] network.

[1]: https://arxiv.org/abs/1603.06937
"""


import tensorflow as tf, tf_keras

BATCH_NORM_EPSILON = 1e-5
BATCH_NORM_MOMENTUM = 0.1
BATCH_NORM_FUSED = True


class IdentityLayer(tf_keras.layers.Layer):
  """A layer which passes through the input as it is."""

  def call(self, inputs):
    return inputs


def _get_padding_for_kernel_size(kernel_size):
  if kernel_size == 7:
    return (3, 3)
  elif kernel_size == 3:
    return (1, 1)
  else:
    raise ValueError('Padding for kernel size {} not known.'.format(
        kernel_size))


def batchnorm():
  try:
    return tf_keras.layers.experimental.SyncBatchNormalization(
        name='batchnorm', epsilon=1e-5, momentum=0.1)
  except AttributeError:
    return tf_keras.layers.BatchNormalization(
        name='batchnorm', epsilon=1e-5, momentum=0.1, fused=BATCH_NORM_FUSED)


class ConvolutionalBlock(tf_keras.layers.Layer):
  """Block that aggregates Convolution + Norm layer + ReLU."""

  def __init__(self, kernel_size, out_channels, stride=1, relu=True,
               padding='same'):
    """Initializes the Convolutional block.

    Args:
      kernel_size: int, convolution kernel size.
      out_channels: int, the desired number of output channels.
      stride: Integer, stride used in the convolution.
      relu: bool, whether to use relu at the end of the layer.
      padding: str, the padding scheme to use when kernel_size <= 1
    """
    super(ConvolutionalBlock, self).__init__()

    if kernel_size > 1:
      padding = 'valid'
      padding_size = _get_padding_for_kernel_size(kernel_size)

      # TODO(vighneshb) Explore if removing and using padding option in conv
      # layer works.
      self.pad = tf_keras.layers.ZeroPadding2D(padding_size)
    else:
      self.pad = IdentityLayer()

    self.conv = tf_keras.layers.Conv2D(
        filters=out_channels, kernel_size=kernel_size, use_bias=False,
        strides=stride, padding=padding)

    self.norm = batchnorm()

    if relu:
      self.relu = tf_keras.layers.ReLU()
    else:
      self.relu = IdentityLayer()

  def call(self, inputs):
    net = self.pad(inputs)
    net = self.conv(net)
    net = self.norm(net)
    return self.relu(net)


class SkipConvolution(ConvolutionalBlock):
  """The skip connection layer for a ResNet."""

  def __init__(self, out_channels, stride):
    """Initializes the skip convolution layer.

    Args:
      out_channels: int, the desired number of output channels.
      stride: int, the stride for the layer.
    """
    super(SkipConvolution, self).__init__(
        out_channels=out_channels, kernel_size=1, stride=stride, relu=False)


class ResidualBlock(tf_keras.layers.Layer):
  """A Residual block."""

  def __init__(self, out_channels, skip_conv=False, kernel_size=3, stride=1,
               padding='same'):
    """Initializes the Residual block.

    Args:
      out_channels: int, the desired number of output channels.
      skip_conv: bool, whether to use a conv layer for skip connections.
      kernel_size: int, convolution kernel size.
      stride: Integer, stride used in the convolution.
      padding: str, the type of padding to use.
    """

    super(ResidualBlock, self).__init__()
    self.conv_block = ConvolutionalBlock(
        kernel_size=kernel_size, out_channels=out_channels, stride=stride)

    self.conv = tf_keras.layers.Conv2D(
        filters=out_channels, kernel_size=kernel_size, use_bias=False,
        strides=1, padding=padding)
    self.norm = batchnorm()

    if skip_conv:
      self.skip = SkipConvolution(out_channels=out_channels,
                                  stride=stride)
    else:
      self.skip = IdentityLayer()

    self.relu = tf_keras.layers.ReLU()

  def call(self, inputs):
    net = self.conv_block(inputs)
    net = self.conv(net)
    net = self.norm(net)
    net_skip = self.skip(inputs)
    return self.relu(net + net_skip)


class InputDownsampleBlock(tf_keras.layers.Layer):
  """Block for the initial feature downsampling."""

  def __init__(self, out_channels_initial_conv, out_channels_residual_block):
    """Initializes the downsample block.

    Args:
      out_channels_initial_conv: int, the desired number of output channels
        in the initial conv layer.
      out_channels_residual_block: int, the desired number of output channels
        in the underlying residual block.
    """

    super(InputDownsampleBlock, self).__init__()
    self.conv_block = ConvolutionalBlock(
        kernel_size=7, out_channels=out_channels_initial_conv, stride=2,
        padding='valid')
    self.residual_block = ResidualBlock(
        out_channels=out_channels_residual_block, stride=2, skip_conv=True)

  def call(self, inputs):
    return self.residual_block(self.conv_block(inputs))


class InputConvBlock(tf_keras.layers.Layer):
  """Block for the initial feature convolution.

  This block is used in the hourglass network when we don't want to downsample
  the input.
  """

  def __init__(self, out_channels_initial_conv, out_channels_residual_block):
    """Initializes the downsample block.

    Args:
      out_channels_initial_conv: int, the desired number of output channels
        in the initial conv layer.
      out_channels_residual_block: int, the desired number of output channels
        in the underlying residual block.
    """

    super(InputConvBlock, self).__init__()

    self.conv_block = ConvolutionalBlock(
        kernel_size=3, out_channels=out_channels_initial_conv, stride=1,
        padding='valid')
    self.residual_block = ResidualBlock(
        out_channels=out_channels_residual_block, stride=1, skip_conv=True)

  def call(self, inputs):
    return self.residual_block(self.conv_block(inputs))


def _make_repeated_residual_blocks(out_channels, num_blocks,
                                   initial_stride=1, residual_channels=None,
                                   initial_skip_conv=False):
  """Stack Residual blocks one after the other.

  Args:
    out_channels: int, the desired number of output channels.
    num_blocks: int, the number of residual blocks to be stacked.
    initial_stride: int, the stride of the initial residual block.
    residual_channels: int, the desired number of output channels in the
      intermediate residual blocks. If not specifed, we use out_channels.
    initial_skip_conv: bool, if set, the first residual block uses a skip
      convolution. This is useful when the number of channels in the input
      are not the same as residual_channels.

  Returns:
    blocks: A list of residual blocks to be applied in sequence.

  """

  blocks = []

  if residual_channels is None:
    residual_channels = out_channels

  for i in range(num_blocks - 1):
    # Only use the stride at the first block so we don't repeatedly downsample
    # the input
    stride = initial_stride if i == 0 else 1

    # If the stide is more than 1, we cannot use an identity layer for the
    # skip connection and are forced to use a conv for the skip connection.
    skip_conv = stride > 1

    if i == 0 and initial_skip_conv:
      skip_conv = True

    blocks.append(
        ResidualBlock(out_channels=residual_channels, stride=stride,
                      skip_conv=skip_conv)
    )

  if num_blocks == 1:
    # If there is only 1 block, the for loop above is not run,
    # therefore we honor the requested stride in the last residual block
    stride = initial_stride
    # We are forced to use a conv in the skip connection if stride > 1
    skip_conv = stride > 1
  else:
    stride = 1
    skip_conv = residual_channels != out_channels

  blocks.append(ResidualBlock(out_channels=out_channels, skip_conv=skip_conv,
                              stride=stride))

  return blocks


def _apply_blocks(inputs, blocks):
  net = inputs

  for block in blocks:
    net = block(net)

  return net


class EncoderDecoderBlock(tf_keras.layers.Layer):
  """An encoder-decoder block which recursively defines the hourglass network."""

  def __init__(self, num_stages, channel_dims, blocks_per_stage,
               stagewise_downsample=True, encoder_decoder_shortcut=True):
    """Initializes the encoder-decoder block.

    Args:
      num_stages: int, Number of stages in the network. At each stage we have 2
        encoder and 1 decoder blocks. The second encoder block downsamples the
        input.
      channel_dims: int list, the output channels dimensions of stages in
        the network. `channel_dims[0]` is used to define the number of
        channels in the first encoder block and `channel_dims[1]` is used to
        define the number of channels in the second encoder block. The channels
        in the recursive inner layers are defined using `channel_dims[1:]`
      blocks_per_stage: int list, number of residual blocks to use at each
        stage. `blocks_per_stage[0]` defines the number of blocks at the
        current stage and `blocks_per_stage[1:]` is used at further stages.
      stagewise_downsample: bool, whether or not to downsample before passing
        inputs to the next stage.
      encoder_decoder_shortcut: bool, whether or not to use shortcut
        connections between encoder and decoder.
    """

    super(EncoderDecoderBlock, self).__init__()

    out_channels = channel_dims[0]
    out_channels_downsampled = channel_dims[1]

    self.encoder_decoder_shortcut = encoder_decoder_shortcut

    if encoder_decoder_shortcut:
      self.merge_features = tf_keras.layers.Add()
      self.encoder_block1 = _make_repeated_residual_blocks(
          out_channels=out_channels, num_blocks=blocks_per_stage[0],
          initial_stride=1)

    initial_stride = 2 if stagewise_downsample else 1
    self.encoder_block2 = _make_repeated_residual_blocks(
        out_channels=out_channels_downsampled,
        num_blocks=blocks_per_stage[0], initial_stride=initial_stride,
        initial_skip_conv=out_channels != out_channels_downsampled)

    if num_stages > 1:
      self.inner_block = [
          EncoderDecoderBlock(num_stages - 1, channel_dims[1:],
                              blocks_per_stage[1:],
                              stagewise_downsample=stagewise_downsample,
                              encoder_decoder_shortcut=encoder_decoder_shortcut)
      ]
    else:
      self.inner_block = _make_repeated_residual_blocks(
          out_channels=out_channels_downsampled,
          num_blocks=blocks_per_stage[1])

    self.decoder_block = _make_repeated_residual_blocks(
        residual_channels=out_channels_downsampled,
        out_channels=out_channels, num_blocks=blocks_per_stage[0])

    self.upsample = tf_keras.layers.UpSampling2D(initial_stride)

  def call(self, inputs):

    if self.encoder_decoder_shortcut:
      encoded_outputs = _apply_blocks(inputs, self.encoder_block1)
    encoded_downsampled_outputs = _apply_blocks(inputs, self.encoder_block2)
    inner_block_outputs = _apply_blocks(
        encoded_downsampled_outputs, self.inner_block)

    decoded_outputs = _apply_blocks(inner_block_outputs, self.decoder_block)
    upsampled_outputs = self.upsample(decoded_outputs)

    if self.encoder_decoder_shortcut:
      return self.merge_features([encoded_outputs, upsampled_outputs])
    else:
      return upsampled_outputs


class HourglassNetwork(tf_keras.Model):
  """The hourglass network."""

  def __init__(self, num_stages, input_channel_dims, channel_dims_per_stage,
               blocks_per_stage, num_hourglasses, initial_downsample=True,
               stagewise_downsample=True, encoder_decoder_shortcut=True):
    """Intializes the feature extractor.

    Args:
      num_stages: int, Number of stages in the network. At each stage we have 2
        encoder and 1 decoder blocks. The second encoder block downsamples the
        input.
      input_channel_dims: int, the number of channels in the input conv blocks.
      channel_dims_per_stage: int list, the output channel dimensions of each
        stage in the hourglass network.
      blocks_per_stage: int list, number of residual blocks to use at each
        stage in the hourglass network
      num_hourglasses: int, number of hourglas networks to stack
        sequentially.
      initial_downsample: bool, if set, downsamples the input by a factor of 4
        before applying the rest of the network. Downsampling is done with a 7x7
        convolution kernel, otherwise a 3x3 kernel is used.
      stagewise_downsample: bool, whether or not to downsample before passing
        inputs to the next stage.
      encoder_decoder_shortcut: bool, whether or not to use shortcut
        connections between encoder and decoder.
    """

    super(HourglassNetwork, self).__init__()

    self.num_hourglasses = num_hourglasses
    self.initial_downsample = initial_downsample
    if initial_downsample:
      self.downsample_input = InputDownsampleBlock(
          out_channels_initial_conv=input_channel_dims,
          out_channels_residual_block=channel_dims_per_stage[0]
      )
    else:
      self.conv_input = InputConvBlock(
          out_channels_initial_conv=input_channel_dims,
          out_channels_residual_block=channel_dims_per_stage[0]
      )

    self.hourglass_network = []
    self.output_conv = []
    for _ in range(self.num_hourglasses):
      self.hourglass_network.append(
          EncoderDecoderBlock(
              num_stages=num_stages, channel_dims=channel_dims_per_stage,
              blocks_per_stage=blocks_per_stage,
              stagewise_downsample=stagewise_downsample,
              encoder_decoder_shortcut=encoder_decoder_shortcut)
      )
      self.output_conv.append(
          ConvolutionalBlock(kernel_size=3,
                             out_channels=channel_dims_per_stage[0])
      )

    self.intermediate_conv1 = []
    self.intermediate_conv2 = []
    self.intermediate_residual = []

    for _ in range(self.num_hourglasses - 1):
      self.intermediate_conv1.append(
          ConvolutionalBlock(
              kernel_size=1, out_channels=channel_dims_per_stage[0], relu=False)
      )
      self.intermediate_conv2.append(
          ConvolutionalBlock(
              kernel_size=1, out_channels=channel_dims_per_stage[0], relu=False)
      )
      self.intermediate_residual.append(
          ResidualBlock(out_channels=channel_dims_per_stage[0])
      )

    self.intermediate_relu = tf_keras.layers.ReLU()

  def call(self, inputs):  # pytype: disable=signature-mismatch  # overriding-parameter-count-checks

    if self.initial_downsample:
      inputs = self.downsample_input(inputs)
    else:
      inputs = self.conv_input(inputs)

    outputs = []

    for i in range(self.num_hourglasses):

      hourglass_output = self.hourglass_network[i](inputs)

      output = self.output_conv[i](hourglass_output)
      outputs.append(output)

      if i < self.num_hourglasses - 1:
        secondary_output = (self.intermediate_conv1[i](inputs) +
                            self.intermediate_conv2[i](output))
        secondary_output = self.intermediate_relu(secondary_output)
        inputs = self.intermediate_residual[i](secondary_output)

    return outputs

  @property
  def out_stride(self):
    """The stride in the output image of the network."""
    return 4

  @property
  def num_feature_outputs(self):
    """Ther number of feature outputs returned by the feature extractor."""
    return self.num_hourglasses


def _layer_depth(layer):
  """Compute depth of Conv/Residual blocks or lists of them."""

  if isinstance(layer, list):
    return sum([_layer_depth(l) for l in layer])

  elif isinstance(layer, ConvolutionalBlock):
    return 1

  elif isinstance(layer, ResidualBlock):
    return 2

  else:
    raise ValueError('Unknown layer - {}'.format(layer))


def _encoder_decoder_depth(network):
  """Helper function to compute depth of encoder-decoder blocks."""

  encoder_block2_layers = _layer_depth(network.encoder_block2)
  decoder_block_layers = _layer_depth(network.decoder_block)

  if isinstance(network.inner_block[0], EncoderDecoderBlock):

    assert len(network.inner_block) == 1, 'Inner block is expected as length 1.'
    inner_block_layers = _encoder_decoder_depth(network.inner_block[0])

    return inner_block_layers + encoder_block2_layers + decoder_block_layers

  elif isinstance(network.inner_block[0], ResidualBlock):
    return (encoder_block2_layers + decoder_block_layers +
            _layer_depth(network.inner_block))

  else:
    raise ValueError('Unknown inner block type.')


def hourglass_depth(network):
  """Helper function to verify depth of hourglass backbone."""

  input_conv_layers = 3  # 1 ResidualBlock and 1 ConvBlock

  # Only intermediate_conv2 and intermediate_residual are applied before
  # sending inputs to the later stages.
  intermediate_layers = (
      _layer_depth(network.intermediate_conv2) +
      _layer_depth(network.intermediate_residual)
  )

  # network.output_conv is applied before sending input to the later stages
  output_layers = _layer_depth(network.output_conv)

  encoder_decoder_layers = sum(_encoder_decoder_depth(net) for net in
                               network.hourglass_network)

  return (input_conv_layers + encoder_decoder_layers + intermediate_layers
          + output_layers)


def hourglass_104():
  """The Hourglass-104 backbone.

  The architecture parameters are taken from [1].

  Returns:
    network: An HourglassNetwork object implementing the Hourglass-104
      backbone.

  [1]: https://arxiv.org/abs/1904.07850
  """

  return HourglassNetwork(
      input_channel_dims=128,
      channel_dims_per_stage=[256, 256, 384, 384, 384, 512],
      num_hourglasses=2,
      num_stages=5,
      blocks_per_stage=[2, 2, 2, 2, 2, 4],
  )


def single_stage_hourglass(input_channel_dims, channel_dims_per_stage,
                           blocks_per_stage, initial_downsample=True,
                           stagewise_downsample=True,
                           encoder_decoder_shortcut=True):
  assert len(channel_dims_per_stage) == len(blocks_per_stage)

  return HourglassNetwork(
      input_channel_dims=input_channel_dims,
      channel_dims_per_stage=channel_dims_per_stage,
      num_hourglasses=1,
      num_stages=len(channel_dims_per_stage) - 1,
      blocks_per_stage=blocks_per_stage,
      initial_downsample=initial_downsample,
      stagewise_downsample=stagewise_downsample,
      encoder_decoder_shortcut=encoder_decoder_shortcut
  )


def hourglass_10(num_channels, initial_downsample=True):
  nc = num_channels
  return single_stage_hourglass(
      input_channel_dims=nc,
      initial_downsample=initial_downsample,
      blocks_per_stage=[1, 1],
      channel_dims_per_stage=[nc * 2, nc * 2])


def hourglass_20(num_channels, initial_downsample=True):
  nc = num_channels
  return single_stage_hourglass(
      input_channel_dims=nc,
      initial_downsample=initial_downsample,
      blocks_per_stage=[1, 2, 2],
      channel_dims_per_stage=[nc * 2, nc * 2, nc * 3])


def hourglass_32(num_channels, initial_downsample=True):
  nc = num_channels
  return single_stage_hourglass(
      input_channel_dims=nc,
      initial_downsample=initial_downsample,
      blocks_per_stage=[2, 2, 2, 2],
      channel_dims_per_stage=[nc * 2, nc * 2, nc * 3, nc * 3])


def hourglass_52(num_channels, initial_downsample=True):
  nc = num_channels
  return single_stage_hourglass(
      input_channel_dims=nc,
      initial_downsample=initial_downsample,
      blocks_per_stage=[2, 2, 2, 2, 2, 4],
      channel_dims_per_stage=[nc * 2, nc * 2, nc * 3, nc * 3, nc * 3, nc*4])


def hourglass_100(num_channels, initial_downsample=True):
  nc = num_channels
  return single_stage_hourglass(
      input_channel_dims=nc,
      initial_downsample=initial_downsample,
      blocks_per_stage=[4, 4, 4, 4, 4, 8],
      channel_dims_per_stage=[nc * 2, nc * 2, nc * 3, nc * 3, nc * 3, nc*4])


def hourglass_20_uniform_size(num_channels):
  nc = num_channels
  return single_stage_hourglass(
      input_channel_dims=nc,
      blocks_per_stage=[1, 2, 2],
      channel_dims_per_stage=[nc * 2, nc * 2, nc * 3],
      initial_downsample=False,
      stagewise_downsample=False)


def hourglass_20_no_shortcut(num_channels):
  nc = num_channels
  return single_stage_hourglass(
      input_channel_dims=nc,
      blocks_per_stage=[1, 2, 2],
      channel_dims_per_stage=[nc * 2, nc * 2, nc * 3],
      initial_downsample=False,
      encoder_decoder_shortcut=False)