official/projects/mosaic/modeling/mosaic_head.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains definitions of segmentation head of the MOSAIC model."""
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
from official.projects.mosaic.modeling import mosaic_blocks
@tf_keras.utils.register_keras_serializable(package='Vision')
class MosaicDecoderHead(tf_keras.layers.Layer):
"""Creates a MOSAIC decoder in segmentation head.
Reference:
[MOSAIC: Mobile Segmentation via decoding Aggregated Information and encoded
Context](https://arxiv.org/pdf/2112.11623.pdf)
"""
def __init__(
self,
num_classes: int,
decoder_input_levels: Optional[List[str]] = None,
decoder_stage_merge_styles: Optional[List[str]] = None,
decoder_filters: Optional[List[int]] = None,
decoder_projected_filters: Optional[List[int]] = None,
encoder_end_level: Optional[int] = 4,
use_additional_classifier_layer: bool = False,
classifier_kernel_size: int = 1,
activation: str = 'relu',
use_sync_bn: bool = False,
batchnorm_momentum: float = 0.99,
batchnorm_epsilon: float = 0.001,
kernel_initializer: str = 'GlorotUniform',
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
interpolation: str = 'bilinear',
bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
**kwargs):
"""Initializes a MOSAIC segmentation head.
Args:
num_classes: An `int` number of mask classification categories. The number
of classes does not include background class.
decoder_input_levels: A list of `str` specifying additional
input levels from the backbone outputs for mask refinement in decoder.
decoder_stage_merge_styles: A list of `str` specifying the merge style at
each stage of the decoder, merge styles can be 'concat_merge' or
'sum_merge'.
decoder_filters: A list of integers specifying the number of channels used
at each decoder stage. Note: this only has affects if the decoder merge
style is 'concat_merge'.
decoder_projected_filters: A list of integers specifying the number of
projected channels at the end of each decoder stage.
encoder_end_level: An optional integer specifying the output level of the
encoder stage, which is used if the input from the encoder to the
decoder head is a dictionary.
use_additional_classifier_layer: A `bool` specifying whether to use an
additional classifier layer or not. It must be True if the final decoder
projected filters does not match the `num_classes`.
classifier_kernel_size: An `int` number to specify the kernel size of the
classifier layer.
activation: A `str` that indicates which activation is used, e.g. 'relu',
'swish', etc.
use_sync_bn: A `bool` that indicates whether to use synchronized batch
normalization across different replicas.
batchnorm_momentum: A `float` of normalization momentum for the moving
average.
batchnorm_epsilon: A `float` added to variance to avoid dividing by zero.
kernel_initializer: Kernel initializer for conv layers. Defaults to
`glorot_uniform`.
kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
Conv2D. Default is None.
interpolation: The interpolation method for upsampling. Defaults to
`bilinear`.
bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D.
**kwargs: Additional keyword arguments to be passed.
"""
super(MosaicDecoderHead, self).__init__(**kwargs)
# Assuming 'decoder_input_levels' are sorted in descending order and the
# other setting are listed in the order according to 'decoder_input_levels'.
if decoder_input_levels is None:
decoder_input_levels = ['3', '2']
if decoder_stage_merge_styles is None:
decoder_stage_merge_styles = ['concat_merge', 'sum_merge']
if decoder_filters is None:
decoder_filters = [64, 64]
if decoder_projected_filters is None:
decoder_projected_filters = [32, 32]
self._decoder_input_levels = decoder_input_levels
self._decoder_stage_merge_styles = decoder_stage_merge_styles
self._decoder_filters = decoder_filters
self._decoder_projected_filters = decoder_projected_filters
if (len(decoder_input_levels) != len(decoder_stage_merge_styles) or
len(decoder_input_levels) != len(decoder_filters) or
len(decoder_input_levels) != len(decoder_projected_filters)):
raise ValueError('The number of Decoder inputs and settings must match.')
self._merge_stages = []
for (stage_merge_style, decoder_filter,
decoder_projected_filter) in zip(decoder_stage_merge_styles,
decoder_filters,
decoder_projected_filters):
if stage_merge_style == 'concat_merge':
concat_merge_stage = mosaic_blocks.DecoderConcatMergeBlock(
decoder_internal_depth=decoder_filter,
decoder_projected_depth=decoder_projected_filter,
output_size=(0, 0),
use_sync_bn=use_sync_bn,
batchnorm_momentum=batchnorm_momentum,
batchnorm_epsilon=batchnorm_epsilon,
activation=activation,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
interpolation=interpolation)
self._merge_stages.append(concat_merge_stage)
elif stage_merge_style == 'sum_merge':
sum_merge_stage = mosaic_blocks.DecoderSumMergeBlock(
decoder_projected_depth=decoder_projected_filter,
output_size=(0, 0),
use_sync_bn=use_sync_bn,
batchnorm_momentum=batchnorm_momentum,
batchnorm_epsilon=batchnorm_epsilon,
activation=activation,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
interpolation=interpolation)
self._merge_stages.append(sum_merge_stage)
else:
raise ValueError(
'A stage merge style in MOSAIC Decoder can only be concat_merge '
'or sum_merge.')
# Concat merge or sum merge does not require an additional classifer layer
# unless the final decoder projected filter does not match num_classes.
final_decoder_projected_filter = decoder_projected_filters[-1]
if (final_decoder_projected_filter != num_classes and
not use_additional_classifier_layer):
raise ValueError('Additional classifier layer is needed if final decoder '
'projected filters does not match num_classes!')
self._use_additional_classifier_layer = use_additional_classifier_layer
if use_additional_classifier_layer:
# This additional classification layer uses different kernel
# initializers and bias compared to earlier blocks.
self._pixelwise_classifier = tf_keras.layers.Conv2D(
name='pixelwise_classifier',
filters=num_classes,
kernel_size=classifier_kernel_size,
padding='same',
bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf_keras.initializers.RandomNormal(stddev=0.01),
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
use_bias=True)
self._activation_fn = tf_utils.get_activation(activation)
self._config_dict = {
'num_classes': num_classes,
'decoder_input_levels': decoder_input_levels,
'decoder_stage_merge_styles': decoder_stage_merge_styles,
'decoder_filters': decoder_filters,
'decoder_projected_filters': decoder_projected_filters,
'encoder_end_level': encoder_end_level,
'use_additional_classifier_layer': use_additional_classifier_layer,
'classifier_kernel_size': classifier_kernel_size,
'activation': activation,
'use_sync_bn': use_sync_bn,
'batchnorm_momentum': batchnorm_momentum,
'batchnorm_epsilon': batchnorm_epsilon,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'interpolation': interpolation,
'bias_regularizer': bias_regularizer
}
def call(self,
inputs: Tuple[Union[tf.Tensor, Mapping[str, tf.Tensor]],
Union[tf.Tensor, Mapping[str, tf.Tensor]]],
training: Optional[bool] = None) -> tf.Tensor:
"""Forward pass of the segmentation head.
It supports a tuple of 2 elements. Each element is a tensor or a tensor
dictionary. The first one is the final (low-resolution) encoder endpoints,
and the second one is higher-resolution backbone endpoints.
When inputs are tensors, they are from a single level of feature maps.
When inputs are dictionaries, they contain multiple levels of feature maps,
where the key is the level/index of feature map.
Note: 'level' denotes the number of 2x downsampling, defined in backbone.
Args:
inputs: A tuple of 2 elements, each element can either be a tensor
representing feature maps or 1 dictionary of tensors:
- key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors.
The first is encoder endpoints, and the second is backbone endpoints.
training: a `Boolean` indicating whether it is in `training` mode.
Returns:
segmentation mask prediction logits: A `tf.Tensor` representing the
output logits before the final segmentation mask.
"""
encoder_outputs = inputs[0]
backbone_outputs = inputs[1]
y = encoder_outputs[str(
self._config_dict['encoder_end_level'])] if isinstance(
encoder_outputs, dict) else encoder_outputs
if isinstance(backbone_outputs, dict):
for level, merge_stage in zip(
self._decoder_input_levels, self._merge_stages):
x = backbone_outputs[str(level)]
y = merge_stage([y, x], training=training)
else:
x = backbone_outputs
y = self._merge_stages[0]([y, x], training=training)
if self._use_additional_classifier_layer:
y = self._pixelwise_classifier(y)
y = self._activation_fn(y)
return y
def get_config(self) -> Dict[str, Any]:
"""Returns a config dictionary for initialization from serialization."""
base_config = super().get_config()
base_config.update(self._config_dict)
return base_config
@classmethod
def from_config(cls, config: Dict[str, Any]):
return cls(**config)