tensorflow/models

View on GitHub
research/slim/nets/s3dg.py

Summary

Maintainability
F
6 days
Test Coverage
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains the definition for Gated Separable 3D network (S3D-G).

The network architecture is proposed by:
  Saining Xie, Chen Sun, Jonathan Huang, Zhuowen Tu and Kevin Murphy,
  Rethinking Spatiotemporal Feature Learning For Video Understanding.
  https://arxiv.org/abs/1712.04851.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow.compat.v1 as tf
import tf_slim as slim

from nets import i3d_utils

# pylint: disable=g-long-lambda
trunc_normal = lambda stddev: tf.truncated_normal_initializer(
    0.0, stddev)
conv3d_spatiotemporal = i3d_utils.conv3d_spatiotemporal
inception_block_v1_3d = i3d_utils.inception_block_v1_3d


arg_scope = slim.arg_scope


def s3dg_arg_scope(weight_decay=1e-7,
                   batch_norm_decay=0.999,
                   batch_norm_epsilon=0.001):
  """Defines default arg_scope for S3D-G.

  Args:
    weight_decay: The weight decay to use for regularizing the model.
    batch_norm_decay: Decay for batch norm moving average.
    batch_norm_epsilon: Small float added to variance to avoid dividing by zero
      in batch norm.

  Returns:
    sc: An arg_scope to use for the models.
  """
  batch_norm_params = {
      # Decay for the moving averages.
      'decay': batch_norm_decay,
      # epsilon to prevent 0s in variance.
      'epsilon': batch_norm_epsilon,
      # Turns off fused batch norm.
      'fused': False,
      # collection containing the moving mean and moving variance.
      'variables_collections': {
          'beta': None,
          'gamma': None,
          'moving_mean': ['moving_vars'],
          'moving_variance': ['moving_vars'],
      }
  }

  with arg_scope([slim.conv3d, conv3d_spatiotemporal],
                 weights_regularizer=slim.l2_regularizer(weight_decay),
                 activation_fn=tf.nn.relu,
                 normalizer_fn=slim.batch_norm,
                 normalizer_params=batch_norm_params):
    with arg_scope([conv3d_spatiotemporal], separable=True) as sc:
      return sc


def self_gating(input_tensor, scope, data_format='NDHWC'):
  """Feature gating as used in S3D-G.

  Transforms the input features by aggregating features from all
  spatial and temporal locations, and applying gating conditioned
  on the aggregated features. More details can be found at:
  https://arxiv.org/abs/1712.04851

  Args:
    input_tensor: A 5-D float tensor of size [batch_size, num_frames,
      height, width, channels].
    scope: scope for `variable_scope`.
    data_format: An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC".
      The data format of the input and output data. With the default format
      "NDHWC", the data is stored in the order of: [batch, in_depth, in_height,
      in_width, in_channels]. Alternatively, the format could be "NCDHW", the
      data storage order is:
      [batch, in_channels, in_depth, in_height, in_width].

  Returns:
    A tensor with the same shape as input_tensor.
  """

  index_c = data_format.index('C')
  index_d = data_format.index('D')
  index_h = data_format.index('H')
  index_w = data_format.index('W')
  input_shape = input_tensor.get_shape().as_list()
  t = input_shape[index_d]
  w = input_shape[index_w]
  h = input_shape[index_h]
  num_channels = input_shape[index_c]

  spatiotemporal_average = slim.avg_pool3d(
      input_tensor, [t, w, h],
      stride=1,
      data_format=data_format,
      scope=scope + '/self_gating/avg_pool3d')

  weights = slim.conv3d(
      spatiotemporal_average,
      num_channels, [1, 1, 1],
      activation_fn=None,
      normalizer_fn=None,
      biases_initializer=None,
      data_format=data_format,
      weights_initializer=trunc_normal(0.01),
      scope=scope + '/self_gating/transformer_W')

  tile_multiples = [1, t, w, h]
  tile_multiples.insert(index_c, 1)
  weights = tf.tile(weights, tile_multiples)
  weights = tf.nn.sigmoid(weights)

  return tf.multiply(weights, input_tensor)


