tensorflow/models

View on GitHub
official/projects/pointpillars/tasks/pointpillars.py

Summary

Maintainability
B
6 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""PointPillars task definition."""

import functools
from typing import Any, List, Mapping, Optional, Tuple

from absl import logging
import tensorflow as tf, tf_keras

from official.core import base_task
from official.core import task_factory
from official.projects.pointpillars.configs import pointpillars as cfg
from official.projects.pointpillars.dataloaders import decoders
from official.projects.pointpillars.dataloaders import parsers
from official.projects.pointpillars.modeling import factory
from official.projects.pointpillars.utils import utils
from official.vision.dataloaders import input_reader_factory
from official.vision.losses import focal_loss
from official.vision.losses import loss_utils


def pick_dataset_fn(file_type: str) -> Any:
  if file_type == 'tfrecord':
    return tf.data.TFRecordDataset
  if file_type == 'tfrecord_compressed':
    return functools.partial(tf.data.TFRecordDataset, compression_type='GZIP')
  raise ValueError('Unrecognized file_type: {}'.format(file_type))


def get_batch_size_per_replica(global_batch_size: int) -> int:
  """Get batch size per accelerator replica."""
  num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
  if global_batch_size < num_replicas:
    logging.warning('Global batch size is smaller than num replicas. '
                    'Set batch size per replica to 1.')
    return 1
  if global_batch_size % num_replicas != 0:
    raise ValueError(
        'global_batch_size {} is not a multiple of num_replicas {}'
        .format(global_batch_size, num_replicas))
  batch_size = int(global_batch_size / num_replicas)
  return batch_size


@task_factory.register_task_cls(cfg.PointPillarsTask)
class PointPillarsTask(base_task.Task):
  """A single-replica view of training procedure."""

  def __init__(self,
               params: cfg.PointPillarsTask,
               logging_dir: Optional[str] = None,
               name: Optional[str] = None):
    super().__init__(params, logging_dir, name)
    self._model = None
    self._attribute_heads = self.task_config.model.head.attribute_heads

  def build_model(self) -> tf_keras.Model:
    # Create only one model instance if this function is called multiple times.
    if self._model is not None:
      return self._model

    pillars_config = self.task_config.model.pillars
    input_specs = {
        'pillars':
            tf_keras.layers.InputSpec(
                shape=(None, pillars_config.num_pillars,
                       pillars_config.num_points_per_pillar,
                       pillars_config.num_features_per_point)),
        'indices':
            tf_keras.layers.InputSpec(
                shape=(None, pillars_config.num_pillars, 2), dtype='int32'),
    }

    train_batch_size = get_batch_size_per_replica(
        self.task_config.train_data.global_batch_size)
    eval_batch_size = get_batch_size_per_replica(
        self.task_config.validation_data.global_batch_size)

    l2_weight_decay = self.task_config.losses.l2_weight_decay
    l2_regularizer = (tf_keras.regularizers.l2(
        l2_weight_decay / 2.0) if l2_weight_decay else None)

    self._model = factory.build_pointpillars(
        input_specs=input_specs,
        model_config=self.task_config.model,
        train_batch_size=train_batch_size,
        eval_batch_size=eval_batch_size,
        l2_regularizer=l2_regularizer)
    return self._model

  def initialize(self, model: tf_keras.Model):
    """Loading pretrained checkpoint."""
    if not self.task_config.init_checkpoint:
      return

    ckpt_dir_or_file = self.task_config.init_checkpoint
    if tf.io.gfile.isdir(ckpt_dir_or_file):
      ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)

    if self.task_config.init_checkpoint_modules == 'all':
      ckpt = tf.train.Checkpoint(**model.checkpoint_items)
      status = ckpt.read(ckpt_dir_or_file)
      status.expect_partial().assert_existing_objects_matched()
    else:
      ckpt_items = {}
      if 'backbone' in self.task_config.init_checkpoint_modules:
        ckpt_items.update(backbone=model.backbone)
      if 'decoder' in self.task_config.init_checkpoint_modules:
        ckpt_items.update(decoder=model.decoder)

      ckpt = tf.train.Checkpoint(**ckpt_items)
      status = ckpt.read(ckpt_dir_or_file)
      status.expect_partial().assert_existing_objects_matched()

    logging.info('Finished loading pretrained checkpoint from %s',
                 ckpt_dir_or_file)

  def build_inputs(
      self,
      params: cfg.DataConfig,
      input_context: Optional[tf.distribute.InputContext] = None
  ) -> tf.data.Dataset:
    """Build input dataset."""
    model_config = self.task_config.model
    if (model_config.classes != 'all' and
        model_config.num_classes != 2):
      raise ValueError('Model num_classes must be 2 when not for all classes.')

    decoder = decoders.ExampleDecoder(model_config.image, model_config.pillars)

    image_size = [model_config.image.height, model_config.image.width]
    anchor_sizes = [(a.length, a.width) for a in model_config.anchors]
    anchor_labeler_config = model_config.anchor_labeler
    parser = parsers.Parser(
        classes=model_config.classes,
        min_level=model_config.min_level,
        max_level=model_config.max_level,
        image_size=image_size,
        anchor_sizes=anchor_sizes,
        match_threshold=anchor_labeler_config.match_threshold,
        unmatched_threshold=anchor_labeler_config.unmatched_threshold,
        max_num_detections=model_config.detection_generator
        .max_num_detections,
        dtype=params.dtype,
    )
    reader = input_reader_factory.input_reader_generator(
        params,
        dataset_fn=pick_dataset_fn(params.file_type),
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))
    dataset = reader.read(input_context=input_context)

    return dataset

  def compute_attribute_losses(
      self,
      outputs: Mapping[str, Any],
      labels: Mapping[str, Any],
      box_sample_weight: tf.Tensor) -> Mapping[str, float]:
    """Computes attribute loss."""
    att_loss_fn = tf_keras.losses.Huber(
        self.task_config.losses.huber_loss_delta,
        reduction=tf_keras.losses.Reduction.SUM)

    losses = {}
    total_loss = 0.0
    for head in self._attribute_heads:
      if head.type != 'regression':
        raise ValueError(f'Attribute type {head.type} not supported.')

      y_true_att = loss_utils.multi_level_flatten(
          labels['attribute_targets'][head.name], last_dim=head.size)
      y_pred_att = loss_utils.multi_level_flatten(
          outputs['attribute_outputs'][head.name], last_dim=head.size)
      if head.name == 'heading':
        # Direction aware loss, wrap the delta angle to [-pi, pi].
        # Otherwise for a loss that is symmetric to direction (i.e., heading 0
        # and pi are the same), we use a tf.sin transform.
        delta = utils.wrap_angle_rad(y_pred_att - y_true_att)
        loss = att_loss_fn(
            y_true=tf.zeros_like(delta),
            y_pred=delta,
            sample_weight=box_sample_weight)
      else:
        loss = att_loss_fn(
            y_true=y_true_att,
            y_pred=y_pred_att,
            sample_weight=box_sample_weight)
      total_loss += loss
      losses[head.name] = loss
    losses['total'] = total_loss
    return losses

  def compute_losses(
      self,
      outputs: Mapping[str, Any],
      labels: Mapping[str, Any],
      aux_losses: Optional[Any] = None) -> Mapping[str, float]:
    """Build losses."""
    params = self.task_config

    cls_loss_fn = focal_loss.FocalLoss(
        alpha=params.losses.focal_loss_alpha,
        gamma=params.losses.focal_loss_gamma,
        reduction=tf_keras.losses.Reduction.SUM)
    box_loss_fn = tf_keras.losses.Huber(
        params.losses.huber_loss_delta,
        reduction=tf_keras.losses.Reduction.SUM)

    # Sums all positives in a batch for normalization and avoids zero
    # num_positives_sum, which would lead to inf loss during training
    cls_sample_weight = labels['cls_weights']
    box_sample_weight = labels['box_weights']
    num_positives = tf.reduce_sum(box_sample_weight) + 1.0
    cls_sample_weight = cls_sample_weight / num_positives
    box_sample_weight = box_sample_weight / num_positives

    y_true_cls = loss_utils.multi_level_flatten(
        labels['cls_targets'], last_dim=None)
    y_true_cls = tf.one_hot(y_true_cls, params.model.num_classes)
    y_pred_cls = loss_utils.multi_level_flatten(
        outputs['cls_outputs'], last_dim=params.model.num_classes)
    y_true_box = loss_utils.multi_level_flatten(
        labels['box_targets'], last_dim=4)
    y_pred_box = loss_utils.multi_level_flatten(
        outputs['box_outputs'], last_dim=4)

    cls_loss = cls_loss_fn(
        y_true=y_true_cls, y_pred=y_pred_cls, sample_weight=cls_sample_weight)
    box_loss = box_loss_fn(
        y_true=y_true_box, y_pred=y_pred_box, sample_weight=box_sample_weight)
    attribute_losses = self.compute_attribute_losses(outputs, labels,
                                                     box_sample_weight)
    model_loss = (
        cls_loss + box_loss * params.losses.box_loss_weight +
        attribute_losses['total'] * params.losses.attribute_loss_weight)

    total_loss = model_loss
    if aux_losses:
      reg_loss = tf.reduce_sum(aux_losses)
      total_loss += reg_loss
    total_loss = params.losses.loss_weight * total_loss

    losses = {
        'class_loss': cls_loss,
        'box_loss': box_loss,
        'attribute_loss': attribute_losses['total'],
        'model_loss': model_loss,
        'total_loss': total_loss,
    }
    for head in self._attribute_heads:
      losses[head.name + '_loss'] = attribute_losses[head.name]
    return losses

  def build_metrics(self, training: bool = True) -> List[tf.metrics.Metric]:
    """Define metrics and how to calculate them."""
    # train/validation loss metrics
    loss_names = [
        'class_loss', 'box_loss', 'attribute_loss', 'model_loss', 'total_loss'
    ]
    for head in self._attribute_heads:
      loss_names.append(head.name + '_loss')
    metrics = []
    for name in loss_names:
      metrics.append(tf_keras.metrics.Mean(name, dtype=tf.float32))

    # Use a separate metric for WOD validation.
    if not training:
      if self.task_config.use_wod_metrics:
        # To use Waymo open dataset metrics, please install one of the pip
        # package `waymo-open-dataset-tf-*` from
        # https://github.com/waymo-research/waymo-open-dataset/blob/master/docs/quick_start.md#use-pre-compiled-pippip3-packages-for-linux
        # Note that the package is built with specific tensorflow version and
        # will produce error if it does not match the tf version that is
        # currently used.
        try:
          from official.projects.pointpillars.utils import wod_detection_evaluator  # pylint: disable=g-import-not-at-top
        except ModuleNotFoundError:
          logging.error('waymo-open-dataset should be installed to enable Waymo'
                        ' evaluator.')
          raise
        self._wod_metric = wod_detection_evaluator.create_evaluator(
            self.task_config.model)
    return metrics

  def train_step(
      self,
      inputs: Tuple[Any, Any],
      model: tf_keras.Model,
      optimizer: tf_keras.optimizers.Optimizer,
      metrics: Optional[List[tf.metrics.Metric]] = None) -> Mapping[str, Any]:
    """Does forward and backward."""
    features, labels = inputs
    num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
    with tf.GradientTape() as tape:
      outputs = model(pillars=features['pillars'],
                      indices=features['indices'],
                      training=True)
      losses = self.compute_losses(
          outputs=outputs, labels=labels, aux_losses=model.losses)

      # Computes per-replica loss.
      scaled_loss = losses['total_loss'] / num_replicas

      # For mixed_precision policy, when LossScaleOptimizer is used, loss is
      # scaled for numerical stability.
      if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
        scaled_loss = optimizer.get_scaled_loss(scaled_loss)

    tvars = model.trainable_variables
    grads = tape.gradient(scaled_loss, tvars)
    # Scales back gradient when LossScaleOptimizer is used.
    if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
      grads = optimizer.get_unscaled_gradients(grads)
    optimizer.apply_gradients(list(zip(grads, tvars)))

    # For updating trainer.train_loss
    logs = {self.loss: losses['total_loss']}
    # For updating trainer.train_metrics
    if metrics:
      for m in metrics:
        m.update_state(losses[m.name])
    return logs

  def validation_step(
      self,
      inputs: Tuple[Any, Any],
      model: tf_keras.Model,
      metrics: Optional[List[tf.metrics.Metric]] = None) -> Mapping[str, Any]:
    """Validatation step."""
    features, labels = inputs
    outputs = model(pillars=features['pillars'],
                    indices=features['indices'],
                    image_shape=labels['image_shape'],
                    anchor_boxes=labels['anchor_boxes'],
                    training=False)
    losses = self.compute_losses(
        outputs=outputs, labels=labels, aux_losses=model.losses)

    # For updating trainer.validation_loss
    logs = {self.loss: losses['total_loss']}
    # For updating trainer.validation_metrics
    if metrics:
      for m in metrics:
        m.update_state(losses[m.name])
    if self.task_config.use_wod_metrics:
      logs.update(
          {self._wod_metric.name: (labels['groundtruths'], outputs)})
    return logs

  def aggregate_logs(self,
                     state: Any = None,
                     step_outputs: Any = None) -> Any:
    """Called after each validation_step to update metrics."""
    logging.log_every_n(logging.INFO,
                        'Aggregating metrics after one evaluation step.', 1000)
    if self.task_config.use_wod_metrics:
      if state is None:
        self._wod_metric.reset_states()
      self._wod_metric.update_state(step_outputs[self._wod_metric.name][0],
                                    step_outputs[self._wod_metric.name][1])
    if state is None:
      state = True
    return state

  def reduce_aggregated_logs(self,
                             aggregated_logs: Any,
                             global_step: Optional[tf.Tensor] = None) -> Any:
    """Called after eval_end to calculate metrics."""
    logging.info('Reducing aggregated metrics after one evaluation cycle.')
    logs = {}
    if self.task_config.use_wod_metrics:
      logs.update(self._wod_metric.result())
    return logs