official/projects/deepmac_maskrcnn/modeling/heads/instance_heads.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Instance prediction heads."""
# Import libraries
from absl import logging
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
from official.projects.deepmac_maskrcnn.modeling.heads import hourglass_network
class DeepMaskHead(tf_keras.layers.Layer):
"""Creates a mask head."""
def __init__(self,
num_classes,
upsample_factor=2,
num_convs=4,
num_filters=256,
use_separable_conv=False,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_regularizer=None,
bias_regularizer=None,
class_agnostic=False,
convnet_variant='default',
**kwargs):
"""Initializes a mask head.
Args:
num_classes: An `int` of the number of classes.
upsample_factor: An `int` that indicates the upsample factor to generate
the final predicted masks. It should be >= 1.
num_convs: An `int` number that represents the number of the intermediate
convolution layers before the mask prediction layers.
num_filters: An `int` number that represents the number of filters of the
intermediate convolution layers.
use_separable_conv: A `bool` that indicates whether the separable
convolution layers is used.
activation: A `str` that indicates which activation is used, e.g. 'relu',
'swish', etc.
use_sync_bn: A `bool` that indicates whether to use synchronized batch
normalization across different replicas.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
Conv2D. Default is None.
bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D.
class_agnostic: A `bool`. If set, we use a single channel mask head that
is shared between all classes.
convnet_variant: A `str` denoting the architecture of network used in the
head. Supported options are 'default', 'hourglass20', 'hourglass52'
and 'hourglass100'.
**kwargs: Additional keyword arguments to be passed.
"""
super(DeepMaskHead, self).__init__(**kwargs)
self._config_dict = {
'num_classes': num_classes,
'upsample_factor': upsample_factor,
'num_convs': num_convs,
'num_filters': num_filters,
'use_separable_conv': use_separable_conv,
'activation': activation,
'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
'class_agnostic': class_agnostic,
'convnet_variant': convnet_variant,
}
if tf_keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation = tf_utils.get_activation(activation)
def _get_conv_op_and_kwargs(self):
conv_op = (tf_keras.layers.SeparableConv2D
if self._config_dict['use_separable_conv']
else tf_keras.layers.Conv2D)
conv_kwargs = {
'filters': self._config_dict['num_filters'],
'kernel_size': 3,
'padding': 'same',
}
if self._config_dict['use_separable_conv']:
conv_kwargs.update({
'depthwise_initializer': tf_keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
'pointwise_initializer': tf_keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
'bias_initializer': tf.zeros_initializer(),
'depthwise_regularizer': self._config_dict['kernel_regularizer'],
'pointwise_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
})
else:
conv_kwargs.update({
'kernel_initializer': tf_keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
'bias_initializer': tf.zeros_initializer(),
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
})
return conv_op, conv_kwargs
def _get_bn_op_and_kwargs(self):
bn_op = (tf_keras.layers.experimental.SyncBatchNormalization
if self._config_dict['use_sync_bn']
else tf_keras.layers.BatchNormalization)
bn_kwargs = {
'axis': self._bn_axis,
'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'],
}
return bn_op, bn_kwargs
def build(self, input_shape):
"""Creates the variables of the head."""
conv_op, conv_kwargs = self._get_conv_op_and_kwargs()
self._build_convnet_variant()
self._deconv = tf_keras.layers.Conv2DTranspose(
filters=self._config_dict['num_filters'],
kernel_size=self._config_dict['upsample_factor'],
strides=self._config_dict['upsample_factor'],
padding='valid',
kernel_initializer=tf_keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
bias_initializer=tf.zeros_initializer(),
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'],
name='mask-upsampling')
bn_op, bn_kwargs = self._get_bn_op_and_kwargs()
self._deconv_bn = bn_op(name='mask-deconv-bn', **bn_kwargs)
if self._config_dict['class_agnostic']:
num_filters = 1
else:
num_filters = self._config_dict['num_classes']
conv_kwargs = {
'filters': num_filters,
'kernel_size': 1,
'padding': 'valid',
}
if self._config_dict['use_separable_conv']:
conv_kwargs.update({
'depthwise_initializer': tf_keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
'pointwise_initializer': tf_keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
'bias_initializer': tf.zeros_initializer(),
'depthwise_regularizer': self._config_dict['kernel_regularizer'],
'pointwise_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
})
else:
conv_kwargs.update({
'kernel_initializer': tf_keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
'bias_initializer': tf.zeros_initializer(),
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
})
self._mask_regressor = conv_op(name='mask-logits', **conv_kwargs)
super(DeepMaskHead, self).build(input_shape)
def call(self, inputs, training=None):
"""Forward pass of mask branch for the Mask-RCNN model.
Args:
inputs: A `list` of two tensors where
inputs[0]: A `tf.Tensor` of shape [batch_size, num_instances,
roi_height, roi_width, roi_channels], representing the ROI features.
inputs[1]: A `tf.Tensor` of shape [batch_size, num_instances],
representing the classes of the ROIs.
training: A `bool` indicating whether it is in `training` mode.
Returns:
mask_outputs: A `tf.Tensor` of shape
[batch_size, num_instances, roi_height * upsample_factor,
roi_width * upsample_factor], representing the mask predictions.
"""
roi_features, roi_classes = inputs
features_shape = tf.shape(roi_features)
num_rois, height, width, filters = (
features_shape[1],
features_shape[2],
features_shape[3],
features_shape[4],
)
x = tf.reshape(roi_features, [-1, height, width, filters])
x = self._call_convnet_variant(x)
x = self._deconv(x)
x = self._deconv_bn(x)
x = self._activation(x)
logits = self._mask_regressor(x)
mask_height = height * self._config_dict['upsample_factor']
mask_width = width * self._config_dict['upsample_factor']
if self._config_dict['class_agnostic']:
return tf.reshape(logits, [-1, num_rois, mask_height, mask_width])
else:
logits = tf.reshape(
logits,
[-1, num_rois, mask_height, mask_width,
self._config_dict['num_classes']])
return tf.gather(
logits, tf.cast(roi_classes, dtype=tf.int32), axis=-1, batch_dims=2
)
def _build_convnet_variant(self):
variant = self._config_dict['convnet_variant']
if variant == 'default':
bn_op, bn_kwargs = self._get_bn_op_and_kwargs()
self._convs = []
self._conv_norms = []
for i in range(self._config_dict['num_convs']):
conv_name = 'mask-conv_{}'.format(i)
conv_op, conv_kwargs = self._get_conv_op_and_kwargs()
self._convs.append(conv_op(name=conv_name, **conv_kwargs))
bn_name = 'mask-conv-bn_{}'.format(i)
self._conv_norms.append(bn_op(name=bn_name, **bn_kwargs))
elif variant == 'hourglass20':
logging.info('Using hourglass 20 network.')
self._hourglass = hourglass_network.hourglass_20(
self._config_dict['num_filters'], initial_downsample=False)
elif variant == 'hourglass52':
logging.info('Using hourglass 52 network.')
self._hourglass = hourglass_network.hourglass_52(
self._config_dict['num_filters'], initial_downsample=False)
elif variant == 'hourglass100':
logging.info('Using hourglass 100 network.')
self._hourglass = hourglass_network.hourglass_100(
self._config_dict['num_filters'], initial_downsample=False)
else:
raise ValueError('Unknown ConvNet variant - {}'.format(variant))
def _call_convnet_variant(self, x):
variant = self._config_dict['convnet_variant']
if variant == 'default':
for conv, bn in zip(self._convs, self._conv_norms):
x = conv(x)
x = bn(x)
x = self._activation(x)
return x
elif variant == 'hourglass20':
return self._hourglass(x)[-1]
elif variant == 'hourglass52':
return self._hourglass(x)[-1]
elif variant == 'hourglass100':
return self._hourglass(x)[-1]
else:
raise ValueError('Unknown ConvNet variant - {}'.format(variant))
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config):
return cls(**config)