tensorflow/models

View on GitHub
official/vision/modeling/layers/nn_blocks_3d.py

Summary

Maintainability
C
1 day
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains common building blocks for 3D networks."""
# Import libraries
import tensorflow as tf, tf_keras

from official.modeling import tf_utils
from official.vision.modeling.layers import nn_layers


@tf_keras.utils.register_keras_serializable(package='Vision')
class SelfGating(tf_keras.layers.Layer):
  """Feature gating as used in S3D-G.

  This implements the S3D-G network from:
  Saining Xie, Chen Sun, Jonathan Huang, Zhuowen Tu, Kevin Murphy.
  Rethinking Spatiotemporal Feature Learning: Speed-Accuracy Trade-offs in Video
  Classification.
  (https://arxiv.org/pdf/1712.04851.pdf)
  """

  def __init__(self, filters, **kwargs):
    """Initializes a self-gating layer.

    Args:
      filters: An `int` number of filters for the convolutional layer.
      **kwargs: Additional keyword arguments to be passed.
    """
    super(SelfGating, self).__init__(**kwargs)
    self._filters = filters

  def build(self, input_shape):
    self._spatial_temporal_average = tf_keras.layers.GlobalAveragePooling3D()

    # No BN and activation after conv.
    self._transformer_w = tf_keras.layers.Conv3D(
        filters=self._filters,
        kernel_size=[1, 1, 1],
        use_bias=True,
        kernel_initializer=tf_keras.initializers.TruncatedNormal(
            mean=0.0, stddev=0.01))

    super(SelfGating, self).build(input_shape)

  def call(self, inputs):
    x = self._spatial_temporal_average(inputs)

    x = tf.expand_dims(x, 1)
    x = tf.expand_dims(x, 2)
    x = tf.expand_dims(x, 3)

    x = self._transformer_w(x)
    x = tf.nn.sigmoid(x)

    return tf.math.multiply(x, inputs)


@tf_keras.utils.register_keras_serializable(package='Vision')
class BottleneckBlock3D(tf_keras.layers.Layer):
  """Creates a 3D bottleneck block."""

  def __init__(self,
               filters,
               temporal_kernel_size,
               temporal_strides,
               spatial_strides,
               stochastic_depth_drop_rate=0.0,
               se_ratio=None,
               use_self_gating=False,
               kernel_initializer='VarianceScaling',
               kernel_regularizer=None,
               bias_regularizer=None,
               activation='relu',
               use_sync_bn=False,
               norm_momentum=0.99,
               norm_epsilon=0.001,
               **kwargs):
    """Initializes a 3D bottleneck block with BN after convolutions.

    Args:
      filters: An `int` number of filters for the first two convolutions. Note
        that the third and final convolution will use 4 times as many filters.
      temporal_kernel_size: An `int` of kernel size for the temporal
        convolutional layer.
      temporal_strides: An `int` of ftemporal stride for the temporal
        convolutional layer.
      spatial_strides: An `int` of spatial stride for the spatial convolutional
        layer.
      stochastic_depth_drop_rate: A `float` or None. If not None, drop rate for
        the stochastic depth layer.
      se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
      use_self_gating: A `bool` of whether to apply self-gating module or not.
      kernel_initializer: A `str` of kernel_initializer for convolutional
        layers.
      kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
        Conv2D. Default to None.
      bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2d.
        Default to None.
      activation: A `str` name of the activation function.
      use_sync_bn: A `bool`. If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
      **kwargs: Additional keyword arguments to be passed.
    """
    super(BottleneckBlock3D, self).__init__(**kwargs)

    self._filters = filters
    self._temporal_kernel_size = temporal_kernel_size
    self._spatial_strides = spatial_strides
    self._temporal_strides = temporal_strides
    self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
    self._use_self_gating = use_self_gating
    self._se_ratio = se_ratio
    self._use_sync_bn = use_sync_bn
    self._activation = activation
    self._kernel_initializer = kernel_initializer
    self._norm_momentum = norm_momentum
    self._norm_epsilon = norm_epsilon
    self._kernel_regularizer = kernel_regularizer
    self._bias_regularizer = bias_regularizer
    self._norm = tf_keras.layers.BatchNormalization

    if tf_keras.backend.image_data_format() == 'channels_last':
      self._bn_axis = -1
    else:
      self._bn_axis = 1
    self._activation_fn = tf_utils.get_activation(activation)

  def build(self, input_shape):
    self._shortcut_maxpool = tf_keras.layers.MaxPool3D(
        pool_size=[1, 1, 1],
        strides=[
            self._temporal_strides, self._spatial_strides, self._spatial_strides
        ])

    self._shortcut_conv = tf_keras.layers.Conv3D(
        filters=4 * self._filters,
        kernel_size=1,
        strides=[
            self._temporal_strides, self._spatial_strides, self._spatial_strides
        ],
        use_bias=False,
        kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer)
    self._norm0 = self._norm(
        axis=self._bn_axis,
        momentum=self._norm_momentum,
        epsilon=self._norm_epsilon,
        synchronized=self._use_sync_bn)

    self._temporal_conv = tf_keras.layers.Conv3D(
        filters=self._filters,
        kernel_size=[self._temporal_kernel_size, 1, 1],
        strides=[self._temporal_strides, 1, 1],
        padding='same',
        use_bias=False,
        kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer)
    self._norm1 = self._norm(
        axis=self._bn_axis,
        momentum=self._norm_momentum,
        epsilon=self._norm_epsilon,
        synchronized=self._use_sync_bn)

    self._spatial_conv = tf_keras.layers.Conv3D(
        filters=self._filters,
        kernel_size=[1, 3, 3],
        strides=[1, self._spatial_strides, self._spatial_strides],
        padding='same',
        use_bias=False,
        kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer)
    self._norm2 = self._norm(
        axis=self._bn_axis,
        momentum=self._norm_momentum,
        epsilon=self._norm_epsilon,
        synchronized=self._use_sync_bn)

    self._expand_conv = tf_keras.layers.Conv3D(
        filters=4 * self._filters,
        kernel_size=[1, 1, 1],
        strides=[1, 1, 1],
        padding='same',
        use_bias=False,
        kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer)
    self._norm3 = self._norm(
        axis=self._bn_axis,
        momentum=self._norm_momentum,
        epsilon=self._norm_epsilon,
        synchronized=self._use_sync_bn)

    if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1:
      self._squeeze_excitation = nn_layers.SqueezeExcitation(
          in_filters=self._filters * 4,
          out_filters=self._filters * 4,
          se_ratio=self._se_ratio,
          use_3d_input=True,
          kernel_initializer=tf_utils.clone_initializer(
              self._kernel_initializer),
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer)
    else:
      self._squeeze_excitation = None

    if self._stochastic_depth_drop_rate:
      self._stochastic_depth = nn_layers.StochasticDepth(
          self._stochastic_depth_drop_rate)
    else:
      self._stochastic_depth = None

    if self._use_self_gating:
      self._self_gating = SelfGating(filters=4 * self._filters)
    else:
      self._self_gating = None

    super(BottleneckBlock3D, self).build(input_shape)

  def get_config(self):
    config = {
        'filters': self._filters,
        'temporal_kernel_size': self._temporal_kernel_size,
        'temporal_strides': self._temporal_strides,
        'spatial_strides': self._spatial_strides,
        'use_self_gating': self._use_self_gating,
        'se_ratio': self._se_ratio,
        'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
        'kernel_initializer': self._kernel_initializer,
        'kernel_regularizer': self._kernel_regularizer,
        'bias_regularizer': self._bias_regularizer,
        'activation': self._activation,
        'use_sync_bn': self._use_sync_bn,
        'norm_momentum': self._norm_momentum,
        'norm_epsilon': self._norm_epsilon
    }
    base_config = super(BottleneckBlock3D, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def call(self, inputs, training=None):
    in_filters = inputs.shape.as_list()[-1]
    if in_filters == 4 * self._filters:
      if self._temporal_strides == 1 and self._spatial_strides == 1:
        shortcut = inputs
      else:
        shortcut = self._shortcut_maxpool(inputs)
    else:
      shortcut = self._shortcut_conv(inputs)
      shortcut = self._norm0(shortcut)

    x = self._temporal_conv(inputs)
    x = self._norm1(x)
    x = self._activation_fn(x)

    x = self._spatial_conv(x)
    x = self._norm2(x)
    x = self._activation_fn(x)

    x = self._expand_conv(x)
    x = self._norm3(x)

    # Apply self-gating, SE, stochastic depth.
    if self._self_gating:
      x = self._self_gating(x)
    if self._squeeze_excitation:
      x = self._squeeze_excitation(x)
    if self._stochastic_depth:
      x = self._stochastic_depth(x, training=training)

    # Apply activation before additional modules.
    x = self._activation_fn(x + shortcut)

    return x