tensorflow/models

View on GitHub
official/projects/videoglue/tasks/multihead_video_classification.py

Summary

Maintainability
A
2 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""HS Video Classification task."""
from typing import Any, List, Optional, Mapping

import tensorflow as tf, tf_keras

from official.core import task_factory
from official.projects.videoglue.configs import video_classification as exp_cfg
from official.projects.videoglue.datasets import dataset_factory
from official.projects.videoglue.tools import checkpoint_loader
from official.vision.tasks import video_classification


@task_factory.register_task_cls(exp_cfg.MultiHeadVideoClassificationTask)
class MultiHeadVideoClassificationTask(
    video_classification.VideoClassificationTask):
  """Internal video classification task."""

  def _is_multihead(self):
    """Reports the joint accuracy or not."""
    label_names = self.task_config.train_data.label_names
    is_multihead = isinstance(label_names, list) and len(label_names) > 1
    return is_multihead

  def _get_label_names(self):
    """Gets the label names."""
    if self._is_multihead():
      label_names = self.task_config.train_data.label_names
    else:
      label_names = [self.task_config.train_data.label_names]
    return label_names

  def build_inputs(self, params: exp_cfg.DataConfig, input_context=None):
    """Builds classification input."""
    augmentation_type = params.data_augmentation.type
    augmentation_params = params.data_augmentation.get().as_dict()

    randaug_params = None
    if params.randaug is not None:
      randaug_params = params.randaug.as_dict()

    autoaug_params = None
    if params.autoaug is not None:
      autoaug_params = params.autoaug.as_dict()

    mixup_cutmix_params = None
    if params.mixup_cutmix is not None:
      mixup_cutmix_params = params.mixup_cutmix.as_dict()

    dataset_config = {
        'is_training': params.is_training,
        'num_frames': params.feature_shape[0],
        'temporal_stride': params.temporal_stride,
        'sample_from_segments': params.sample_from_segments,
        'min_resize': params.min_resize,
        'crop_size': params.feature_shape[1],
        'zero_centering_image': params.zero_centering_image,
        'random_flip_image': params.random_flip_image,
        'num_test_clips': params.num_test_clips,
        'augmentation_type': augmentation_type,
        'augmentation_params': augmentation_params,
        'randaug_params': randaug_params,
        'autoaug_params': autoaug_params,
        'mixup_cutmix_params': mixup_cutmix_params,
        # TODO(lzyuan): Unify num_test_crops and multi_crop flags.
        'multi_crop': params.num_test_crops == 3,
    }
    data_loader = dataset_factory.DataLoader(
        params=params, dataset_config=dataset_config)
    return data_loader(input_context=input_context)

  def initialize(self, model: tf_keras.Model):
    """Loads pretrained checkpoint."""
    if not self.task_config.init_checkpoint:
      return

    checkpoint_loader.get_checkpoint_loader(
        model=model,
        init_checkpoint=self.task_config.init_checkpoint,
        init_checkpoint_type=self.task_config.init_checkpoint_modules)

  def build_metrics(self, training: bool = True):
    """Gets streaming metrics for training/validation."""
    metrics = []
    for label_name in self._get_label_names():
      if self._is_multilabel():
        metrics += [
            tf_keras.metrics.AUC(
                curve='ROC',
                multi_label=self._is_multilabel(),
                name=f'{label_name}/ROC-AUC'),
            tf_keras.metrics.RecallAtPrecision(
                precision=0.95, name=f'{label_name}/RecallAtPrecision95'),
            tf_keras.metrics.AUC(
                curve='PR',
                multi_label=self._is_multilabel(),
                name=f'{label_name}/PR-AUC'),
        ]
      else:
        metrics += [
            tf_keras.metrics.CategoricalAccuracy(
                name=f'{label_name}/accuracy'),
            tf_keras.metrics.TopKCategoricalAccuracy(
                k=1, name=f'{label_name}/top_1_accuracy'),
            tf_keras.metrics.TopKCategoricalAccuracy(
                k=5, name=f'{label_name}/top_5_accuracy')
        ]

    if self._is_multihead():
      metrics.append(
          tf_keras.metrics.Mean(name='label_joint/accuracy'))
    return metrics

  def process_metrics(self, metrics: List[Any],
                      labels: List[tf.Tensor],
                      model_outputs: List[tf.Tensor]):
    """Processes and updates metrics.

    Called when using custom training loop API.

    Args:
      metrics: a nested structure of metrics objects. The return of function
        self.build_metrics.
      labels: a nested structure of tensors contains labels.
      model_outputs: a list of output tensors. Assume the order is aligned with
        the label_names list.
    """
    for i, label_name in enumerate(self._get_label_names()):
      for metric in metrics:
        if label_name in metric.name:
          metric.update_state(labels[i], model_outputs[i])

    if self._is_multihead():

      def joint_accuracy_fn(y_true: List[tf.Tensor], y_pred: List[tf.Tensor]):
        """Calculates the joint accuracy of predictions."""
        hits = []
        for label, pred in zip(y_true, y_pred):
          label_id = tf.argmax(label, axis=-1)
          pred_id = tf.argmax(pred, axis=-1)
          hits.append(tf.equal(label_id, pred_id))
        hits = tf.math.reduce_all(tf.stack(hits, axis=-1), axis=-1)
        return tf.reduce_mean(tf.cast(hits, tf.float32))

      for metric in metrics:
        if 'joint' in metric.name:
          values = joint_accuracy_fn(labels, model_outputs)
          metric.update_state(values)

  def build_losses(self,
                   labels: List[tf.Tensor],
                   model_outputs: List[tf.Tensor],
                   aux_losses: Optional[Any] = None):
    """Builds losses."""

    all_losses = 0
    for model_output, label in zip(model_outputs, labels):
      losses = super().build_losses(model_outputs=model_output,
                                    labels=label,
                                    aux_losses=aux_losses)
      all_losses += losses[self.loss]
    return all_losses

  def train_step(self,
                 inputs: Mapping[str, Any],
                 model: tf_keras.Model,
                 optimizer: tf_keras.optimizers.Optimizer,
                 metrics: Optional[List[Any]] = None):
    """Does forward and backward pass.

    Args:
      inputs: a dictionary of input tensors.
      model: the model, forward pass definition.
      optimizer: the optimizer for this training step.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    features = inputs['image']
    labels = [inputs[k] for k in self._get_label_names()]

    num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
    with tf.GradientTape() as tape:
      outputs = model(features, training=True)
      # tf_keras.Model eliminates the list if the outputs list len is 1.
      # Recover it here to be compatible with multihead settings.
      outputs = [outputs] if isinstance(outputs, tf.Tensor) else outputs
      # Casting output layer as float32 is necessary when mixed_precision is
      # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
      outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)

      if self._is_multilabel():
        outputs = tf.nest.map_structure(tf.math.sigmoid, outputs)
      else:
        outputs = tf.nest.map_structure(tf.math.softmax, outputs)

      all_losses = self.build_losses(model_outputs=outputs,
                                     labels=labels,
                                     aux_losses=model.losses)

      # Scale loss as the default gradients allreduce performs sum inside the
      # optimizer.
      scaled_loss = all_losses / num_replicas

      # For mixed_precision policy, when LossScaleOptimizer is used, loss is
      # scaled for numerical stability.
      if isinstance(
          optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
        scaled_loss = optimizer.get_scaled_loss(scaled_loss)

    tvars = model.trainable_variables
    grads = tape.gradient(scaled_loss, tvars)
    # Scale back gradient before apply_gradients when LossScaleOptimizer is
    # used.
    if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
      grads = optimizer.get_unscaled_gradients(grads)
    optimizer.apply_gradients(list(zip(grads, tvars)))

    logs = {self.loss: all_losses}
    if metrics:
      self.process_metrics(metrics, labels, outputs)
      logs.update({m.name: m.result() for m in metrics})
    return logs

  def validation_step(self,
                      inputs: Mapping[str, tf.Tensor],
                      model: tf_keras.Model,
                      metrics: Optional[List[Any]] = None):
    """Validatation step.

    Args:
      inputs: a dictionary of input tensors.
      model: the keras.Model.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    features = inputs['image']
    labels = [inputs[k] for k in self._get_label_names()]

    input_partition_dims = self.task_config.eval_input_partition_dims
    if input_partition_dims:
      strategy = tf.distribute.get_strategy()
      features['image'] = strategy.experimental_split_to_logical_devices(
          features['image'], input_partition_dims)

    outputs = self.inference_step(features, model)
    # tf_keras.Model eliminates the list if the outputs list len is 1.
    # Recover it here to be compatible with multihead settings.
    outputs = [outputs] if isinstance(outputs, tf.Tensor) else outputs
    # Casting output layer as float32 is necessary when mixed_precision is
    # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
    outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
    all_losses = self.build_losses(model_outputs=outputs, labels=labels,
                                   aux_losses=model.losses)

    logs = {self.loss: all_losses}
    if metrics:
      self.process_metrics(metrics, labels, outputs)
      logs.update({m.name: m.result() for m in metrics})
    return logs