tensorflow/models

View on GitHub
official/projects/panoptic/tasks/panoptic_deeplab.py

Summary

Maintainability
C
1 day
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Panoptic Deeplab task definition."""
from typing import Any, Dict, List, Mapping, Optional, Tuple

from absl import logging
import tensorflow as tf, tf_keras

from official.common import dataset_fn
from official.core import base_task
from official.core import task_factory
from official.projects.panoptic.configs import panoptic_deeplab as exp_cfg
from official.projects.panoptic.dataloaders import panoptic_deeplab_input
from official.projects.panoptic.losses import panoptic_deeplab_losses
from official.projects.panoptic.modeling import factory
from official.vision.dataloaders import input_reader_factory
from official.vision.evaluation import panoptic_quality_evaluator
from official.vision.evaluation import segmentation_metrics


@task_factory.register_task_cls(exp_cfg.PanopticDeeplabTask)
class PanopticDeeplabTask(base_task.Task):
  """A task for Panoptic Deeplab."""

  def build_model(self):
    """Builds panoptic deeplab model."""
    input_specs = tf_keras.layers.InputSpec(
        shape=[None] + self.task_config.model.input_size)

    l2_weight_decay = self.task_config.losses.l2_weight_decay
    # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
    # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
    # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
    l2_regularizer = (tf_keras.regularizers.l2(
        l2_weight_decay / 2.0) if l2_weight_decay else None)

    model = factory.build_panoptic_deeplab(
        input_specs=input_specs,
        model_config=self.task_config.model,
        l2_regularizer=l2_regularizer)

    # Builds the model through warm-up call.
    dummy_images = tf_keras.Input(self.task_config.model.input_size)
    # Note that image_info is always in the shape of [4, 2].
    dummy_image_info = tf_keras.layers.Input([4, 2])
    _ = model(dummy_images, dummy_image_info, training=False)
    return model

  def initialize(self, model: tf_keras.Model):
    """Loads pretrained checkpoint."""
    if not self.task_config.init_checkpoint:
      return

    ckpt_dir_or_file = self.task_config.init_checkpoint
    if tf.io.gfile.isdir(ckpt_dir_or_file):
      ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)

    # Restoring checkpoint.
    if 'all' in self.task_config.init_checkpoint_modules:
      ckpt = tf.train.Checkpoint(**model.checkpoint_items)
      status = ckpt.read(ckpt_dir_or_file)
      status.expect_partial().assert_existing_objects_matched()
    else:
      ckpt_items = {}
      if 'backbone' in self.task_config.init_checkpoint_modules:
        ckpt_items.update(backbone=model.backbone)
      if 'decoder' in self.task_config.init_checkpoint_modules:
        ckpt_items.update(semantic_decoder=model.semantic_decoder)
        if not self.task_config.model.shared_decoder:
          ckpt_items.update(instance_decoder=model.instance_decoder)

      ckpt = tf.train.Checkpoint(**ckpt_items)
      status = ckpt.read(ckpt_dir_or_file)
      status.expect_partial().assert_existing_objects_matched()

    logging.info('Finished loading pretrained checkpoint from %s',
                 ckpt_dir_or_file)

  def build_inputs(self,
                   params: exp_cfg.DataConfig,
                   input_context: Optional[tf.distribute.InputContext] = None):
    """Builds panoptic deeplab input."""
    decoder_cfg = params.decoder.get()

    if params.decoder.type == 'simple_decoder':
      decoder = panoptic_deeplab_input.TfExampleDecoder(
          regenerate_source_id=decoder_cfg.regenerate_source_id,
          panoptic_category_mask_key=decoder_cfg.panoptic_category_mask_key,
          panoptic_instance_mask_key=decoder_cfg.panoptic_instance_mask_key)
    else:
      raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type))

    parser = panoptic_deeplab_input.Parser(
        output_size=self.task_config.model.input_size[:2],
        ignore_label=params.parser.ignore_label,
        resize_eval_groundtruth=params.parser.resize_eval_groundtruth,
        groundtruth_padded_size=params.parser.groundtruth_padded_size,
        aug_scale_min=params.parser.aug_scale_min,
        aug_scale_max=params.parser.aug_scale_max,
        aug_rand_hflip=params.parser.aug_rand_hflip,
        aug_type=params.parser.aug_type,
        sigma=params.parser.sigma,
        dtype=params.parser.dtype)

    reader = input_reader_factory.input_reader_generator(
        params,
        dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))

    dataset = reader.read(input_context=input_context)

    return dataset

  def build_losses(self,
                   labels: Mapping[str, tf.Tensor],
                   model_outputs: Mapping[str, tf.Tensor],
                   aux_losses: Optional[Any] = None):
    """Panoptic deeplab losses.

    Args:
      labels: labels.
      model_outputs: Output logits from panoptic deeplab.
      aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.

    Returns:
      The total loss tensor.
    """
    loss_config = self._task_config.losses
    segmentation_loss_fn = (
        panoptic_deeplab_losses.WeightedBootstrappedCrossEntropyLoss(
            loss_config.label_smoothing,
            loss_config.class_weights,
            loss_config.ignore_label,
            top_k_percent_pixels=loss_config.top_k_percent_pixels))
    instance_center_heatmap_loss_fn = panoptic_deeplab_losses.CenterHeatmapLoss(
    )
    instance_center_offset_loss_fn = panoptic_deeplab_losses.CenterOffsetLoss()

    semantic_weights = tf.cast(
        labels['semantic_weights'],
        dtype=model_outputs['instance_centers_heatmap'].dtype)
    things_mask = tf.cast(
        tf.squeeze(labels['things_mask'], axis=3),
        dtype=model_outputs['instance_centers_heatmap'].dtype)
    valid_mask = tf.cast(
        tf.squeeze(labels['valid_mask'], axis=3),
        dtype=model_outputs['instance_centers_heatmap'].dtype)

    segmentation_loss = segmentation_loss_fn(
        model_outputs['segmentation_outputs'],
        labels['category_mask'],
        sample_weight=semantic_weights)
    instance_center_heatmap_loss = instance_center_heatmap_loss_fn(
        model_outputs['instance_centers_heatmap'],
        labels['instance_centers_heatmap'],
        sample_weight=valid_mask)
    instance_center_offset_loss = instance_center_offset_loss_fn(
        model_outputs['instance_centers_offset'],
        labels['instance_centers_offset'],
        sample_weight=things_mask)

    model_loss = (
        loss_config.segmentation_loss_weight * segmentation_loss +
        loss_config.center_heatmap_loss_weight * instance_center_heatmap_loss +
        loss_config.center_offset_loss_weight * instance_center_offset_loss)

    total_loss = model_loss
    if aux_losses:
      total_loss += tf.add_n(aux_losses)

    losses = {
        'total_loss': total_loss,
        'model_loss': model_loss,
        'segmentation_loss': segmentation_loss,
        'instance_center_heatmap_loss': instance_center_heatmap_loss,
        'instance_center_offset_loss': instance_center_offset_loss
    }

    return losses

  def build_metrics(self, training: bool = True) -> List[
      tf_keras.metrics.Metric]:
    """Build metrics."""
    eval_config = self.task_config.evaluation
    metrics = []
    if training:
      metric_names = [
          'total_loss',
          'segmentation_loss',
          'instance_center_heatmap_loss',
          'instance_center_offset_loss',
          'model_loss']
      for name in metric_names:
        metrics.append(tf_keras.metrics.Mean(name, dtype=tf.float32))

      if eval_config.report_train_mean_iou:
        self.train_mean_iou = segmentation_metrics.MeanIoU(
            name='train_mean_iou',
            num_classes=self.task_config.model.num_classes,
            rescale_predictions=False,
            dtype=tf.float32)
    else:
      rescale_predictions = (not self.task_config.validation_data.parser
                             .resize_eval_groundtruth)
      self.perclass_iou_metric = segmentation_metrics.PerClassIoU(
          name='per_class_iou',
          num_classes=self.task_config.model.num_classes,
          rescale_predictions=rescale_predictions,
          dtype=tf.float32)

      if self.task_config.model.generate_panoptic_masks:
        self.panoptic_quality_metric = (
            panoptic_quality_evaluator.PanopticQualityEvaluator(
                num_categories=self.task_config.model.num_classes,
                ignored_label=eval_config.ignored_label,
                max_instances_per_category=eval_config
                .max_instances_per_category,
                offset=eval_config.offset,
                is_thing=eval_config.is_thing,
                rescale_predictions=eval_config.rescale_predictions))

    return metrics

  def train_step(
      self,
      inputs: Tuple[Any, Any],
      model: tf_keras.Model,
      optimizer: tf_keras.optimizers.Optimizer,
      metrics: Optional[List[Any]] = None) -> Dict[str, Any]:
    """Does forward and backward.

    Args:
      inputs: a dictionary of input tensors.
      model: the model, forward pass definition.
      optimizer: the optimizer for this training step.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    images, labels = inputs
    num_replicas = tf.distribute.get_strategy().num_replicas_in_sync

    with tf.GradientTape() as tape:
      outputs = model(
          inputs=images,
          image_info=labels['image_info'],
          training=True)
      outputs = tf.nest.map_structure(
          lambda x: tf.cast(x, tf.float32), outputs)

      # Computes per-replica loss.
      losses = self.build_losses(
          labels=labels,
          model_outputs=outputs,
          aux_losses=model.losses)
      scaled_loss = losses['total_loss'] / num_replicas

      # For mixed_precision policy, when LossScaleOptimizer is used, loss is
      # scaled for numerical stability.
      if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
        scaled_loss = optimizer.get_scaled_loss(scaled_loss)

    tvars = model.trainable_variables
    grads = tape.gradient(scaled_loss, tvars)
    # Scales back gradient when LossScaleOptimizer is used.
    if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
      grads = optimizer.get_unscaled_gradients(grads)
    optimizer.apply_gradients(list(zip(grads, tvars)))

    logs = {self.loss: losses['total_loss']}

    if metrics:
      for m in metrics:
        m.update_state(losses[m.name])

    if self.task_config.evaluation.report_train_mean_iou:
      segmentation_labels = {
          'masks': labels['category_mask'],
          'valid_masks': labels['valid_mask'],
          'image_info': labels['image_info']
      }
      self.process_metrics(
          metrics=[self.train_mean_iou],
          labels=segmentation_labels,
          model_outputs=outputs['segmentation_outputs'])
      logs.update({
          self.train_mean_iou.name:
              self.train_mean_iou.result()
      })

    return logs

  def validation_step(
      self,
      inputs: Tuple[Any, Any],
      model: tf_keras.Model,
      metrics: Optional[List[Any]] = None) -> Dict[str, Any]:
    """Validatation step.

    Args:
      inputs: a dictionary of input tensors.
      model: the keras.Model.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    images, labels = inputs

    outputs = model(
        inputs=images,
        image_info=labels['image_info'],
        training=False)

    logs = {self.loss: 0}
    segmentation_labels = {
        'masks': labels['category_mask'],
        'valid_masks': labels['valid_mask'],
        'image_info': labels['image_info']
    }

    self.perclass_iou_metric.update_state(segmentation_labels,
                                          outputs['segmentation_outputs'])

    if self.task_config.model.generate_panoptic_masks:
      pq_metric_labels = {
          'category_mask': tf.squeeze(labels['category_mask'], axis=3),
          'instance_mask': tf.squeeze(labels['instance_mask'], axis=3),
          'image_info': labels['image_info']
      }
      panoptic_outputs = {
          'category_mask':
              outputs['category_mask'],
          'instance_mask':
              outputs['instance_mask'],
      }
      logs.update({
          self.panoptic_quality_metric.name:
              (pq_metric_labels, panoptic_outputs)})
    return logs

  def aggregate_logs(self, state=None, step_outputs=None):
    if state is None:
      self.perclass_iou_metric.reset_states()
      state = [self.perclass_iou_metric]
      if self.task_config.model.generate_panoptic_masks:
        state += [self.panoptic_quality_metric]

    if self.task_config.model.generate_panoptic_masks:
      self.panoptic_quality_metric.update_state(
          step_outputs[self.panoptic_quality_metric.name][0],
          step_outputs[self.panoptic_quality_metric.name][1])

    return state

  def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
    result = {}
    ious = self.perclass_iou_metric.result()
    if self.task_config.evaluation.report_per_class_iou:
      for i, value in enumerate(ious.numpy()):
        result.update({'segmentation_iou/class_{}'.format(i): value})

    # Computes mean IoU
    result.update({'segmentation_mean_iou': tf.reduce_mean(ious).numpy()})

    if self.task_config.model.generate_panoptic_masks:
      panoptic_quality_results = self.panoptic_quality_metric.result()
      for k, value in panoptic_quality_results.items():
        if k.endswith('per_class'):
          if self.task_config.evaluation.report_per_class_pq:
            for i, per_class_value in enumerate(value):
              metric_key = 'panoptic_quality/{}/class_{}'.format(k, i)
              result[metric_key] = per_class_value
          else:
            continue
        else:
          result['panoptic_quality/{}'.format(k)] = value

    return result