def s3dg_base(inputs,
              first_temporal_kernel_size=3,
              temporal_conv_startat='Conv2d_2c_3x3',
              gating_startat='Conv2d_2c_3x3',
              final_endpoint='Mixed_5c',
              min_depth=16,
              depth_multiplier=1.0,
              data_format='NDHWC',
              scope='InceptionV1'):
  """Defines the I3D/S3DG base architecture.

  Note that we use the names as defined in Inception V1 to facilitate checkpoint
  conversion from an image-trained Inception V1 checkpoint to I3D checkpoint.

  Args:
    inputs: A 5-D float tensor of size [batch_size, num_frames, height, width,
      channels].
    first_temporal_kernel_size: Specifies the temporal kernel size for the first
      conv3d filter. A larger value slows down the model but provides little
      accuracy improvement. The default is 7 in the original I3D and S3D-G but 3
      gives better performance. Must be set to one of 1, 3, 5 or 7.
    temporal_conv_startat: Specifies the first conv block to use 3D or separable
      3D convs rather than 2D convs (implemented as [1, k, k] 3D conv). This is
      used to construct the inverted pyramid models. 'Conv2d_2c_3x3' is the
      first valid block to use separable 3D convs. If provided block name is
      not present, all valid blocks will use separable 3D convs. Note that
      'Conv2d_1a_7x7' cannot be made into a separable 3D conv, but can be made
      into a 2D or 3D conv using the `first_temporal_kernel_size` option.
    gating_startat: Specifies the first conv block to use self gating.
      'Conv2d_2c_3x3' is the first valid block to use self gating. If provided
      block name is not present, all valid blocks will use separable 3D convs.
    final_endpoint: Specifies the endpoint to construct the network up to. It
      can be one of ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
      'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c',
      'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e',
      'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b', 'Mixed_5c']
    min_depth: Minimum depth value (number of channels) for all convolution ops.
      Enforced when depth_multiplier < 1, and not an active constraint when
      depth_multiplier >= 1.
    depth_multiplier: Float multiplier for the depth (number of channels)
      for all convolution ops. The value must be greater than zero. Typical
      usage will be to set this value in (0, 1) to reduce the number of
      parameters or computation cost of the model.
    data_format: An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC".
      The data format of the input and output data. With the default format
      "NDHWC", the data is stored in the order of: [batch, in_depth, in_height,
      in_width, in_channels]. Alternatively, the format could be "NCDHW", the
      data storage order is:
      [batch, in_channels, in_depth, in_height, in_width].
    scope: Optional variable_scope.

  Returns:
    A dictionary from components of the network to the corresponding activation.

  Raises:
    ValueError: if final_endpoint is not set to one of the predefined values, or
      if depth_multiplier <= 0.
  """

  assert data_format in ['NDHWC', 'NCDHW']
  end_points = {}
  t = 1
  # For inverted pyramid models, we start with gating switched off.
  use_gating = False
  self_gating_fn = None
  def gating_fn(inputs, scope):
    return self_gating(inputs, scope, data_format=data_format)

  if depth_multiplier <= 0:
    raise ValueError('depth_multiplier is not greater than zero.')
  depth = lambda d: max(int(d * depth_multiplier), min_depth)

  with tf.variable_scope(scope, 'InceptionV1', [inputs]):
    with arg_scope([slim.conv3d], weights_initializer=trunc_normal(0.01)):
      with arg_scope([slim.conv3d, slim.max_pool3d, conv3d_spatiotemporal],
                     stride=1,
                     data_format=data_format,
                     padding='SAME'):
        # batch_size x 32 x 112 x 112 x 64
        end_point = 'Conv2d_1a_7x7'
        if first_temporal_kernel_size not in [1, 3, 5, 7]:
          raise ValueError(
              'first_temporal_kernel_size can only be 1, 3, 5 or 7.')
        # Separable conv is slow when used at first conv layer.
        net = conv3d_spatiotemporal(
            inputs,
            depth(64), [first_temporal_kernel_size, 7, 7],
            stride=2,
            separable=False,
            scope=end_point)
        end_points[end_point] = net
        if final_endpoint == end_point:
          return net, end_points
        # batch_size x 32 x 56 x 56 x 64
        end_point = 'MaxPool_2a_3x3'
        net = slim.max_pool3d(net, [1, 3, 3], stride=[1, 2, 2], scope=end_point)
        end_points[end_point] = net
        if final_endpoint == end_point:
          return net, end_points
        # batch_size x 32 x 56 x 56 x 64
        end_point = 'Conv2d_2b_1x1'
        net = slim.conv3d(net, depth(64), [1, 1, 1], scope=end_point)
        end_points[end_point] = net
        if final_endpoint == end_point:
          return net, end_points
        # batch_size x 32 x 56 x 56 x 192
        end_point = 'Conv2d_2c_3x3'
        if temporal_conv_startat == end_point:
          t = 3
        if gating_startat == end_point:
          use_gating = True
          self_gating_fn = gating_fn
        net = conv3d_spatiotemporal(net, depth(192), [t, 3, 3], scope=end_point)
        if use_gating:
          net = self_gating(net, scope=end_point, data_format=data_format)
        end_points[end_point] = net
        if final_endpoint == end_point:
          return net, end_points
        # batch_size x 32 x 28 x 28 x 192
        end_point = 'MaxPool_3a_3x3'
        net = slim.max_pool3d(net, [1, 3, 3], stride=[1, 2, 2], scope=end_point)
        end_points[end_point] = net
        if final_endpoint == end_point:
          return net, end_points

        # batch_size x 32 x 28 x 28 x 256
        end_point = 'Mixed_3b'
        if temporal_conv_startat == end_point:
          t = 3
        if gating_startat == end_point:
          use_gating = True
          self_gating_fn = gating_fn
        net = inception_block_v1_3d(
            net,
            num_outputs_0_0a=depth(64),
            num_outputs_1_0a=depth(96),
            num_outputs_1_0b=depth(128),
            num_outputs_2_0a=depth(16),
            num_outputs_2_0b=depth(32),
            num_outputs_3_0b=depth(32),
            temporal_kernel_size=t,
            self_gating_fn=self_gating_fn,
            data_format=data_format,
            scope=end_point)
        end_points[end_point] = net
        if final_endpoint == end_point:
          return net, end_points

        end_point = 'Mixed_3c'
        if temporal_conv_startat == end_point:
          t = 3
        if gating_startat == end_point:
          use_gating = True
          self_gating_fn = gating_fn
        net = inception_block_v1_3d(
            net,
            num_outputs_0_0a=depth(128),
            num_outputs_1_0a=depth(128),
            num_outputs_1_0b=depth(192),
            num_outputs_2_0a=depth(32),
            num_outputs_2_0b=depth(96),
            num_outputs_3_0b=depth(64),
            temporal_kernel_size=t,
            self_gating_fn=self_gating_fn,
            data_format=data_format,
            scope=end_point)
        end_points[end_point] = net
        if final_endpoint == end_point:
          return net, end_points

        end_point = 'MaxPool_4a_3x3'
        net = slim.max_pool3d(net, [3, 3, 3], stride=[2, 2, 2], scope=end_point)
        end_points[end_point] = net
        if final_endpoint == end_point:
          return net, end_points

        # batch_size x 16 x 14 x 14 x 512
        end_point = 'Mixed_4b'
        if temporal_conv_startat == end_point:
          t = 3
        if gating_startat == end_point:
          use_gating = True
          self_gating_fn = gating_fn
        net = inception_block_v1_3d(
            net,
            num_outputs_0_0a=depth(192),
            num_outputs_1_0a=depth(96),
            num_outputs_1_0b=depth(208),
            num_outputs_2_0a=depth(16),
            num_outputs_2_0b=depth(48),
            num_outputs_3_0b=depth(64),
            temporal_kernel_size=t,
            self_gating_fn=self_gating_fn,
            data_format=data_format,
            scope=end_point)
        end_points[end_point] = net
        if final_endpoint == end_point:
          return net, end_points

        # batch_size x 16 x 14 x 14 x 512
        end_point = 'Mixed_4c'
        if temporal_conv_startat == end_point:
          t = 3
        if gating_startat == end_point:
          use_gating = True
          self_gating_fn = gating_fn
        net = inception_block_v1_3d(
            net,
            num_outputs_0_0a=depth(160),
            num_outputs_1_0a=depth(112),
            num_outputs_1_0b=depth(224),
            num_outputs_2_0a=depth(24),
            num_outputs_2_0b=depth(64),
            num_outputs_3_0b=depth(64),
            temporal_kernel_size=t,
            self_gating_fn=self_gating_fn,
            data_format=data_format,
            scope=end_point)
        end_points[end_point] = net
        if final_endpoint == end_point:
          return net, end_points

        # batch_size x 16 x 14 x 14 x 512
        end_point = 'Mixed_4d'
        if temporal_conv_startat == end_point:
          t = 3
        if gating_startat == end_point:
          use_gating = True
          self_gating_fn = gating_fn
        net = inception_block_v1_3d(
            net,
            num_outputs_0_0a=depth(128),
            num_outputs_1_0a=depth(128),
            num_outputs_1_0b=depth(256),
            num_outputs_2_0a=depth(24),
            num_outputs_2_0b=depth(64),
            num_outputs_3_0b=depth(64),
            temporal_kernel_size=t,
            self_gating_fn=self_gating_fn,
            data_format=data_format,
            scope=end_point)
        end_points[end_point] = net
        if final_endpoint == end_point:
          return net, end_points

        # batch_size x 16 x 14 x 14 x 528
        end_point = 'Mixed_4e'
        if temporal_conv_startat == end_point:
          t = 3
        if gating_startat == end_point:
          use_gating = True
          self_gating_fn = gating_fn
        net = inception_block_v1_3d(
            net,
            num_outputs_0_0a=depth(112),
            num_outputs_1_0a=depth(144),
            num_outputs_1_0b=depth(288),
            num_outputs_2_0a=depth(32),
            num_outputs_2_0b=depth(64),
            num_outputs_3_0b=depth(64),
            temporal_kernel_size=t,
            self_gating_fn=self_gating_fn,
            data_format=data_format,
            scope=end_point)
        end_points[end_point] = net
        if final_endpoint == end_point:
          return net, end_points

        # batch_size x 16 x 14 x 14 x 832
        end_point = 'Mixed_4f'
        if temporal_conv_startat == end_point:
          t = 3
        if gating_startat == end_point:
          use_gating = True
          self_gating_fn = gating_fn
        net = inception_block_v1_3d(
            net,
            num_outputs_0_0a=depth(256),
            num_outputs_1_0a=depth(160),
            num_outputs_1_0b=depth(320),
            num_outputs_2_0a=depth(32),
            num_outputs_2_0b=depth(128),
            num_outputs_3_0b=depth(128),
            temporal_kernel_size=t,
            self_gating_fn=self_gating_fn,
            data_format=data_format,
            scope=end_point)
        end_points[end_point] = net
        if final_endpoint == end_point:
          return net, end_points

        end_point = 'MaxPool_5a_2x2'
        net = slim.max_pool3d(net, [2, 2, 2], stride=[2, 2, 2], scope=end_point)
        end_points[end_point] = net
        if final_endpoint == end_point:
          return net, end_points

        # batch_size x 8 x 7 x 7 x 832
        end_point = 'Mixed_5b'
        if temporal_conv_startat == end_point:
          t = 3
        if gating_startat == end_point:
          use_gating = True
          self_gating_fn = gating_fn
        net = inception_block_v1_3d(
            net,
            num_outputs_0_0a=depth(256),
            num_outputs_1_0a=depth(160),
            num_outputs_1_0b=depth(320),
            num_outputs_2_0a=depth(32),
            num_outputs_2_0b=depth(128),
            num_outputs_3_0b=depth(128),
            temporal_kernel_size=t,
            self_gating_fn=self_gating_fn,
            data_format=data_format,
            scope=end_point)
        end_points[end_point] = net
        if final_endpoint == end_point:
          return net, end_points

        # batch_size x 8 x 7 x 7 x 1024
        end_point = 'Mixed_5c'
        if temporal_conv_startat == end_point:
          t = 3
        if gating_startat == end_point:
          use_gating = True
          self_gating_fn = gating_fn
        net = inception_block_v1_3d(
            net,
            num_outputs_0_0a=depth(384),
            num_outputs_1_0a=depth(192),
            num_outputs_1_0b=depth(384),
            num_outputs_2_0a=depth(48),
            num_outputs_2_0b=depth(128),
            num_outputs_3_0b=depth(128),
            temporal_kernel_size=t,
            self_gating_fn=self_gating_fn,
            data_format=data_format,
            scope=end_point)
        end_points[end_point] = net
        if final_endpoint == end_point:
          return net, end_points
    raise ValueError('Unknown final endpoint %s' % final_endpoint)


