tensorflow/models

View on GitHub
official/projects/volumetric_models/tasks/semantic_segmentation_3d.py

Summary

Maintainability
B
4 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Image segmentation task definition."""
from typing import Any, Dict, Mapping, Optional, Sequence, Union

from absl import logging
import tensorflow as tf, tf_keras

from official.common import dataset_fn
from official.core import base_task
from official.core import input_reader
from official.core import task_factory
from official.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg
from official.projects.volumetric_models.dataloaders import segmentation_input_3d
from official.projects.volumetric_models.evaluation import segmentation_metrics
from official.projects.volumetric_models.losses import segmentation_losses
from official.projects.volumetric_models.modeling import factory


@task_factory.register_task_cls(exp_cfg.SemanticSegmentation3DTask)
class SemanticSegmentation3DTask(base_task.Task):
  """A task for semantic segmentation."""

  def build_model(self) -> tf_keras.Model:
    """Builds segmentation model."""
    input_specs = tf_keras.layers.InputSpec(
        shape=[None] + self.task_config.model.input_size +
        [self.task_config.model.num_channels],
        dtype=self.task_config.train_data.dtype)

    l2_weight_decay = self.task_config.losses.l2_weight_decay
    # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
    # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
    # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
    l2_regularizer = (
        tf_keras.regularizers.l2(l2_weight_decay /
                                 2.0) if l2_weight_decay else None)

    model = factory.build_segmentation_model_3d(
        input_specs=input_specs,
        model_config=self.task_config.model,
        l2_regularizer=l2_regularizer)

    # Create a dummy input and call model instance to initialize the model. This
    # is needed when launching multiple experiments using the same model
    # directory. Since there is already a trained model, forward pass will not
    # run and the model will never be built. This is only done when spatial
    # partitioning is not enabled; otherwise it will fail with OOM due to
    # extremely large input.
    if (not self.task_config.train_input_partition_dims) and (
        not self.task_config.eval_input_partition_dims):
      dummy_input = tf.random.uniform(shape=[1] + list(input_specs.shape[1:]))
      _ = model(dummy_input)

    return model

  def initialize(self, model: tf_keras.Model):
    """Loads pretrained checkpoint."""
    if not self.task_config.init_checkpoint:
      return

    ckpt_dir_or_file = self.task_config.init_checkpoint
    if tf.io.gfile.isdir(ckpt_dir_or_file):
      ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)

    # Restoring checkpoint.
    if 'all' in self.task_config.init_checkpoint_modules:
      ckpt = tf.train.Checkpoint(**model.checkpoint_items)
      status = ckpt.read(ckpt_dir_or_file)
      status.expect_partial().assert_existing_objects_matched()
    else:
      ckpt_items = {}
      if 'backbone' in self.task_config.init_checkpoint_modules:
        ckpt_items.update(backbone=model.backbone)
      if 'decoder' in self.task_config.init_checkpoint_modules:
        ckpt_items.update(decoder=model.decoder)

      ckpt = tf.train.Checkpoint(**ckpt_items)
      status = ckpt.read(ckpt_dir_or_file)
      status.expect_partial().assert_existing_objects_matched()

    logging.info('Finished loading pretrained checkpoint from %s',
                 ckpt_dir_or_file)

  def build_inputs(self, params, input_context=None) -> tf.data.Dataset:
    """Builds classification input."""
    decoder = segmentation_input_3d.Decoder(
        image_field_key=params.image_field_key,
        label_field_key=params.label_field_key)
    parser = segmentation_input_3d.Parser(
        input_size=params.input_size,
        num_classes=params.num_classes,
        num_channels=params.num_channels,
        image_field_key=params.image_field_key,
        label_field_key=params.label_field_key,
        dtype=params.dtype,
        label_dtype=params.label_dtype)

    reader = input_reader.InputReader(
        params,
        dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))

    dataset = reader.read(input_context=input_context)

    return dataset

  def build_losses(self,
                   labels: tf.Tensor,
                   model_outputs: tf.Tensor,
                   aux_losses=None) -> tf.Tensor:
    """Segmentation loss.

    Args:
      labels: labels.
      model_outputs: Output logits of the classifier.
      aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.

    Returns:
      The total loss tensor.
    """
    segmentation_loss_fn = segmentation_losses.SegmentationLossDiceScore(
        metric_type='adaptive')

    total_loss = segmentation_loss_fn(model_outputs, labels)

    if aux_losses:
      total_loss += tf.add_n(aux_losses)

    return total_loss

  def build_metrics(self,
                    training: bool = True) -> Sequence[tf_keras.metrics.Metric]:
    """Gets streaming metrics for training/validation."""
    metrics = []
    num_classes = self.task_config.model.num_classes
    if training:
      metrics.extend([
          tf_keras.metrics.CategoricalAccuracy(
              name='train_categorical_accuracy', dtype=tf.float32)
      ])
    else:
      self.metrics = [
          segmentation_metrics.DiceScore(
              num_classes=num_classes,
              metric_type='generalized',
              per_class_metric=self.task_config.evaluation
              .report_per_class_metric,
              name='val_generalized_dice',
              dtype=tf.float32)
      ]

    return metrics

  def train_step(
      self,
      inputs,
      model: tf_keras.Model,
      optimizer: tf_keras.optimizers.Optimizer,
      metrics: Optional[Sequence[tf_keras.metrics.Metric]] = None
  ) -> Dict[Any, Any]:
    """Does forward and backward.

    Args:
      inputs: a dictionary of input tensors.
      model: the model, forward pass definition.
      optimizer: the optimizer for this training step.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    features, labels = inputs

    input_partition_dims = self.task_config.train_input_partition_dims
    if input_partition_dims:
      strategy = tf.distribute.get_strategy()
      features = strategy.experimental_split_to_logical_devices(
          features, input_partition_dims)

    num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
    with tf.GradientTape() as tape:
      outputs = model(features, training=True)
      # Casting output layer as float32 is necessary when mixed_precision is
      # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
      outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)

      outputs = outputs['logits']
      if self.task_config.model.head.output_logits:
        outputs = tf.nn.softmax(outputs)

      # Computes per-replica loss.
      loss = self.build_losses(
          labels=labels, model_outputs=outputs, aux_losses=model.losses)
      # Scales loss as the default gradients allreduce performs sum inside the
      # optimizer.
      scaled_loss = loss / num_replicas

      # For mixed_precision policy, when LossScaleOptimizer is used, loss is
      # scaled for numerical stability.
      if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
        scaled_loss = optimizer.get_scaled_loss(scaled_loss)

    tvars = model.trainable_variables
    grads = tape.gradient(scaled_loss, tvars)
    # Scales back gradient before apply_gradients when LossScaleOptimizer is
    # used.
    if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
      grads = optimizer.get_unscaled_gradients(grads)
    optimizer.apply_gradients(list(zip(grads, tvars)))

    logs = {self.loss: loss}

    # Compute all metrics within strategy scope for training.
    if metrics:
      labels = tf.cast(labels, tf.float32)
      outputs = tf.cast(outputs, tf.float32)
      self.process_metrics(metrics, labels, outputs)
      logs.update({m.name: m.result() for m in metrics})

    return logs

  def validation_step(
      self,
      inputs,
      model: tf_keras.Model,
      metrics: Optional[Sequence[tf_keras.metrics.Metric]] = None
  ) -> Dict[Any, Any]:
    """Validatation step.

    Args:
      inputs: a dictionary of input tensors.
      model: the keras.Model.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    features, labels = inputs

    input_partition_dims = self.task_config.eval_input_partition_dims
    if input_partition_dims:
      strategy = tf.distribute.get_strategy()
      features = strategy.experimental_split_to_logical_devices(
          features, input_partition_dims)

    outputs = self.inference_step(features, model)
    outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
    outputs = outputs['logits']
    if self.task_config.model.head.output_logits:
      outputs = tf.nn.softmax(outputs)

    loss = self.build_losses(
        model_outputs=outputs, labels=labels, aux_losses=model.losses)
    logs = {self.loss: loss}

    # Compute dice score metrics on CPU.
    for metric in self.metrics:
      labels = tf.cast(labels, tf.float32)
      logits = tf.cast(outputs, tf.float32)
      logs.update({metric.name: (labels, logits)})

    return logs

  def inference_step(self, inputs, model: tf_keras.Model) -> tf.Tensor:
    """Performs the forward step."""
    return model(inputs, training=False)

  def aggregate_logs(
      self,
      state: Optional[Sequence[Union[segmentation_metrics.DiceScore,
                                     tf_keras.metrics.Metric]]] = None,
      step_outputs: Optional[Mapping[str, Any]] = None
  ) -> Sequence[tf_keras.metrics.Metric]:
    """Aggregates statistics to compute metrics over training.

    Args:
      state: A sequence of tf_keras.metrics.Metric objects. Each element records
        a metric.
      step_outputs: A dictionary of [metric_name, (labels, output)] from a step.

    Returns:
      An updated sequence of tf_keras.metrics.Metric objects.
    """
    if state is None:
      for metric in self.metrics:
        metric.reset_states()
      state = self.metrics

    for metric in self.metrics:
      labels = step_outputs[metric.name][0]
      predictions = step_outputs[metric.name][1]

      # If `step_output` is distributed, it contains a tuple of Tensors instead
      # of a single Tensor, so we need to concatenate them along the batch
      # dimension in this case to have a single Tensor.
      if isinstance(labels, tuple):
        labels = tf.concat(list(labels), axis=0)
      if isinstance(predictions, tuple):
        predictions = tf.concat(list(predictions), axis=0)

      labels = tf.cast(labels, tf.float32)
      predictions = tf.cast(predictions, tf.float32)
      metric.update_state(labels, predictions)
    return state

  def reduce_aggregated_logs(
      self,
      aggregated_logs: Optional[Mapping[str, Any]] = None,
      global_step: Optional[tf.Tensor] = None) -> Mapping[str, float]:
    """Reduces logs to obtain per-class metrics if needed.

    Args:
      aggregated_logs: An optional dictionary containing aggregated logs.
      global_step: An optional `tf.Tensor` of current global training steps.

    Returns:
      The reduced logs containing per-class metrics and overall metrics.

    Raises:
      ValueError: If `self.metrics` does not contain exactly 1 metric object.
    """
    result = {}
    if len(self.metrics) != 1:
      raise ValueError('Exact one metric must be present, but {0} are '
                       'present.'.format(len(self.metrics)))

    metric = self.metrics[0].result().numpy()
    if self.task_config.evaluation.report_per_class_metric:
      for i, metric_val in enumerate(metric):
        metric_name = self.metrics[0].name + '/class_{0}'.format(
            i - 1) if i > 0 else self.metrics[0].name
        result.update({metric_name: metric_val})
    else:
      result.update({self.metrics[0].name: metric})
    return result