tensorflow/models

View on GitHub
official/projects/basnet/modeling/nn_blocks.py

Summary

Maintainability
C
7 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains common building blocks for BasNet model."""

import tensorflow as tf, tf_keras

from official.modeling import tf_utils


@tf_keras.utils.register_keras_serializable(package='Vision')
class ConvBlock(tf_keras.layers.Layer):
  """A (Conv+BN+Activation) block."""

  def __init__(self,
               filters,
               strides,
               dilation_rate=1,
               kernel_size=3,
               kernel_initializer='VarianceScaling',
               kernel_regularizer=None,
               bias_regularizer=None,
               activation='relu',
               use_bias=False,
               use_sync_bn=False,
               norm_momentum=0.99,
               norm_epsilon=0.001,
               **kwargs):
    """A vgg block with BN after convolutions.

    Args:
      filters: `int` number of filters for the first two convolutions. Note that
        the third and final convolution will use 4 times as many filters.
      strides: `int` block stride. If greater than 1, this block will ultimately
        downsample the input.
      dilation_rate: `int`, dilation rate for conv layers.
      kernel_size: `int`, kernel size of conv layers.
      kernel_initializer: kernel_initializer for convolutional layers.
      kernel_regularizer: tf_keras.regularizers.Regularizer object for Conv2D.
                          Default to None.
      bias_regularizer: tf_keras.regularizers.Regularizer object for Conv2d.
                        Default to None.
      activation: `str` name of the activation function.
      use_bias: `bool`, whether or not use bias in conv layers.
      use_sync_bn: if True, use synchronized batch normalization.
      norm_momentum: `float` normalization omentum for the moving average.
      norm_epsilon: `float` small float added to variance to avoid dividing by
        zero.
      **kwargs: keyword arguments to be passed.
    """
    super(ConvBlock, self).__init__(**kwargs)
    self._config_dict = {
        'filters': filters,
        'kernel_size': kernel_size,
        'strides': strides,
        'dilation_rate': dilation_rate,
        'kernel_initializer': kernel_initializer,
        'kernel_regularizer': kernel_regularizer,
        'bias_regularizer': bias_regularizer,
        'activation': activation,
        'use_sync_bn': use_sync_bn,
        'use_bias': use_bias,
        'norm_momentum': norm_momentum,
        'norm_epsilon': norm_epsilon
    }
    if use_sync_bn:
      self._norm = tf_keras.layers.experimental.SyncBatchNormalization
    else:
      self._norm = tf_keras.layers.BatchNormalization
    if tf_keras.backend.image_data_format() == 'channels_last':
      self._bn_axis = -1
    else:
      self._bn_axis = 1
    self._activation_fn = tf_utils.get_activation(activation)

  def build(self, input_shape):
    conv_kwargs = {
        'padding': 'same',
        'use_bias': self._config_dict['use_bias'],
        'kernel_initializer': self._config_dict['kernel_initializer'],
        'kernel_regularizer': self._config_dict['kernel_regularizer'],
        'bias_regularizer': self._config_dict['bias_regularizer'],
    }

    self._conv0 = tf_keras.layers.Conv2D(
        filters=self._config_dict['filters'],
        kernel_size=self._config_dict['kernel_size'],
        strides=self._config_dict['strides'],
        dilation_rate=self._config_dict['dilation_rate'],
        **conv_kwargs)
    self._norm0 = self._norm(
        axis=self._bn_axis,
        momentum=self._config_dict['norm_momentum'],
        epsilon=self._config_dict['norm_epsilon'])

    super(ConvBlock, self).build(input_shape)

  def get_config(self):
    return self._config_dict

  def call(self, inputs, training=None):
    x = self._conv0(inputs)
    x = self._norm0(x)
    x = self._activation_fn(x)

    return x


@tf_keras.utils.register_keras_serializable(package='Vision')
class ResBlock(tf_keras.layers.Layer):
  """A residual block."""

  def __init__(self,
               filters,
               strides,
               use_projection=False,
               kernel_initializer='VarianceScaling',
               kernel_regularizer=None,
               bias_regularizer=None,
               activation='relu',
               use_sync_bn=False,
               use_bias=False,
               norm_momentum=0.99,
               norm_epsilon=0.001,
               **kwargs):
    """Initializes a residual block with BN after convolutions.

    Args:
      filters: An `int` number of filters for the first two convolutions. Note
        that the third and final convolution will use 4 times as many filters.
      strides: An `int` block stride. If greater than 1, this block will
        ultimately downsample the input.
      use_projection: A `bool` for whether this block should use a projection
        shortcut (versus the default identity shortcut). This is usually `True`
        for the first block of a block group, which may change the number of
        filters and the resolution.
      kernel_initializer: A `str` of kernel_initializer for convolutional
        layers.
      kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
        Conv2D. Default to None.
      bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2d.
        Default to None.
      activation: A `str` name of the activation function.
      use_sync_bn: A `bool`. If True, use synchronized batch normalization.
      use_bias: A `bool`. If True, use bias in conv2d.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
      **kwargs: Additional keyword arguments to be passed.
    """
    super(ResBlock, self).__init__(**kwargs)
    self._config_dict = {
        'filters': filters,
        'strides': strides,
        'use_projection': use_projection,
        'kernel_initializer': kernel_initializer,
        'kernel_regularizer': kernel_regularizer,
        'bias_regularizer': bias_regularizer,
        'activation': activation,
        'use_sync_bn': use_sync_bn,
        'use_bias': use_bias,
        'norm_momentum': norm_momentum,
        'norm_epsilon': norm_epsilon
    }
    if use_sync_bn:
      self._norm = tf_keras.layers.experimental.SyncBatchNormalization
    else:
      self._norm = tf_keras.layers.BatchNormalization
    if tf_keras.backend.image_data_format() == 'channels_last':
      self._bn_axis = -1
    else:
      self._bn_axis = 1
    self._activation_fn = tf_utils.get_activation(activation)

  def build(self, input_shape):
    conv_kwargs = {
        'filters': self._config_dict['filters'],
        'padding': 'same',
        'use_bias': self._config_dict['use_bias'],
        'kernel_initializer': self._config_dict['kernel_initializer'],
        'kernel_regularizer': self._config_dict['kernel_regularizer'],
        'bias_regularizer': self._config_dict['bias_regularizer'],
    }

    if self._config_dict['use_projection']:
      self._shortcut = tf_keras.layers.Conv2D(
          filters=self._config_dict['filters'],
          kernel_size=1,
          strides=self._config_dict['strides'],
          use_bias=self._config_dict['use_bias'],
          kernel_initializer=self._config_dict['kernel_initializer'],
          kernel_regularizer=self._config_dict['kernel_regularizer'],
          bias_regularizer=self._config_dict['bias_regularizer'])
      self._norm0 = self._norm(
          axis=self._bn_axis,
          momentum=self._config_dict['norm_momentum'],
          epsilon=self._config_dict['norm_epsilon'])

    self._conv1 = tf_keras.layers.Conv2D(
        kernel_size=3,
        strides=self._config_dict['strides'],
        **conv_kwargs)
    self._norm1 = self._norm(
        axis=self._bn_axis,
        momentum=self._config_dict['norm_momentum'],
        epsilon=self._config_dict['norm_epsilon'])

    self._conv2 = tf_keras.layers.Conv2D(
        kernel_size=3,
        strides=1,
        **conv_kwargs)
    self._norm2 = self._norm(
        axis=self._bn_axis,
        momentum=self._config_dict['norm_momentum'],
        epsilon=self._config_dict['norm_epsilon'])

    super(ResBlock, self).build(input_shape)

  def get_config(self):
    return self._config_dict

  def call(self, inputs, training=None):
    shortcut = inputs
    if self._config_dict['use_projection']:
      shortcut = self._shortcut(shortcut)
      shortcut = self._norm0(shortcut)

    x = self._conv1(inputs)
    x = self._norm1(x)
    x = self._activation_fn(x)

    x = self._conv2(x)
    x = self._norm2(x)

    return self._activation_fn(x + shortcut)