tensorflow/models

View on GitHub
official/vision/modeling/heads/segmentation_heads.py

Summary

Maintainability
D
1 day
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains definitions of segmentation heads."""
from typing import List, Union, Optional, Mapping, Tuple, Any
import tensorflow as tf, tf_keras

from official.modeling import tf_utils
from official.vision.modeling.layers import nn_layers
from official.vision.ops import spatial_transform_ops


class MaskScoring(tf_keras.Model):
  """Creates a mask scoring layer.

  This implements mask scoring layer from the paper:

  Zhaojin Huang, Lichao Huang, Yongchao Gong, Chang Huang, Xinggang Wang.
  Mask Scoring R-CNN.
  (https://arxiv.org/pdf/1903.00241.pdf)
  """

  def __init__(
      self,
      num_classes: int,
      fc_input_size: List[int],
      num_convs: int = 3,
      num_filters: int = 256,
      use_depthwise_convolution: bool = False,
      fc_dims: int = 1024,
      num_fcs: int = 2,
      activation: str = 'relu',
      use_sync_bn: bool = False,
      norm_momentum: float = 0.99,
      norm_epsilon: float = 0.001,
      kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      **kwargs):

    """Initializes mask scoring layer.

    Args:
      num_classes: An `int` for number of classes.
      fc_input_size: A List of `int` for the input size of the
        fully connected layers.
      num_convs: An`int` for number of conv layers.
      num_filters: An `int` for the number of filters for conv layers.
      use_depthwise_convolution: A `bool`, whether or not using depthwise convs.
      fc_dims: An `int` number of filters for each fully connected layers.
      num_fcs: An `int` for number of fully connected layers.
      activation: A `str` name of the activation function.
      use_sync_bn: A bool, whether or not to use sync batch normalization.
      norm_momentum: A float for the momentum in BatchNorm. Defaults to 0.99.
      norm_epsilon: A float for the epsilon value in BatchNorm. Defaults to
        0.001.
      kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
        Conv2D. Default is None.
      bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D.
      **kwargs: Additional keyword arguments to be passed.
    """
    super(MaskScoring, self).__init__(**kwargs)

    self._config_dict = {
        'num_classes': num_classes,
        'num_convs': num_convs,
        'num_filters': num_filters,
        'fc_input_size': fc_input_size,
        'fc_dims': fc_dims,
        'num_fcs': num_fcs,
        'use_sync_bn': use_sync_bn,
        'use_depthwise_convolution': use_depthwise_convolution,
        'norm_momentum': norm_momentum,
        'norm_epsilon': norm_epsilon,
        'activation': activation,
        'kernel_regularizer': kernel_regularizer,
        'bias_regularizer': bias_regularizer,
    }

    if tf_keras.backend.image_data_format() == 'channels_last':
      self._bn_axis = -1
    else:
      self._bn_axis = 1
    self._activation = tf_utils.get_activation(activation)

  def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]):
    """Creates the variables of the mask scoring head."""
    conv_op = tf_keras.layers.Conv2D
    conv_kwargs = {
        'filters': self._config_dict['num_filters'],
        'kernel_size': 3,
        'padding': 'same',
    }
    conv_kwargs.update({
        'kernel_initializer': tf_keras.initializers.VarianceScaling(
            scale=2, mode='fan_out', distribution='untruncated_normal'),
        'bias_initializer': tf.zeros_initializer(),
        'kernel_regularizer': self._config_dict['kernel_regularizer'],
        'bias_regularizer': self._config_dict['bias_regularizer'],
    })
    bn_op = tf_keras.layers.BatchNormalization
    bn_kwargs = {
        'axis': self._bn_axis,
        'momentum': self._config_dict['norm_momentum'],
        'epsilon': self._config_dict['norm_epsilon'],
        'synchronized': self._config_dict['use_sync_bn'],
    }

    self._convs = []
    self._conv_norms = []
    for i in range(self._config_dict['num_convs']):
      if self._config_dict['use_depthwise_convolution']:
        self._convs.append(
            tf_keras.layers.DepthwiseConv2D(
                name='mask-scoring-depthwise-conv-{}'.format(i),
                kernel_size=3,
                padding='same',
                use_bias=False,
                depthwise_initializer=tf_keras.initializers.RandomNormal(
                    stddev=0.01),
                depthwise_regularizer=self._config_dict['kernel_regularizer'],
                depth_multiplier=1))
        norm_name = 'mask-scoring-depthwise-bn-{}'.format(i)
        self._conv_norms.append(bn_op(name=norm_name, **bn_kwargs))
      conv_name = 'mask-scoring-conv-{}'.format(i)
      if 'kernel_initializer' in conv_kwargs:
        conv_kwargs['kernel_initializer'] = tf_utils.clone_initializer(
            conv_kwargs['kernel_initializer'])
      if self._config_dict['use_depthwise_convolution']:
        conv_kwargs['kernel_size'] = 1
      self._convs.append(conv_op(name=conv_name, **conv_kwargs))
      bn_name = 'mask-scoring-bn-{}'.format(i)
      self._conv_norms.append(bn_op(name=bn_name, **bn_kwargs))

    self._fcs = []
    self._fc_norms = []
    for i in range(self._config_dict['num_fcs']):
      fc_name = 'mask-scoring-fc-{}'.format(i)
      self._fcs.append(
          tf_keras.layers.Dense(
              units=self._config_dict['fc_dims'],
              kernel_initializer=tf_keras.initializers.VarianceScaling(
                  scale=1 / 3.0, mode='fan_out', distribution='uniform'),
              kernel_regularizer=self._config_dict['kernel_regularizer'],
              bias_regularizer=self._config_dict['bias_regularizer'],
              name=fc_name))
      bn_name = 'mask-scoring-fc-bn-{}'.format(i)
      self._fc_norms.append(bn_op(name=bn_name, **bn_kwargs))

    self._classifier = tf_keras.layers.Dense(
        units=self._config_dict['num_classes'],
        kernel_initializer=tf_keras.initializers.RandomNormal(stddev=0.01),
        bias_initializer=tf.zeros_initializer(),
        kernel_regularizer=self._config_dict['kernel_regularizer'],
        bias_regularizer=self._config_dict['bias_regularizer'],
        name='iou-scores')

    super(MaskScoring, self).build(input_shape)

  def call(self, inputs: tf.Tensor, training: bool = None):  # pytype: disable=signature-mismatch  # overriding-parameter-count-checks
    """Forward pass mask scoring head.

    Args:
      inputs: A `tf.Tensor` of the shape [batch_size, width, size, num_classes],
      representing the segmentation logits.
      training: a `bool` indicating whether it is in `training` mode.

    Returns:
      mask_scores: A `tf.Tensor` of predicted mask scores
        [batch_size, num_classes].
    """
    x = tf.stop_gradient(inputs)
    for conv, bn in zip(self._convs, self._conv_norms):
      x = conv(x)
      x = bn(x)
      x = self._activation(x)

    # Casts feat to float32 so the resize op can be run on TPU.
    x = tf.cast(x, tf.float32)
    x = tf.image.resize(x, size=self._config_dict['fc_input_size'],
                        method=tf.image.ResizeMethod.BILINEAR)
    # Casts it back to be compatible with the rest opetations.
    x = tf.cast(x, inputs.dtype)

    _, h, w, filters = x.get_shape().as_list()
    x = tf.reshape(x, [-1, h * w * filters])

    for fc, bn in zip(self._fcs, self._fc_norms):
      x = fc(x)
      x = bn(x)
      x = self._activation(x)

    ious = self._classifier(x)
    return ious

  def get_config(self) -> Mapping[str, Any]:
    return self._config_dict

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)


@tf_keras.utils.register_keras_serializable(package='Vision')
class SegmentationHead(tf_keras.layers.Layer):
  """Creates a segmentation head."""

  def __init__(
      self,
      num_classes: int,
      level: Union[int, str],
      num_convs: int = 2,
      num_filters: int = 256,
      use_depthwise_convolution: bool = False,
      prediction_kernel_size: int = 1,
      upsample_factor: int = 1,
      feature_fusion: Optional[str] = None,
      decoder_min_level: Optional[int] = None,
      decoder_max_level: Optional[int] = None,
      low_level: int = 2,
      low_level_num_filters: int = 48,
      num_decoder_filters: int = 256,
      activation: str = 'relu',
      logit_activation: Optional[str] = None,
      use_sync_bn: bool = False,
      norm_momentum: float = 0.99,
      norm_epsilon: float = 0.001,
      kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      **kwargs):
    """Initializes a segmentation head.

    Args:
      num_classes: An `int` number of mask classification categories. The number
        of classes does not include background class.
      level: An `int` or `str`, level to use to build segmentation head.
      num_convs: An `int` number of stacked convolution before the last
        prediction layer.
      num_filters: An `int` number to specify the number of filters used.
        Default is 256.
      use_depthwise_convolution: A bool to specify if use depthwise separable
        convolutions.
      prediction_kernel_size: An `int` number to specify the kernel size of the
      prediction layer.
      upsample_factor: An `int` number to specify the upsampling factor to
        generate finer mask. Default 1 means no upsampling is applied.
      feature_fusion: One of the constants in nn_layers.FeatureFusion, namely
        `deeplabv3plus`, `pyramid_fusion`, `panoptic_fpn_fusion`,
        `deeplabv3plus_sum_to_merge`, or None. If `deeplabv3plus`, features from
        decoder_features[level] will be fused with low level feature maps from
        backbone. If `pyramid_fusion`, multiscale features will be resized and
        fused at the target level.
      decoder_min_level: An `int` of minimum level from decoder to use in
        feature fusion. It is only used when feature_fusion is set to
        `panoptic_fpn_fusion`.
      decoder_max_level: An `int` of maximum level from decoder to use in
        feature fusion. It is only used when feature_fusion is set to
        `panoptic_fpn_fusion`.
      low_level: An `int` of backbone level to be used for feature fusion. It is
        used when feature_fusion is set to `deeplabv3plus` or
        `deeplabv3plus_sum_to_merge`.
      low_level_num_filters: An `int` of reduced number of filters for the low
        level features before fusing it with higher level features. It is only
        used when feature_fusion is set to `deeplabv3plus` or
        `deeplabv3plus_sum_to_merge`.
      num_decoder_filters: An `int` of number of filters in the decoder outputs.
        It is only used when feature_fusion is set to `panoptic_fpn_fusion`.
      activation: A `str` that indicates which activation is used, e.g. 'relu',
        'swish', etc.
      logit_activation: Activation applied to the final classifier layer logits,
        e.g. 'sigmoid', 'softmax'. Can be useful in cases when the task does not
        use only cross entropy loss.
      use_sync_bn: A `bool` that indicates whether to use synchronized batch
        normalization across different replicas.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
      kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
        Conv2D. Default is None.
      bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D.
      **kwargs: Additional keyword arguments to be passed.
    """
    super(SegmentationHead, self).__init__(**kwargs)

    self._config_dict = {
        'num_classes': num_classes,
        'level': level,
        'num_convs': num_convs,
        'num_filters': num_filters,
        'use_depthwise_convolution': use_depthwise_convolution,
        'prediction_kernel_size': prediction_kernel_size,
        'upsample_factor': upsample_factor,
        'feature_fusion': feature_fusion,
        'decoder_min_level': decoder_min_level,
        'decoder_max_level': decoder_max_level,
        'low_level': low_level,
        'low_level_num_filters': low_level_num_filters,
        'num_decoder_filters': num_decoder_filters,
        'activation': activation,
        'logit_activation': logit_activation,
        'use_sync_bn': use_sync_bn,
        'norm_momentum': norm_momentum,
        'norm_epsilon': norm_epsilon,
        'kernel_regularizer': kernel_regularizer,
        'bias_regularizer': bias_regularizer
    }
    if tf_keras.backend.image_data_format() == 'channels_last':
      self._bn_axis = -1
    else:
      self._bn_axis = 1
    self._activation = tf_utils.get_activation(activation)

  def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]):
    """Creates the variables of the segmentation head."""
    use_depthwise_convolution = self._config_dict['use_depthwise_convolution']
    conv_op = tf_keras.layers.Conv2D
    bn_op = tf_keras.layers.BatchNormalization
    bn_kwargs = {
        'axis': self._bn_axis,
        'momentum': self._config_dict['norm_momentum'],
        'epsilon': self._config_dict['norm_epsilon'],
        'synchronized': self._config_dict['use_sync_bn'],
    }

    if self._config_dict['feature_fusion'] in {'deeplabv3plus',
                                               'deeplabv3plus_sum_to_merge'}:
      # Deeplabv3+ feature fusion layers.
      self._dlv3p_conv = conv_op(
          kernel_size=1,
          padding='same',
          use_bias=False,
          kernel_initializer=tf_keras.initializers.RandomNormal(stddev=0.01),
          kernel_regularizer=self._config_dict['kernel_regularizer'],
          name='segmentation_head_deeplabv3p_fusion_conv',
          filters=self._config_dict['low_level_num_filters'])

      self._dlv3p_norm = bn_op(
          name='segmentation_head_deeplabv3p_fusion_norm', **bn_kwargs)

    elif self._config_dict['feature_fusion'] == 'panoptic_fpn_fusion':
      self._panoptic_fpn_fusion = nn_layers.PanopticFPNFusion(
          min_level=self._config_dict['decoder_min_level'],
          max_level=self._config_dict['decoder_max_level'],
          target_level=self._config_dict['level'],
          num_filters=self._config_dict['num_filters'],
          num_fpn_filters=self._config_dict['num_decoder_filters'],
          activation=self._config_dict['activation'],
          kernel_regularizer=self._config_dict['kernel_regularizer'],
          bias_regularizer=self._config_dict['bias_regularizer'])

    # Segmentation head layers.
    self._convs = []
    self._norms = []
    for i in range(self._config_dict['num_convs']):
      if use_depthwise_convolution:
        self._convs.append(
            tf_keras.layers.DepthwiseConv2D(
                name='segmentation_head_depthwise_conv_{}'.format(i),
                kernel_size=3,
                padding='same',
                use_bias=False,
                depthwise_initializer=tf_keras.initializers.RandomNormal(
                    stddev=0.01),
                depthwise_regularizer=self._config_dict['kernel_regularizer'],
                depth_multiplier=1))
        norm_name = 'segmentation_head_depthwise_norm_{}'.format(i)
        self._norms.append(bn_op(name=norm_name, **bn_kwargs))
      conv_name = 'segmentation_head_conv_{}'.format(i)
      self._convs.append(
          conv_op(
              name=conv_name,
              filters=self._config_dict['num_filters'],
              kernel_size=3 if not use_depthwise_convolution else 1,
              padding='same',
              use_bias=False,
              kernel_initializer=tf_keras.initializers.RandomNormal(
                  stddev=0.01),
              kernel_regularizer=self._config_dict['kernel_regularizer']))
      norm_name = 'segmentation_head_norm_{}'.format(i)
      self._norms.append(bn_op(name=norm_name, **bn_kwargs))

    self._classifier = conv_op(
        name='segmentation_output',
        filters=self._config_dict['num_classes'],
        kernel_size=self._config_dict['prediction_kernel_size'],
        padding='same',
        activation=self._config_dict['logit_activation'],
        bias_initializer=tf.zeros_initializer(),
        kernel_initializer=tf_keras.initializers.RandomNormal(stddev=0.01),
        kernel_regularizer=self._config_dict['kernel_regularizer'],
        bias_regularizer=self._config_dict['bias_regularizer'])

    super().build(input_shape)

  def call(self, inputs: Tuple[Union[tf.Tensor, Mapping[str, tf.Tensor]],
                               Union[tf.Tensor, Mapping[str, tf.Tensor]]]):
    """Forward pass of the segmentation head.

    It supports both a tuple of 2 tensors or 2 dictionaries. The first is
    backbone endpoints, and the second is decoder endpoints. When inputs are
    tensors, they are from a single level of feature maps. When inputs are
    dictionaries, they contain multiple levels of feature maps, where the key
    is the index of feature map.

    Args:
      inputs: A tuple of 2 feature map tensors of shape
        [batch, height_l, width_l, channels] or 2 dictionaries of tensors:
        - key: A `str` of the level of the multilevel features.
        - values: A `tf.Tensor` of the feature map tensors, whose shape is
            [batch, height_l, width_l, channels].
        The first is backbone endpoints, and the second is decoder endpoints.
    Returns:
      segmentation prediction mask: A `tf.Tensor` of the segmentation mask
        scores predicted from input features.
    """

    backbone_output = inputs[0]
    decoder_output = inputs[1]
    if self._config_dict['feature_fusion'] in {'deeplabv3plus',
                                               'deeplabv3plus_sum_to_merge'}:
      # deeplabv3+ feature fusion
      x = decoder_output[str(self._config_dict['level'])] if isinstance(
          decoder_output, dict) else decoder_output
      y = backbone_output[str(self._config_dict['low_level'])] if isinstance(
          backbone_output, dict) else backbone_output
      y = self._dlv3p_norm(self._dlv3p_conv(y))
      y = self._activation(y)

      x = tf.image.resize(
          x, tf.shape(y)[1:3], method=tf.image.ResizeMethod.BILINEAR)
      x = tf.cast(x, dtype=y.dtype)
      if self._config_dict['feature_fusion'] == 'deeplabv3plus':
        x = tf.concat([x, y], axis=self._bn_axis)
      else:
        x = tf_keras.layers.Add()([x, y])
    elif self._config_dict['feature_fusion'] == 'pyramid_fusion':
      if not isinstance(decoder_output, dict):
        raise ValueError('Only support dictionary decoder_output.')
      x = nn_layers.pyramid_feature_fusion(decoder_output,
                                           self._config_dict['level'])
    elif self._config_dict['feature_fusion'] == 'panoptic_fpn_fusion':
      x = self._panoptic_fpn_fusion(decoder_output)
    else:
      x = decoder_output[str(self._config_dict['level'])] if isinstance(
          decoder_output, dict) else decoder_output

    for conv, norm in zip(self._convs, self._norms):
      x = conv(x)
      x = norm(x)
      x = self._activation(x)
    if self._config_dict['upsample_factor'] > 1:
      x = spatial_transform_ops.nearest_upsampling(
          x, scale=self._config_dict['upsample_factor'])

    return self._classifier(x)

  def get_config(self):
    base_config = super().get_config()
    return dict(list(base_config.items()) + list(self._config_dict.items()))

  @classmethod
  def from_config(cls, config):
    return cls(**config)