official/projects/basnet/modeling/nn_blocks.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains common building blocks for BasNet model."""
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
@tf_keras.utils.register_keras_serializable(package='Vision')
class ConvBlock(tf_keras.layers.Layer):
"""A (Conv+BN+Activation) block."""
def __init__(self,
filters,
strides,
dilation_rate=1,
kernel_size=3,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_bias=False,
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""A vgg block with BN after convolutions.
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
dilation_rate: `int`, dilation rate for conv layers.
kernel_size: `int`, kernel size of conv layers.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf_keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf_keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_bias: `bool`, whether or not use bias in conv layers.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super(ConvBlock, self).__init__(**kwargs)
self._config_dict = {
'filters': filters,
'kernel_size': kernel_size,
'strides': strides,
'dilation_rate': dilation_rate,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
'activation': activation,
'use_sync_bn': use_sync_bn,
'use_bias': use_bias,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon
}
if use_sync_bn:
self._norm = tf_keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf_keras.layers.BatchNormalization
if tf_keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
conv_kwargs = {
'padding': 'same',
'use_bias': self._config_dict['use_bias'],
'kernel_initializer': self._config_dict['kernel_initializer'],
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
}
self._conv0 = tf_keras.layers.Conv2D(
filters=self._config_dict['filters'],
kernel_size=self._config_dict['kernel_size'],
strides=self._config_dict['strides'],
dilation_rate=self._config_dict['dilation_rate'],
**conv_kwargs)
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])
super(ConvBlock, self).build(input_shape)
def get_config(self):
return self._config_dict
def call(self, inputs, training=None):
x = self._conv0(inputs)
x = self._norm0(x)
x = self._activation_fn(x)
return x
@tf_keras.utils.register_keras_serializable(package='Vision')
class ResBlock(tf_keras.layers.Layer):
"""A residual block."""
def __init__(self,
filters,
strides,
use_projection=False,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
use_bias=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""Initializes a residual block with BN after convolutions.
Args:
filters: An `int` number of filters for the first two convolutions. Note
that the third and final convolution will use 4 times as many filters.
strides: An `int` block stride. If greater than 1, this block will
ultimately downsample the input.
use_projection: A `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
kernel_initializer: A `str` of kernel_initializer for convolutional
layers.
kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
Conv2D. Default to None.
bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2d.
Default to None.
activation: A `str` name of the activation function.
use_sync_bn: A `bool`. If True, use synchronized batch normalization.
use_bias: A `bool`. If True, use bias in conv2d.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
**kwargs: Additional keyword arguments to be passed.
"""
super(ResBlock, self).__init__(**kwargs)
self._config_dict = {
'filters': filters,
'strides': strides,
'use_projection': use_projection,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
'activation': activation,
'use_sync_bn': use_sync_bn,
'use_bias': use_bias,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon
}
if use_sync_bn:
self._norm = tf_keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf_keras.layers.BatchNormalization
if tf_keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
conv_kwargs = {
'filters': self._config_dict['filters'],
'padding': 'same',
'use_bias': self._config_dict['use_bias'],
'kernel_initializer': self._config_dict['kernel_initializer'],
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
}
if self._config_dict['use_projection']:
self._shortcut = tf_keras.layers.Conv2D(
filters=self._config_dict['filters'],
kernel_size=1,
strides=self._config_dict['strides'],
use_bias=self._config_dict['use_bias'],
kernel_initializer=self._config_dict['kernel_initializer'],
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'])
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])
self._conv1 = tf_keras.layers.Conv2D(
kernel_size=3,
strides=self._config_dict['strides'],
**conv_kwargs)
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])
self._conv2 = tf_keras.layers.Conv2D(
kernel_size=3,
strides=1,
**conv_kwargs)
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._config_dict['norm_momentum'],
epsilon=self._config_dict['norm_epsilon'])
super(ResBlock, self).build(input_shape)
def get_config(self):
return self._config_dict
def call(self, inputs, training=None):
shortcut = inputs
if self._config_dict['use_projection']:
shortcut = self._shortcut(shortcut)
shortcut = self._norm0(shortcut)
x = self._conv1(inputs)
x = self._norm1(x)
x = self._activation_fn(x)
x = self._conv2(x)
x = self._norm2(x)
return self._activation_fn(x + shortcut)