official/projects/s3d/modeling/net_utils.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Commonly used TensorFlow 2 network blocks."""
from typing import Any, Text, Sequence, Union
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
WEIGHT_INITIALIZER = {
'Xavier': tf_keras.initializers.GlorotUniform,
'Gaussian': lambda: tf_keras.initializers.RandomNormal(stddev=0.01),
}
initializers = tf_keras.initializers
regularizers = tf_keras.regularizers
def make_set_from_start_endpoint(start_endpoint: Text,
endpoints: Sequence[Text]):
"""Makes a subset of endpoints from the given starting position."""
if start_endpoint not in endpoints:
return set()
start_index = endpoints.index(start_endpoint)
return set(endpoints[start_index:])
def apply_depth_multiplier(d: Union[int, Sequence[Any]],
depth_multiplier: float):
"""Applies depth_multiplier recursively to ints."""
if isinstance(d, int):
return int(d * depth_multiplier)
else:
return [apply_depth_multiplier(x, depth_multiplier) for x in d]
class ParameterizedConvLayer(tf_keras.layers.Layer):
"""Convolution layer based on the input conv_type."""
def __init__(
self,
conv_type: Text,
kernel_size: int,
filters: int,
strides: Sequence[int],
rates: Sequence[int],
use_sync_bn: bool = False,
norm_momentum: float = 0.999,
norm_epsilon: float = 0.001,
temporal_conv_initializer: Union[
Text, initializers.Initializer] = 'glorot_uniform',
kernel_initializer: Union[Text,
initializers.Initializer] = 'truncated_normal',
kernel_regularizer: Union[Text, regularizers.Regularizer] = 'l2',
**kwargs):
super(ParameterizedConvLayer, self).__init__(**kwargs)
self._conv_type = conv_type
self._kernel_size = kernel_size
self._filters = filters
self._strides = strides
self._rates = rates
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if use_sync_bn:
self._norm = tf_keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf_keras.layers.BatchNormalization
if tf_keras.backend.image_data_format() == 'channels_last':
self._channel_axis = -1
else:
self._channel_axis = 1
self._temporal_conv_initializer = temporal_conv_initializer
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
def _build_conv_layer_params(self, input_shape):
"""Builds params for conv layers."""
conv_layer_params = []
if self._conv_type == '3d':
conv_layer_params.append(
dict(
filters=self._filters,
kernel_size=[self._kernel_size] * 3,
strides=self._strides,
dilation_rate=self._rates,
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
))
elif self._conv_type == '2d':
conv_layer_params.append(
dict(
filters=self._filters,
kernel_size=[1, self._kernel_size, self._kernel_size],
strides=[1, self._strides[1], self._strides[2]],
dilation_rate=[1, self._rates[1], self._rates[2]],
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
))
elif self._conv_type == '1+2d':
channels_in = input_shape[self._channel_axis]
conv_layer_params.append(
dict(
filters=channels_in,
kernel_size=[self._kernel_size, 1, 1],
strides=[self._strides[0], 1, 1],
dilation_rate=[self._rates[0], 1, 1],
kernel_initializer=tf_utils.clone_initializer(
self._temporal_conv_initializer),
))
conv_layer_params.append(
dict(
filters=self._filters,
kernel_size=[1, self._kernel_size, self._kernel_size],
strides=[1, self._strides[1], self._strides[2]],
dilation_rate=[1, self._rates[1], self._rates[2]],
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
))
elif self._conv_type == '2+1d':
conv_layer_params.append(
dict(
filters=self._filters,
kernel_size=[1, self._kernel_size, self._kernel_size],
strides=[1, self._strides[1], self._strides[2]],
dilation_rate=[1, self._rates[1], self._rates[2]],
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
))
conv_layer_params.append(
dict(
filters=self._filters,
kernel_size=[self._kernel_size, 1, 1],
strides=[self._strides[0], 1, 1],
dilation_rate=[self._rates[0], 1, 1],
kernel_initializer=tf_utils.clone_initializer(
self._temporal_conv_initializer),
))
elif self._conv_type == '1+1+1d':
conv_layer_params.append(
dict(
filters=self._filters,
kernel_size=[1, 1, self._kernel_size],
strides=[1, 1, self._strides[2]],
dilation_rate=[1, 1, self._rates[2]],
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
))
conv_layer_params.append(
dict(
filters=self._filters,
kernel_size=[1, self._kernel_size, 1],
strides=[1, self._strides[1], 1],
dilation_rate=[1, self._rates[1], 1],
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
))
conv_layer_params.append(
dict(
filters=self._filters,
kernel_size=[self._kernel_size, 1, 1],
strides=[self._strides[0], 1, 1],
dilation_rate=[self._rates[0], 1, 1],
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
))
else:
raise ValueError('Unsupported conv_type: {}'.format(self._conv_type))
return conv_layer_params
def _build_norm_layer_params(self, conv_param):
"""Builds params for the norm layer after one conv layer."""
return dict(
axis=self._channel_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
scale=False,
gamma_initializer='ones')
def _build_activation_layer_params(self, conv_param):
"""Builds params for the activation layer after one conv layer."""
return {}
def _append_conv_layer(self, param):
"""Appends conv, normalization and activation layers."""
self._parameterized_conv_layers.append(
tf_keras.layers.Conv3D(
padding='same',
use_bias=False,
kernel_regularizer=self._kernel_regularizer,
**param,
))
norm_layer_params = self._build_norm_layer_params(param)
self._parameterized_conv_layers.append(self._norm(**norm_layer_params))
relu_layer_params = self._build_activation_layer_params(param)
self._parameterized_conv_layers.append(
tf_keras.layers.Activation('relu', **relu_layer_params))
def build(self, input_shape):
self._parameterized_conv_layers = []
for conv_layer_param in self._build_conv_layer_params(input_shape):
self._append_conv_layer(conv_layer_param)
super(ParameterizedConvLayer, self).build(input_shape)
def call(self, inputs):
x = inputs
for layer in self._parameterized_conv_layers:
x = layer(x)
return x