def s3dg(inputs,
         num_classes=1000,
         first_temporal_kernel_size=3,
         temporal_conv_startat='Conv2d_2c_3x3',
         gating_startat='Conv2d_2c_3x3',
         final_endpoint='Mixed_5c',
         min_depth=16,
         depth_multiplier=1.0,
         dropout_keep_prob=0.8,
         is_training=True,
         prediction_fn=slim.softmax,
         spatial_squeeze=True,
         reuse=None,
         data_format='NDHWC',
         scope='InceptionV1'):
  """Defines the S3D-G architecture.

  The default image size used to train this network is 224x224.

  Args:
    inputs: A 5-D float tensor of size [batch_size, num_frames, height, width,
      channels].
    num_classes: number of predicted classes.
    first_temporal_kernel_size: Specifies the temporal kernel size for the first
      conv3d filter. A larger value slows down the model but provides little
      accuracy improvement. Must be set to one of 1, 3, 5 or 7.
    temporal_conv_startat: Specifies the first conv block to use separable 3D
      convs rather than 2D convs (implemented as [1, k, k] 3D conv). This is
      used to construct the inverted pyramid models. 'Conv2d_2c_3x3' is the
      first valid block to use separable 3D convs. If provided block name is
      not present, all valid blocks will use separable 3D convs.
    gating_startat: Specifies the first conv block to use self gating.
      'Conv2d_2c_3x3' is the first valid block to use self gating. If provided
      block name is not present, all valid blocks will use separable 3D convs.
    final_endpoint: Specifies the endpoint to construct the network up to. It
      can be one of ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
      'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c',
      'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e',
      'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b', 'Mixed_5c']
    min_depth: Minimum depth value (number of channels) for all convolution ops.
      Enforced when depth_multiplier < 1, and not an active constraint when
      depth_multiplier >= 1.
    depth_multiplier: Float multiplier for the depth (number of channels)
      for all convolution ops. The value must be greater than zero. Typical
      usage will be to set this value in (0, 1) to reduce the number of
      parameters or computation cost of the model.
    dropout_keep_prob: the percentage of activation values that are retained.
    is_training: whether is training or not.
    prediction_fn: a function to get predictions out of logits.
    spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
        of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
    reuse: whether or not the network and its variables should be reused. To be
      able to reuse 'scope' must be given.
    data_format: An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC".
      The data format of the input and output data. With the default format
      "NDHWC", the data is stored in the order of: [batch, in_depth, in_height,
      in_width, in_channels]. Alternatively, the format could be "NCDHW", the
      data storage order is:
      [batch, in_channels, in_depth, in_height, in_width].
    scope: Optional variable_scope.

  Returns:
    logits: the pre-softmax activations, a tensor of size
      [batch_size, num_classes]
    end_points: a dictionary from components of the network to the corresponding
      activation.
  """
  assert data_format in ['NDHWC', 'NCDHW']
  # Final pooling and prediction
  with tf.variable_scope(
      scope, 'InceptionV1', [inputs, num_classes], reuse=reuse) as scope:
    with arg_scope([slim.batch_norm, slim.dropout], is_training=is_training):
      net, end_points = s3dg_base(
          inputs,
          first_temporal_kernel_size=first_temporal_kernel_size,
          temporal_conv_startat=temporal_conv_startat,
          gating_startat=gating_startat,
          final_endpoint=final_endpoint,
          min_depth=min_depth,
          depth_multiplier=depth_multiplier,
          data_format=data_format,
          scope=scope)
      with tf.variable_scope('Logits'):
        if data_format.startswith('NC'):
          net = tf.transpose(a=net, perm=[0, 2, 3, 4, 1])
        kernel_size = i3d_utils.reduced_kernel_size_3d(net, [2, 7, 7])
        net = slim.avg_pool3d(
            net,
            kernel_size,
            stride=1,
            data_format='NDHWC',
            scope='AvgPool_0a_7x7')
        net = slim.dropout(net, dropout_keep_prob, scope='Dropout_0b')
        logits = slim.conv3d(
            net,
            num_classes, [1, 1, 1],
            activation_fn=None,
            normalizer_fn=None,
            data_format='NDHWC',
            scope='Conv2d_0c_1x1')
        # Temporal average pooling.
        logits = tf.reduce_mean(input_tensor=logits, axis=1)
        if spatial_squeeze:
          logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')

        end_points['Logits'] = logits
        end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
  return logits, end_points


s3dg.default_image_size = 224