tensorflow/models

View on GitHub
official/vision/modeling/decoders/nasfpn.py

Summary

Maintainability
C
1 day
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains definitions of NAS-FPN."""

from typing import Any, List, Mapping, Optional, Tuple

# Import libraries

from absl import logging
import tensorflow as tf, tf_keras

from official.modeling import hyperparams
from official.modeling import tf_utils
from official.vision.modeling.decoders import factory
from official.vision.ops import spatial_transform_ops


# The fixed NAS-FPN architecture discovered by NAS.
# Each element represents a specification of a building block:
#   (block_level, combine_fn, (input_offset0, input_offset1), is_output).
NASFPN_BLOCK_SPECS = [
    (4, 'attention', (1, 3), False),
    (4, 'sum', (1, 5), False),
    (3, 'sum', (0, 6), True),
    (4, 'sum', (6, 7), True),
    (5, 'attention', (7, 8), True),
    (7, 'attention', (6, 9), True),
    (6, 'attention', (9, 10), True),
]


class BlockSpec():
  """A container class that specifies the block configuration for NAS-FPN."""

  def __init__(self, level: int, combine_fn: str,
               input_offsets: Tuple[int, int], is_output: bool):
    self.level = level
    self.combine_fn = combine_fn
    self.input_offsets = input_offsets
    self.is_output = is_output


def build_block_specs(
    block_specs: Optional[List[Tuple[Any, ...]]] = None) -> List[BlockSpec]:
  """Builds the list of BlockSpec objects for NAS-FPN."""
  if not block_specs:
    block_specs = NASFPN_BLOCK_SPECS
  logging.info('Building NAS-FPN block specs: %s', block_specs)
  return [BlockSpec(*b) for b in block_specs]


@tf_keras.utils.register_keras_serializable(package='Vision')
class NASFPN(tf_keras.Model):
  """Creates a NAS-FPN model.

  This implements the paper:
  Golnaz Ghiasi, Tsung-Yi Lin, Ruoming Pang, Quoc V. Le.
  NAS-FPN: Learning Scalable Feature Pyramid Architecture for Object Detection.
  (https://arxiv.org/abs/1904.07392)
  """

  def __init__(
      self,
      input_specs: Mapping[str, tf.TensorShape],
      min_level: int = 3,
      max_level: int = 7,
      block_specs: Optional[List[BlockSpec]] = None,
      num_filters: int = 256,
      num_repeats: int = 5,
      use_separable_conv: bool = False,
      activation: str = 'relu',
      use_sync_bn: bool = False,
      norm_momentum: float = 0.99,
      norm_epsilon: float = 0.001,
      kernel_initializer: str = 'VarianceScaling',
      kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      **kwargs):
    """Initializes a NAS-FPN model.

    Args:
      input_specs: A `dict` of input specifications. A dictionary consists of
        {level: TensorShape} from a backbone.
      min_level: An `int` of minimum level in FPN output feature maps.
      max_level: An `int` of maximum level in FPN output feature maps.
      block_specs: a list of BlockSpec objects that specifies the NAS-FPN
        network topology. By default, the previously discovered architecture is
        used.
      num_filters: An `int` number of filters in FPN layers.
      num_repeats: number of repeats for feature pyramid network.
      use_separable_conv: A `bool`.  If True use separable convolution for
        convolution in FPN layers.
      activation: A `str` name of the activation function.
      use_sync_bn: A `bool`. If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
      kernel_initializer: A `str` name of kernel_initializer for convolutional
        layers.
      kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
        Conv2D. Default is None.
      bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D.
      **kwargs: Additional keyword arguments to be passed.
    """
    self._config_dict = {
        'input_specs': input_specs,
        'min_level': min_level,
        'max_level': max_level,
        'num_filters': num_filters,
        'num_repeats': num_repeats,
        'use_separable_conv': use_separable_conv,
        'activation': activation,
        'use_sync_bn': use_sync_bn,
        'norm_momentum': norm_momentum,
        'norm_epsilon': norm_epsilon,
        'kernel_initializer': kernel_initializer,
        'kernel_regularizer': kernel_regularizer,
        'bias_regularizer': bias_regularizer,
    }
    self._min_level = min_level
    self._max_level = max_level
    self._block_specs = (
        build_block_specs() if block_specs is None else block_specs
    )
    self._num_repeats = num_repeats
    self._conv_op = (tf_keras.layers.SeparableConv2D
                     if self._config_dict['use_separable_conv']
                     else tf_keras.layers.Conv2D)
    self._norm_op = tf_keras.layers.BatchNormalization
    if tf_keras.backend.image_data_format() == 'channels_last':
      self._bn_axis = -1
    else:
      self._bn_axis = 1
    self._norm_kwargs = {
        'axis': self._bn_axis,
        'momentum': self._config_dict['norm_momentum'],
        'epsilon': self._config_dict['norm_epsilon'],
        'synchronized': self._config_dict['use_sync_bn'],
    }
    self._activation = tf_utils.get_activation(activation)

    # Gets input feature pyramid from backbone.
    inputs = self._build_input_pyramid(input_specs, min_level)

    # Projects the input features.
    feats = []
    for level in range(self._min_level, self._max_level + 1):
      if str(level) in inputs.keys():
        feats.append(self._resample_feature_map(
            inputs[str(level)], level, level, self._config_dict['num_filters']))
      else:
        feats.append(self._resample_feature_map(
            feats[-1], level - 1, level, self._config_dict['num_filters']))

    # Repeatly builds the NAS-FPN modules.
    for _ in range(self._num_repeats):
      output_feats = self._build_feature_pyramid(feats)
      feats = [output_feats[level]
               for level in range(self._min_level, self._max_level + 1)]

    self._output_specs = {
        str(level): output_feats[level].get_shape()
        for level in range(min_level, max_level + 1)
    }
    output_feats = {str(level): output_feats[level]
                    for level in output_feats.keys()}
    super(NASFPN, self).__init__(inputs=inputs, outputs=output_feats, **kwargs)

  def _build_input_pyramid(self, input_specs: Mapping[str, tf.TensorShape],
                           min_level: int):
    assert isinstance(input_specs, dict)
    if min(input_specs.keys()) > str(min_level):
      raise ValueError(
          'Backbone min level should be less or equal to FPN min level')

    inputs = {}
    for level, spec in input_specs.items():
      inputs[level] = tf_keras.Input(shape=spec[1:])
    return inputs

  def _resample_feature_map(self,
                            inputs,
                            input_level,
                            target_level,
                            target_num_filters=256):
    x = inputs
    _, _, _, input_num_filters = x.get_shape().as_list()
    if input_num_filters != target_num_filters:
      x = self._conv_op(
          filters=target_num_filters,
          kernel_size=1,
          padding='same',
          **self._conv_kwargs)(x)
      x = self._norm_op(**self._norm_kwargs)(x)

    if input_level < target_level:
      stride = int(2 ** (target_level - input_level))
      return tf_keras.layers.MaxPool2D(
          pool_size=stride, strides=stride, padding='same')(x)
    if input_level > target_level:
      scale = int(2 ** (input_level - target_level))
      return spatial_transform_ops.nearest_upsampling(x, scale=scale)

    # Force output x to be the same dtype as mixed precision policy. This avoids
    # dtype mismatch when one input (by default float32 dtype) does not meet all
    # the above conditions and is output unchanged, while other inputs are
    # processed to have different dtype, e.g., using bfloat16 on TPU.
    compute_dtype = tf_keras.layers.Layer().dtype_policy.compute_dtype
    if (compute_dtype is not None) and (x.dtype != compute_dtype):
      return tf.cast(x, dtype=compute_dtype)
    else:
      return x

  @property
  def _conv_kwargs(self):
    if self._config_dict['use_separable_conv']:
      return {
          'depthwise_initializer': tf_keras.initializers.VarianceScaling(
              scale=2, mode='fan_out', distribution='untruncated_normal'),
          'pointwise_initializer': tf_keras.initializers.VarianceScaling(
              scale=2, mode='fan_out', distribution='untruncated_normal'),
          'bias_initializer': tf.zeros_initializer(),
          'depthwise_regularizer': self._config_dict['kernel_regularizer'],
          'pointwise_regularizer': self._config_dict['kernel_regularizer'],
          'bias_regularizer': self._config_dict['bias_regularizer'],
      }
    else:
      return {
          'kernel_initializer': tf_keras.initializers.VarianceScaling(
              scale=2, mode='fan_out', distribution='untruncated_normal'),
          'bias_initializer': tf.zeros_initializer(),
          'kernel_regularizer': self._config_dict['kernel_regularizer'],
          'bias_regularizer': self._config_dict['bias_regularizer'],
      }

  def _global_attention(self, feat0, feat1):
    m = tf.math.reduce_max(feat0, axis=[1, 2], keepdims=True)
    m = tf.math.sigmoid(m)
    return feat0 + feat1 * m

  def _build_feature_pyramid(self, feats):
    num_output_connections = [0] * len(feats)
    num_output_levels = self._max_level - self._min_level + 1
    feat_levels = list(range(self._min_level, self._max_level + 1))

    for i, block_spec in enumerate(self._block_specs):
      new_level = block_spec.level

      # Checks the range of input_offsets.
      for input_offset in block_spec.input_offsets:
        if input_offset >= len(feats):
          raise ValueError(
              'input_offset ({}) is larger than num feats({})'.format(
                  input_offset, len(feats)))
      input0 = block_spec.input_offsets[0]
      input1 = block_spec.input_offsets[1]

      # Update graph with inputs.
      node0 = feats[input0]
      node0_level = feat_levels[input0]
      num_output_connections[input0] += 1
      node0 = self._resample_feature_map(node0, node0_level, new_level)
      node1 = feats[input1]
      node1_level = feat_levels[input1]
      num_output_connections[input1] += 1
      node1 = self._resample_feature_map(node1, node1_level, new_level)

      # Combine node0 and node1 to create new feat.
      if block_spec.combine_fn == 'sum':
        new_node = node0 + node1
      elif block_spec.combine_fn == 'attention':
        if node0_level >= node1_level:
          new_node = self._global_attention(node0, node1)
        else:
          new_node = self._global_attention(node1, node0)
      else:
        raise ValueError('unknown combine_fn `{}`.'
                         .format(block_spec.combine_fn))

      # Add intermediate nodes that do not have any connections to output.
      if block_spec.is_output:
        for j, (feat, feat_level, num_output) in enumerate(
            zip(feats, feat_levels, num_output_connections)):
          if num_output == 0 and feat_level == new_level:
            num_output_connections[j] += 1

            feat_ = self._resample_feature_map(feat, feat_level, new_level)
            new_node += feat_

      new_node = self._activation(new_node)
      new_node = self._conv_op(
          filters=self._config_dict['num_filters'],
          kernel_size=(3, 3),
          padding='same',
          **self._conv_kwargs)(new_node)
      new_node = self._norm_op(**self._norm_kwargs)(new_node)

      feats.append(new_node)
      feat_levels.append(new_level)
      num_output_connections.append(0)

    output_feats = {}
    for i in range(len(feats) - num_output_levels, len(feats)):
      level = feat_levels[i]
      output_feats[level] = feats[i]
    logging.info('Output feature pyramid: %s', output_feats)
    return output_feats

  def get_config(self) -> Mapping[str, Any]:
    return self._config_dict

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)

  @property
  def output_specs(self) -> Mapping[str, tf.TensorShape]:
    """A dict of {level: TensorShape} pairs for the model output."""
    return self._output_specs


@factory.register_decoder_builder('nasfpn')
def build_nasfpn_decoder(
    input_specs: Mapping[str, tf.TensorShape],
    model_config: hyperparams.Config,
    l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None
) -> tf_keras.Model:
  """Builds NASFPN decoder from a config.

  Args:
    input_specs: A `dict` of input specifications. A dictionary consists of
      {level: TensorShape} from a backbone.
    model_config: A OneOfConfig. Model config.
    l2_regularizer: A `tf_keras.regularizers.Regularizer` instance. Default to
      None.

  Returns:
    A `tf_keras.Model` instance of the NASFPN decoder.

  Raises:
    ValueError: If the model_config.decoder.type is not `nasfpn`.
  """
  decoder_type = model_config.decoder.type
  decoder_cfg = model_config.decoder.get()
  if decoder_type != 'nasfpn':
    raise ValueError(f'Inconsistent decoder type {decoder_type}. '
                     'Need to be `nasfpn`.')

  norm_activation_config = model_config.norm_activation
  return NASFPN(
      input_specs=input_specs,
      min_level=model_config.min_level,
      max_level=model_config.max_level,
      num_filters=decoder_cfg.num_filters,
      num_repeats=decoder_cfg.num_repeats,
      use_separable_conv=decoder_cfg.use_separable_conv,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)