tensorflow/models

View on GitHub
official/projects/maskconver/tasks/multiscale_maskconver.py

Summary

Maintainability
A
2 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Panoptic Multi-scale MaskConver task definition."""
from typing import Any, Dict, List, Mapping, Optional, Tuple
import tensorflow as tf, tf_keras

from official.common import dataset_fn
from official.core import task_factory
from official.projects.maskconver.configs import multiscale_maskconver as exp_cfg
from official.projects.maskconver.dataloaders import multiscale_maskconver_input
from official.projects.maskconver.losses import maskconver_losses
from official.projects.maskconver.modeling import factory
from official.projects.maskconver.modeling.layers import copypaste
from official.projects.maskconver.tasks import maskconver
from official.projects.volumetric_models.losses import segmentation_losses as volumeteric_segmentation_losses
from official.vision.dataloaders import input_reader_factory


@task_factory.register_task_cls(exp_cfg.MultiScaleMaskConverTask)
class PanopticMultiScaleMaskConverTask(maskconver.PanopticMaskRCNNTask):

  """A single-replica view of training procedure.

  Panoptic Mask R-CNN task provides artifacts for training/evalution procedures,
  including loading/iterating over Datasets, initializing the model, calculating
  the loss, post-processing, and customized metrics with reduction.
  """

  def build_model(self) -> tf_keras.Model:
    """Build Panoptic Mask R-CNN model."""

    tf_keras.utils.set_random_seed(0)
    tf.config.experimental.enable_op_determinism()
    input_specs = tf_keras.layers.InputSpec(
        shape=[None] + self.task_config.model.input_size)

    l2_weight_decay = self.task_config.losses.l2_weight_decay
    # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
    # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
    # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
    l2_regularizer = (tf_keras.regularizers.l2(
        l2_weight_decay / 2.0) if l2_weight_decay else None)

    model = factory.build_multiscale_maskconver_model(
        input_specs=input_specs,
        model_config=self.task_config.model,
        l2_regularizer=l2_regularizer)

    # Get images and labels with batch size of 1.
    images, labels = next(
        iter(self.build_inputs(self.task_config.validation_data)))
    images = tf.nest.map_structure(lambda x: x[0:1, ...], images)
    labels = tf.nest.map_structure(lambda x: x[0:1, ...], labels)
    _ = model(
        images,
        image_info=labels['image_info'],
        training=False)
    return model

  def build_inputs(
      self,
      params: exp_cfg.DataConfig,
      input_context: Optional[tf.distribute.InputContext] = None
  ) -> tf.data.Dataset:
    """Build input dataset."""
    decoder_cfg = params.decoder.get()

    if params.decoder.type == 'simple_decoder':
      decoder = multiscale_maskconver_input.TfExampleDecoder(
          regenerate_source_id=decoder_cfg.regenerate_source_id,
          mask_binarize_threshold=decoder_cfg.mask_binarize_threshold,
          include_panoptic_masks=decoder_cfg.include_panoptic_masks,
          panoptic_category_mask_key=decoder_cfg.panoptic_category_mask_key,
          panoptic_instance_mask_key=decoder_cfg.panoptic_instance_mask_key)
    else:
      raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type))

    if params.parser.copypaste:
      sample_fn = copypaste.CopyPaste(
          self.task_config.model.input_size[:2],
          copypaste_frequency=params.parser.copypaste.copypaste_frequency,
          copypaste_aug_scale_max=params.parser.copypaste.copypaste_aug_scale_max,
          copypaste_aug_scale_min=params.parser.copypaste.copypaste_aug_scale_min,
          aug_scale_min=params.parser.copypaste.aug_scale_min,
          aug_scale_max=params.parser.copypaste.aug_scale_max,
          random_flip=params.parser.aug_rand_hflip,
          num_thing_classes=self.task_config.model.num_thing_classes)
    else:
      sample_fn = None

    parser = multiscale_maskconver_input.Parser(
        output_size=self.task_config.model.input_size[:2],
        min_level=self.task_config.model.min_level,
        max_level=self.task_config.model.max_level,
        fpn_low_range=params.parser.fpn_low_range,
        fpn_high_range=params.parser.fpn_high_range,
        dtype=params.dtype,
        aug_rand_hflip=params.parser.aug_rand_hflip,
        aug_scale_min=params.parser.aug_scale_min,
        aug_scale_max=params.parser.aug_scale_max,
        max_num_instances=params.parser.max_num_instances,
        segmentation_resize_eval_groundtruth=params.parser
        .segmentation_resize_eval_groundtruth,
        segmentation_groundtruth_padded_size=params.parser
        .segmentation_groundtruth_padded_size,
        segmentation_ignore_label=params.parser.segmentation_ignore_label,
        panoptic_ignore_label=params.parser.panoptic_ignore_label,
        num_panoptic_categories=self.task_config.model.num_classes,
        num_thing_categories=self.task_config.model.num_thing_classes,
        mask_target_level=params.parser.mask_target_level,
        level=self.task_config.model.level,
        gaussian_iou=params.parser.gaussaian_iou,
        aug_type=params.parser.aug_type,)

    reader = input_reader_factory.input_reader_generator(
        params,
        dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
        sample_fn=sample_fn.copypaste_fn(
            params.is_training) if sample_fn else None,
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))
    dataset = reader.read(input_context=input_context)

    return dataset

  def build_losses(self,
                   outputs: Mapping[str, Any],
                   labels: Mapping[str, Any],
                   iteration: Any,
                   aux_losses: Optional[Any] = None,
                   step=None) -> Dict[str, tf.Tensor]:
    """Build Panoptic Mask R-CNN losses."""
    # pylint: disable=line-too-long
    loss_params = self._task_config.losses
    center_loss_fn = maskconver_losses.PenaltyReducedLogisticFocalLoss(
        alpha=loss_params.alpha, beta=loss_params.beta)

    true_flattened_ct_heatmap = labels['panoptic_heatmaps']
    true_flattened_ct_heatmap = tf.cast(true_flattened_ct_heatmap, tf.float32)

    pred_flattened_ct_heatmap = outputs['class_heatmaps']
    pred_flattened_ct_heatmap = tf.cast(pred_flattened_ct_heatmap, tf.float32)

    center_loss = center_loss_fn(
        target_tensor=true_flattened_ct_heatmap,
        prediction_tensor=pred_flattened_ct_heatmap,
        weights=1.0)

    replica_context = tf.distribute.get_replica_context()
    global_num_instances = replica_context.all_reduce(
        tf.distribute.ReduceOp.SUM, labels['num_instances'])
    num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
    num_instances = tf.cast(global_num_instances, tf.float32) / tf.cast(num_replicas, tf.float32) + 1.0

    center_loss = tf.reduce_sum(center_loss) / num_instances

    gt_masks = labels['panoptic_masks']
    gt_mask_weights = labels['panoptic_mask_weights'][:, None, None, :] * tf.ones_like(gt_masks)
    panoptic_padding_mask = labels['panoptic_padding_mask'][:, :, :, None] * tf.ones_like(gt_masks)

    # gt_masks
    _, h, w, q = gt_masks.get_shape().as_list()
    predicted_masks = tf.cast(outputs['mask_proposal_logits'], tf.float32)
    predicted_masks = tf.image.resize(
        predicted_masks, tf.shape(gt_masks)[1:3], method='bilinear')

    mask_loss_fn = tf_keras.losses.BinaryCrossentropy(
        from_logits=True,
        label_smoothing=0.0,
        axis=-1,
        reduction=tf_keras.losses.Reduction.NONE,
        name='binary_crossentropy')

    mask_weights = tf.cast(gt_masks >= 0, tf.float32) * gt_mask_weights  * (
        1 - panoptic_padding_mask)  # b, h, w, # max inst
    mask_loss = mask_loss_fn(
        tf.expand_dims(gt_masks, -1),
        tf.expand_dims(predicted_masks, -1),
        sample_weight=tf.expand_dims(mask_weights, -1))

    mask_loss = tf.reshape(mask_loss, [-1, h * w, q])
    mask_loss = tf.reduce_sum(tf.reduce_mean(mask_loss, axis=1)) / num_instances

    # Dice loss
    masked_predictions = tf.sigmoid(predicted_masks) * tf.cast(
        gt_mask_weights > 0, tf.float32) * (1 - panoptic_padding_mask)
    masked_gt_masks = gt_masks * tf.cast(gt_mask_weights > 0, tf.float32) * (
        1 - panoptic_padding_mask)

    masked_predictions = tf.transpose(masked_predictions, [0, 3, 1, 2])
    masked_predictions = tf.reshape(masked_predictions, [-1, h, w, 1])
    masked_gt_masks = tf.transpose(masked_gt_masks, [0, 3, 1, 2])
    masked_gt_masks = tf.reshape(masked_gt_masks, [-1, h, w, 1])

    dice_loss_fn = volumeteric_segmentation_losses.SegmentationLossDiceScore(
        metric_type='adaptive', axis=(2, 3))
    dice_loss = dice_loss_fn(logits=masked_predictions, labels=masked_gt_masks)

    total_loss = center_loss + loss_params.mask_weight * (mask_loss + dice_loss)
    if aux_losses:
      total_loss += tf.add_n(aux_losses)

    total_loss = loss_params.loss_weight * total_loss

    losses = {'total_loss': total_loss,
              'mask_loss': mask_loss,
              'center_loss': center_loss,
              'dice_loss': dice_loss,}
    return losses

  def train_step(self,
                 inputs: Tuple[Any, Any],
                 model: tf_keras.Model,
                 optimizer: tf_keras.optimizers.Optimizer,
                 metrics: Optional[List[Any]] = None) -> Dict[str, Any]:
    """Does forward and backward.

    Args:
      inputs: a dictionary of input tensors.
      model: the model, forward pass definition.
      optimizer: the optimizer for this training step.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    images, labels = inputs
    num_replicas = tf.distribute.get_strategy().num_replicas_in_sync

    with tf.GradientTape() as tape:
      outputs = model(
          images,
          box_indices=labels['panoptic_box_indices'],
          classes=labels['panoptic_classes'],
          training=True)
      outputs = tf.nest.map_structure(
          lambda x: tf.cast(x, tf.float32), outputs)

      # Computes per-replica loss.
      losses = self.build_losses(
          outputs=outputs,
          labels=labels,
          aux_losses=model.losses,
          iteration=optimizer.iterations,
          step=optimizer.iterations)
      scaled_loss = losses['total_loss'] / num_replicas

      # For mixed_precision policy, when LossScaleOptimizer is used, loss is
      # scaled for numerical stability.
      if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
        scaled_loss = optimizer.get_scaled_loss(scaled_loss)

    tvars = model.trainable_variables
    grads = tape.gradient(scaled_loss, tvars)
    # Scales back gradient when LossScaleOptimizer is used.
    if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
      grads = optimizer.get_unscaled_gradients(grads)
    optimizer.apply_gradients(list(zip(grads, tvars)))

    logs = {self.loss: losses['total_loss']}

    if metrics:
      for m in metrics:
        m.update_state(losses[m.name])

    return logs