tensorflow/models

View on GitHub
official/vision/tasks/image_classification.py

Summary

Maintainability
C
1 day
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Image classification task definition."""
from typing import Any, List, Optional, Tuple

from absl import logging
import tensorflow as tf, tf_keras

from official.common import dataset_fn
from official.core import base_task
from official.core import task_factory
from official.modeling import tf_utils
from official.vision.configs import image_classification as exp_cfg
from official.vision.dataloaders import classification_input
from official.vision.dataloaders import input_reader
from official.vision.dataloaders import input_reader_factory
from official.vision.dataloaders import tfds_factory
from official.vision.modeling import factory
from official.vision.ops import augment


_EPSILON = 1e-6


@task_factory.register_task_cls(exp_cfg.ImageClassificationTask)
class ImageClassificationTask(base_task.Task):
  """A task for image classification."""

  def build_model(self):
    """Builds classification model."""
    input_specs = tf_keras.layers.InputSpec(
        shape=[None] + self.task_config.model.input_size)

    l2_weight_decay = self.task_config.losses.l2_weight_decay
    # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
    # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
    # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
    l2_regularizer = (tf_keras.regularizers.l2(
        l2_weight_decay / 2.0) if l2_weight_decay else None)

    model = factory.build_classification_model(
        input_specs=input_specs,
        model_config=self.task_config.model,
        l2_regularizer=l2_regularizer)

    if self.task_config.freeze_backbone:
      model.backbone.trainable = False

    # Builds the model
    dummy_inputs = tf_keras.Input(self.task_config.model.input_size)
    _ = model(dummy_inputs, training=False)
    return model

  def initialize(self, model: tf_keras.Model):
    """Loads pretrained checkpoint."""
    if not self.task_config.init_checkpoint:
      return

    ckpt_dir_or_file = self.task_config.init_checkpoint
    if tf.io.gfile.isdir(ckpt_dir_or_file):
      ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)

    # Restoring checkpoint.
    if self.task_config.init_checkpoint_modules == 'all':
      ckpt = tf.train.Checkpoint(model=model)
      status = ckpt.read(ckpt_dir_or_file)
      status.expect_partial().assert_existing_objects_matched()
    elif self.task_config.init_checkpoint_modules == 'backbone':
      ckpt = tf.train.Checkpoint(backbone=model.backbone)
      status = ckpt.read(ckpt_dir_or_file)
      status.expect_partial().assert_existing_objects_matched()
    else:
      raise ValueError(
          "Only 'all' or 'backbone' can be used to initialize the model.")

    logging.info('Finished loading pretrained checkpoint from %s',
                 ckpt_dir_or_file)

  def build_inputs(
      self,
      params: exp_cfg.DataConfig,
      input_context: Optional[tf.distribute.InputContext] = None
  ) -> tf.data.Dataset:
    """Builds classification input."""

    num_classes = self.task_config.model.num_classes
    input_size = self.task_config.model.input_size
    image_field_key = self.task_config.train_data.image_field_key
    label_field_key = self.task_config.train_data.label_field_key
    is_multilabel = self.task_config.train_data.is_multilabel

    if params.tfds_name:
      decoder = tfds_factory.get_classification_decoder(params.tfds_name)
    else:
      decoder = classification_input.Decoder(
          image_field_key=image_field_key, label_field_key=label_field_key,
          is_multilabel=is_multilabel)

    parser = classification_input.Parser(
        output_size=input_size[:2],
        num_classes=num_classes,
        image_field_key=image_field_key,
        label_field_key=label_field_key,
        decode_jpeg_only=params.decode_jpeg_only,
        aug_rand_hflip=params.aug_rand_hflip,
        aug_crop=params.aug_crop,
        aug_type=params.aug_type,
        color_jitter=params.color_jitter,
        random_erasing=params.random_erasing,
        is_multilabel=is_multilabel,
        dtype=params.dtype,
        center_crop_fraction=params.center_crop_fraction,
        tf_resize_method=params.tf_resize_method,
        three_augment=params.three_augment)

    postprocess_fn = None
    if params.mixup_and_cutmix:
      postprocess_fn = augment.MixupAndCutmix(
          mixup_alpha=params.mixup_and_cutmix.mixup_alpha,
          cutmix_alpha=params.mixup_and_cutmix.cutmix_alpha,
          prob=params.mixup_and_cutmix.prob,
          label_smoothing=params.mixup_and_cutmix.label_smoothing,
          num_classes=num_classes)

    def sample_fn(repeated_augment, dataset):
      weights = [1 / repeated_augment] * repeated_augment
      dataset = tf.data.Dataset.sample_from_datasets(
          datasets=[dataset] * repeated_augment,
          weights=weights,
          seed=None,
          stop_on_empty_dataset=True,
      )
      return dataset

    is_repeated_augment = (
        params.is_training
        and params.repeated_augment is not None
    )
    reader = input_reader_factory.input_reader_generator(
        params,
        dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
        decoder_fn=decoder.decode,
        combine_fn=input_reader.create_combine_fn(params),
        parser_fn=parser.parse_fn(params.is_training),
        postprocess_fn=postprocess_fn,
        sample_fn=(lambda ds: sample_fn(params.repeated_augment, ds))
        if is_repeated_augment
        else None,
    )

    dataset = reader.read(input_context=input_context)
    return dataset

  def build_losses(self,
                   labels: tf.Tensor,
                   model_outputs: tf.Tensor,
                   aux_losses: Optional[Any] = None) -> tf.Tensor:
    """Builds sparse categorical cross entropy loss.

    Args:
      labels: Input groundtruth labels.
      model_outputs: Output logits of the classifier.
      aux_losses: The auxiliarly loss tensors, i.e. `losses` in tf_keras.Model.

    Returns:
      The total loss tensor.
    """
    losses_config = self.task_config.losses
    is_multilabel = self.task_config.train_data.is_multilabel

    if not is_multilabel:
      if losses_config.use_binary_cross_entropy:
        total_loss = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=labels, logits=model_outputs
        )
        # Average over all object classes inside an image.
        total_loss = tf.reduce_mean(total_loss, axis=-1)
      elif losses_config.one_hot:
        total_loss = tf_keras.losses.categorical_crossentropy(
            labels,
            model_outputs,
            from_logits=True,
            label_smoothing=losses_config.label_smoothing)
      elif losses_config.soft_labels:
        total_loss = tf.nn.softmax_cross_entropy_with_logits(
            labels, model_outputs)
      else:
        total_loss = tf_keras.losses.sparse_categorical_crossentropy(
            labels, model_outputs, from_logits=True)
    else:
      # Multi-label binary cross entropy loss. This will apply `reduce_mean`.
      total_loss = tf_keras.losses.binary_crossentropy(
          labels,
          model_outputs,
          from_logits=True,
          label_smoothing=losses_config.label_smoothing,
          axis=-1)
      # Multiple num_classes to behave like `reduce_sum`.
      total_loss = total_loss * self.task_config.model.num_classes

    total_loss = tf_utils.safe_mean(total_loss)
    if aux_losses:
      total_loss += tf.add_n(aux_losses)

    total_loss = losses_config.loss_weight * total_loss
    return total_loss

  def build_metrics(self,
                    training: bool = True) -> List[tf_keras.metrics.Metric]:
    """Gets streaming metrics for training/validation."""
    is_multilabel = self.task_config.train_data.is_multilabel
    if not is_multilabel:
      k = self.task_config.evaluation.top_k
      if (self.task_config.losses.one_hot or
          self.task_config.losses.soft_labels):
        metrics = [
            tf_keras.metrics.CategoricalAccuracy(name='accuracy'),
            tf_keras.metrics.TopKCategoricalAccuracy(
                k=k, name='top_{}_accuracy'.format(k))]
        if hasattr(
            self.task_config.evaluation, 'precision_and_recall_thresholds'
        ) and self.task_config.evaluation.precision_and_recall_thresholds:
          thresholds = self.task_config.evaluation.precision_and_recall_thresholds  # pylint: disable=line-too-long
          # pylint:disable=g-complex-comprehension
          metrics += [
              tf_keras.metrics.Precision(
                  thresholds=th,
                  name='precision_at_threshold_{}'.format(th),
                  top_k=1) for th in thresholds
          ]
          metrics += [
              tf_keras.metrics.Recall(
                  thresholds=th,
                  name='recall_at_threshold_{}'.format(th),
                  top_k=1) for th in thresholds
          ]

          # Add per-class precision and recall.
          if hasattr(
              self.task_config.evaluation,
              'report_per_class_precision_and_recall'
          ) and self.task_config.evaluation.report_per_class_precision_and_recall:
            for class_id in range(self.task_config.model.num_classes):
              metrics += [
                  tf_keras.metrics.Precision(
                      thresholds=th,
                      class_id=class_id,
                      name=f'precision_at_threshold_{th}/{class_id}',
                      top_k=1) for th in thresholds
              ]
              metrics += [
                  tf_keras.metrics.Recall(
                      thresholds=th,
                      class_id=class_id,
                      name=f'recall_at_threshold_{th}/{class_id}',
                      top_k=1) for th in thresholds
              ]
              # pylint:enable=g-complex-comprehension
      else:
        metrics = [
            tf_keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
            tf_keras.metrics.SparseTopKCategoricalAccuracy(
                k=k, name='top_{}_accuracy'.format(k))]
    else:
      metrics = []
      # These metrics destablize the training if included in training. The jobs
      # fail due to OOM.
      # TODO(arashwan): Investigate adding following metric to train.
      if not training:
        metrics = [
            tf_keras.metrics.AUC(
                name='globalPR-AUC',
                curve='PR',
                multi_label=False,
                from_logits=True),
            tf_keras.metrics.AUC(
                name='meanPR-AUC',
                curve='PR',
                multi_label=True,
                num_labels=self.task_config.model.num_classes,
                from_logits=True),
        ]
    return metrics

  def train_step(self,
                 inputs: Tuple[Any, Any],
                 model: tf_keras.Model,
                 optimizer: tf_keras.optimizers.Optimizer,
                 metrics: Optional[List[Any]] = None):
    """Does forward and backward.

    Args:
      inputs: A tuple of input tensors of (features, labels).
      model: A tf_keras.Model instance.
      optimizer: The optimizer for this training step.
      metrics: A nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    features, labels = inputs

    is_multilabel = self.task_config.train_data.is_multilabel
    if self.task_config.losses.one_hot and not is_multilabel:
      labels = tf.one_hot(labels, self.task_config.model.num_classes)

    if self.task_config.losses.use_binary_cross_entropy:
      # BCE loss converts the multiclass classification to multilabel. The
      # corresponding label value of objects present in the image would be one.
      if self.task_config.train_data.mixup_and_cutmix is not None:
        # label values below off_value_threshold would be mapped to zero and
        # above that would be mapped to one. Negative labels are guaranteed to
        # have value less than or equal value of the off_value from mixup.
        off_value_threshold = (
            self.task_config.train_data.mixup_and_cutmix.label_smoothing
            / self.task_config.model.num_classes
        )
        labels = tf.where(
            tf.less(labels, off_value_threshold + _EPSILON), 0.0, 1.0)
      elif tf.rank(labels) == 1:
        labels = tf.one_hot(labels, self.task_config.model.num_classes)

    num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
    with tf.GradientTape() as tape:
      outputs = model(features, training=True)

      # Casting output layer as float32 is necessary when mixed_precision is
      # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
      outputs = tf.nest.map_structure(
          lambda x: tf.cast(x, tf.float32), outputs)

      # Computes per-replica loss.
      loss = self.build_losses(
          model_outputs=outputs,
          labels=labels,
          aux_losses=model.losses)
      # Scales loss as the default gradients allreduce performs sum inside the
      # optimizer.
      scaled_loss = loss / num_replicas

      # For mixed_precision policy, when LossScaleOptimizer is used, loss is
      # scaled for numerical stability.
      if isinstance(
          optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
        scaled_loss = optimizer.get_scaled_loss(scaled_loss)

    tvars = model.trainable_variables
    grads = tape.gradient(scaled_loss, tvars)
    # Scales back gradient before apply_gradients when LossScaleOptimizer is
    # used.
    if isinstance(
        optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
      grads = optimizer.get_unscaled_gradients(grads)
    optimizer.apply_gradients(list(zip(grads, tvars)))

    logs = {self.loss: loss}

    # Convert logits to softmax for metric computation if needed.
    if hasattr(self.task_config.model,
               'output_softmax') and self.task_config.model.output_softmax:
      outputs = tf.nn.softmax(outputs, axis=-1)
    if metrics:
      self.process_metrics(metrics, labels, outputs)
    elif model.compiled_metrics:
      self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
      logs.update({m.name: m.result() for m in model.metrics})
    return logs

  def validation_step(self,
                      inputs: Tuple[Any, Any],
                      model: tf_keras.Model,
                      metrics: Optional[List[Any]] = None):
    """Runs validatation step.

    Args:
      inputs: A tuple of input tensors of (features, labels).
      model: A tf_keras.Model instance.
      metrics: A nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    features, labels = inputs
    one_hot = self.task_config.losses.one_hot
    soft_labels = self.task_config.losses.soft_labels
    is_multilabel = self.task_config.train_data.is_multilabel
    # Note: `soft_labels`` only apply to the training phrase. In the validation
    # phrase, labels should still be integer ids and need to be converted to
    # one hot format.
    if (one_hot or soft_labels) and not is_multilabel:
      labels = tf.one_hot(labels, self.task_config.model.num_classes)

    outputs = self.inference_step(features, model)
    outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
    loss = self.build_losses(
        model_outputs=outputs,
        labels=labels,
        aux_losses=model.losses)

    logs = {self.loss: loss}
    # Convert logits to softmax for metric computation if needed.
    if hasattr(self.task_config.model,
               'output_softmax') and self.task_config.model.output_softmax:
      outputs = tf.nn.softmax(outputs, axis=-1)
    if metrics:
      self.process_metrics(metrics, labels, outputs)
    elif model.compiled_metrics:
      self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
      logs.update({m.name: m.result() for m in model.metrics})
    return logs

  def inference_step(self, inputs: tf.Tensor, model: tf_keras.Model):
    """Performs the forward step."""
    return model(inputs, training=False)