tensorflow/models

View on GitHub
official/projects/yt8m/modeling/backbones/dbof.py

Summary

Maintainability
A
2 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Dbof model definitions."""

import functools
from typing import Any, Optional

import tensorflow as tf, tf_keras

from official.modeling import hyperparams
from official.modeling import tf_utils
from official.projects.yt8m.configs import yt8m as yt8m_cfg
from official.projects.yt8m.modeling import nn_layers
from official.projects.yt8m.modeling import yt8m_model_utils
from official.vision.configs import common
from official.vision.modeling.backbones import factory


layers = tf_keras.layers


class Dbof(layers.Layer):
  """A YT8M model class builder.

  Creates a Deep Bag of Frames model.
  The model projects the features for each frame into a higher dimensional
  'clustering' space, pools across frames in that space, and then
  uses a configurable video-level model to classify the now aggregated features.
  The model will randomly sample either frames or sequences of frames during
  training to speed up convergence.
  """

  def __init__(
      self,
      input_specs: layers.InputSpec = layers.InputSpec(
          shape=[None, None, 1152]
      ),
      params: yt8m_cfg.DbofModel = yt8m_cfg.DbofModel(),
      norm_activation: common.NormActivation = common.NormActivation(),
      l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      **kwargs,
  ):
    """YT8M initialization function.

    Args:
      input_specs: `tf_keras.layers.InputSpec` specs of the input tensor.
        [batch_size x num_frames x num_features].
      params: model configuration parameters.
      norm_activation: Model normalization and activation configs.
      l2_regularizer: An optional kernel weight regularizer.
      **kwargs: keyword arguments to be passed.
    """
    super().__init__(**kwargs)
    self._input_specs = input_specs
    self._params = params
    self._norm_activation = norm_activation
    self._l2_regularizer = l2_regularizer
    self._act_fn = tf_utils.get_activation(self._norm_activation.activation)
    self._norm = functools.partial(
        layers.BatchNormalization,
        momentum=self._norm_activation.norm_momentum,
        epsilon=self._norm_activation.norm_epsilon,
        synchronized=self._norm_activation.use_sync_bn,
    )
    feature_size = input_specs.shape[-1]

    # Configure model batch norm layer.
    if self._params.add_batch_norm:
      self._input_bn = self._norm(name="input_bn")
      self._cluster_bn = self._norm(name="cluster_bn")
      self._hidden_bn = self._norm(name="hidden_bn")
    else:
      self._hidden_biases = self.add_weight(
          name="hidden_biases",
          shape=[self._params.hidden_size],
          initializer=tf.random_normal_initializer(stddev=0.01),
      )
      self._cluster_biases = self.add_weight(
          name="cluster_biases",
          shape=[self._params.cluster_size],
          initializer=tf.random_normal_initializer(
              stddev=1.0 / tf.math.sqrt(feature_size)
          ),
      )

    if self._params.use_context_gate_cluster_layer:
      self._context_gate = nn_layers.ContextGate(
          normalizer_fn=self._norm,
          pooling_method=None,
          hidden_layer_size=self._params.context_gate_cluster_bottleneck_size,
          kernel_regularizer=self._l2_regularizer,
          name="context_gate_cluster",
      )

    self._hidden_dense = layers.Dense(
        self._params.hidden_size,
        kernel_regularizer=self._l2_regularizer,
        kernel_initializer=tf.random_normal_initializer(
            stddev=1.0 / tf.sqrt(tf.cast(self._params.cluster_size, tf.float32))
        ),
        name="hidden_dense",
    )

    if self._params.cluster_size > 0:
      self._cluster_dense = layers.Dense(
          self._params.cluster_size,
          kernel_regularizer=self._l2_regularizer,
          kernel_initializer=tf.random_normal_initializer(
              stddev=1.0 / tf.sqrt(tf.cast(feature_size, tf.float32))
          ),
          name="cluster_dense",
      )

  def call(
      self, inputs: tf.Tensor, num_frames: Any = None,
  ) -> tf.Tensor:
    # L2 normalize input features
    activation = tf.nn.l2_normalize(inputs, -1)

    if self._params.add_batch_norm:
      activation = self._input_bn(activation)

    if self._params.cluster_size > 0:
      activation = self._cluster_dense(activation)
      if self._params.add_batch_norm:
        activation = self._cluster_bn(activation)
    if not self._params.add_batch_norm:
      activation += self._cluster_biases

    activation = self._act_fn(activation)

    if self._params.use_context_gate_cluster_layer:
      activation = self._context_gate(activation)

    activation = yt8m_model_utils.frame_pooling(
        activation,
        method=self._params.pooling_method,
        num_frames=num_frames,
    )

    activation = self._hidden_dense(activation)
    if self._params.add_batch_norm:
      activation = self._hidden_bn(activation)
    else:
      activation += self._hidden_biases

    activation = self._act_fn(activation)
    return activation


@factory.register_backbone_builder("dbof")
def build_dbof(
    input_specs: tf_keras.layers.InputSpec,
    backbone_config: hyperparams.Config,
    norm_activation_config: hyperparams.Config,
    l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
    **kwargs,
) -> tf_keras.Model:
  """Builds a dbof backbone from a config."""
  backbone_type = backbone_config.type
  backbone_cfg = backbone_config.get()
  assert backbone_type == "dbof", f"Inconsistent backbone type {backbone_type}"

  dbof = Dbof(
      input_specs=input_specs,
      params=backbone_cfg,
      norm_activation=norm_activation_config,
      l2_regularizer=l2_regularizer,
      **kwargs,
  )

  # Warmup calls to build model variables.
  dbof(tf_keras.Input(input_specs.shape[1:]))
  return dbof