tensorflow/models

View on GitHub
official/projects/yolo/losses/yolov7_loss.py

Summary

Maintainability
D
2 days
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""YOLOv7 loss function."""

import tensorflow as tf, tf_keras

from official.projects.yolo.ops import box_ops
from official.vision.losses import focal_loss

_LAYER_BALANCE = {
    '3': [4.0, 1.0, 0.4],
    '5': [4.0, 1.0, 0.25, 0.06, 0.02],
}


def smooth_bce_targets(eps=0.1):
  """Computes positive, negative label smoothing BCE targets.

  https://arxiv.org/pdf/1902.04103.pdf equation 3.

  Args:
    eps: a float number from [0, 1] representing label smoothing factor.

  Returns:
    Positive and negative targets after label smoothing.
  """
  return 1.0 - 0.5 * eps, 0.5 * eps


def merge_labels(labels):
  """Converts the ground-truth labels into loss targets."""
  boxes = box_ops.yxyx_to_xcycwh(labels['bbox'])
  classes = tf.cast(labels['classes'], boxes.dtype)
  return tf.concat([classes[..., None], boxes], axis=-1)


class YoloV7Loss(tf_keras.losses.Loss):
  """YOLOv7 loss function."""

  def __init__(
      self,
      anchors,
      strides,
      input_size,
      alpha=0.25,
      gamma=1.5,
      box_weight=0.05,
      obj_weight=0.7,
      cls_weight=0.3,
      label_smoothing=0.0,
      anchor_threshold=4.0,
      iou_mix_ratio=1.0,
      num_classes=80,
      auto_balance=False,
      reduction=tf_keras.losses.Reduction.NONE,
      name=None,
  ):
    """Constructor for YOLOv7 loss.

    Follows the implementation here:
      https://github.com/WongKinYiu/yolov7/blob/main/utils/loss.py#L422

    Args:
      anchors: a 2D array represents different anchors used at each level.
      strides: a 1D array represents the strides. Note that all numbers should
        be a power of 2, and they usually start with level 3 and end at level
        5 or 7. Therefore, the list should usually be [8, 16, 32] or
        [8, 16, 32, 64, 128].
      input_size: a list containing the height and width of the input image.
      alpha: alpha for focal loss.
      gamma: gamma for focal loss. If set to 0, focal loss will be disabled.
      box_weight: float weight scalar applied to bounding box loss.
      obj_weight: float weight scalar applied to objectness loss.
      cls_weight: float weight scalar applied to class loss.
      label_smoothing: small float number used to compute positive and negative
        targets. If set to 0, the positive targets will be 1 and negative
        targets will be 0.
      anchor_threshold: threshold for the anchor matching. Larger number allows
        more displacements between anchors and targets.
      iou_mix_ratio: float ratio to mix the IoU score with the positive target,
        which is 1.
      num_classes: number of classes.
      auto_balance: a boolean flag that indicates whether auto balance should be
        used. If used, the default balance factors will automatically update
        for each batch.
      reduction: Reduction method. Should be set to None at all time as this
        loss module always output a loss scalar.
      name: Optional name for the loss.
    """
    # Loss required fields.
    self._num_classes = num_classes
    self._num_layers = len(strides)
    self._num_anchors = len(anchors[0])
    self._anchors = anchors
    self._strides = strides
    self._input_size = input_size
    self._iou_mix_ratio = iou_mix_ratio

    # Scale down anchors by the strides to match the feature map.
    for i, stride in enumerate(strides):
      self._anchors[i] = tf.constant(self._anchors[i], tf.float32) / stride

    self._anchor_threshold = anchor_threshold

    self._pos_targets, self._neg_targets = smooth_bce_targets(label_smoothing)
    if gamma > 0:
      self._cls_loss_fn = focal_loss.FocalLoss(
          alpha=alpha, gamma=gamma, reduction=reduction, name='cls_loss')
      self._obj_loss_fn = focal_loss.FocalLoss(
          alpha=alpha, gamma=gamma, reduction=reduction, name='obj_loss')
    else:
      self._cls_loss_fn = tf.nn.sigmoid_cross_entropy_with_logits
      self._obj_loss_fn = tf.nn.sigmoid_cross_entropy_with_logits

    # Weight to combine losses
    self._box_weight = box_weight
    self._obj_weight = obj_weight * input_size[0] / 640 * input_size[1] / 640
    self._cls_weight = cls_weight * num_classes / 80

    # Layer balance scalar
    self._balance = _LAYER_BALANCE[str(self._num_layers)][:]
    for i, bal in enumerate(self._balance):
      self._balance[i] = tf.constant(bal, tf.float32)
    self._auto_balance = auto_balance
    assert 16 in strides, (
        'Expect level 4 (stride of 16) always exist in the strides, received %s'
        % strides
    )
    self._ssi = list(strides).index(16) if auto_balance else 0  # stride 16 idx

    super().__init__(reduction=reduction, name=name)

  def call(self, labels, predictions):
    labels = merge_labels(labels)
    p = {}
    for key in predictions:
      # [batch_size, num_anchors, height, width, num_classes + boxes + obj]
      p[key] = tf.transpose(predictions[key], [0, 3, 1, 2, 4])
    cls_loss, box_loss, obj_loss, iou_metric = [tf.zeros(1) for _ in range(4)]
    total_num_matchings = tf.zeros(1)
    total_num_gts = tf.reduce_sum(tf.cast(labels[..., 0] != -1, tf.float32))

    masks, indices, anchors, cls_targets, box_targets = self._build_targets(
        labels, p)

    batch_size = tf.shape(indices)[0]
    layer_shape = [batch_size, self._num_layers, -1]
    # [anchor_indices, grid_js, grid_is]
    masks = tf.reshape(masks, layer_shape)
    indices = tf.reshape(indices, [*layer_shape, 3])
    anchors = tf.reshape(anchors, [*layer_shape, 2])
    cls_targets = tf.reshape(cls_targets, layer_shape)
    box_targets = tf.reshape(box_targets, [*layer_shape, 4])

    # Losses
    for layer_key, layer_pred in p.items():
      i = int(layer_key) - 3

      obj_targets = tf.zeros_like(layer_pred[..., 0])

      layer_masks = masks[:, i]
      num_matchings = tf.reduce_sum(tf.cast(layer_masks, tf.int32))
      total_num_matchings += tf.cast(num_matchings, tf.float32)

      if num_matchings > 0:
        layer_indices = indices[:, i]
        batch_indices = tf.tile(
            tf.range(batch_size)[:, None], [1, tf.shape(layer_indices)[1]]
        )[..., None]
        layer_indices = tf.concat([batch_indices, layer_indices], axis=-1)
        layer_indices = tf.boolean_mask(layer_indices, layer_masks)
        layer_anchors = tf.boolean_mask(anchors[:, i], layer_masks)

        layer_cls_targets = tf.boolean_mask(cls_targets[:, i], layer_masks)
        layer_box_targets = tf.boolean_mask(box_targets[:, i], layer_masks)

        # In the same shape of layer_target.
        matched_pred = tf.gather_nd(layer_pred, layer_indices)

        pred_xcyc = tf.sigmoid(matched_pred[..., :2]) * 2 - 0.5
        pred_wh = (
            tf.square(tf.sigmoid(matched_pred[..., 2:4]) * 2) * layer_anchors)
        pred_xcycwh = tf.concat([pred_xcyc, pred_wh], axis=-1)
        _, ciou = box_ops.compute_ciou(pred_xcycwh, layer_box_targets)

        box_loss += tf.reduce_mean(1.0 - ciou)
        iou_metric += tf.reduce_mean(ciou)

        # Compute classification loss.
        if self._num_classes > 1:  # cls loss (only if multiple classes)
          t = tf.one_hot(
              layer_cls_targets,
              self._num_classes,
              on_value=self._pos_targets,
              off_value=self._neg_targets,
          )
          cls_loss += tf.reduce_mean(
              self._cls_loss_fn(t, matched_pred[..., 5:]))

        # Compute objectness loss.
        iou_ratio = tf.cast(
            (1.0 - self._iou_mix_ratio)
            + (self._iou_mix_ratio * tf.maximum(tf.stop_gradient(ciou), 0)),
            obj_targets.dtype,
        )
        obj_targets = tf.tensor_scatter_nd_max(
            obj_targets, layer_indices, iou_ratio
        )
      layer_obj_loss = tf.reduce_mean(
          self._obj_loss_fn(obj_targets, layer_pred[..., 4])
      )
      obj_loss += layer_obj_loss * self._balance[i]
      # Updates the balance factor, which is a moving average of previous
      # factor at the same level.
      if self._auto_balance:
        self._balance[i] = self._balance[
            i
        ] * 0.9999 + 0.0001 / tf.stop_gradient(layer_obj_loss)

    # Re-balance the factors so that stride at self._ssi always receives 1.
    if self._auto_balance:
      self._balance = [x / self._balance[self._ssi] for x in self._balance]

    box_loss *= self._box_weight
    obj_loss *= self._obj_weight
    cls_loss *= self._cls_weight

    self._box_loss = tf.stop_gradient(box_loss)
    self._obj_loss = tf.stop_gradient(obj_loss)
    self._cls_loss = tf.stop_gradient(cls_loss)
    self._iou = tf.stop_gradient(iou_metric) / self._num_layers
    self._num_matchings = tf.stop_gradient(
        total_num_matchings) / tf.cast(batch_size, tf.float32)
    self._num_gts = tf.stop_gradient(
        total_num_gts) / tf.cast(batch_size, tf.float32)

    loss = box_loss + obj_loss + cls_loss
    return loss * tf.cast(batch_size, loss.dtype)

  def _build_targets(self, labels, predictions):
    """Finds three matching anchors for each ground-truth."""
    label_shape = tf.shape(labels)
    batch_size, max_boxes = label_shape[0], label_shape[1]
    masks, indices, anch = [], [], []
    cls_targets, box_targets = [], []
    anchor_indices = tf.tile(
        tf.range(self._num_anchors, dtype=tf.float32)[None, None],
        [batch_size, max_boxes, 1],
    )
    # Append anchor indices to labels.
    labels = tf.tile(labels[:, :, None], [1, 1, self._num_anchors, 1])
    labels = tf.concat([labels, anchor_indices[..., None]], axis=-1)

    # Bias is used to determine the matching. 0.5 means matching anchors that
    # fall in the 0.5 differences in the feature map. For instance, a box
    # coordinates of (15.6, 35.4) will match the anchors at [15, 35], [16, 35],
    # and [15, 34].
    bias = 0.5  # bias
    off = (
        tf.constant(
            [
                [0, 0],
                [1, 0], [0, 1], [-1, 0], [0, -1],  # j, k, l, m
                # [1, 1], [1, -1], [-1, 1], [-1, -1],  # jk,jm,lk,lm
            ],
            tf.float32,
        )
        * bias
    )  # offsets

    for i in range(self._num_layers):
      anchors = self._anchors[i]
      _, _, h, w, _ = predictions[str(i + 3)].get_shape().as_list()
      gain = tf.constant([1, w, h, w, h, 1], dtype=tf.float32)

      t = labels * gain

      # Filter out targets that do not match the current anchors.
      wh_ratio = t[..., 3:5] / tf.cast(anchors[None, None], tf.float32)
      labels_mask = tf.less(
          tf.reduce_max(tf.maximum(wh_ratio, 1.0 / wh_ratio), axis=-1),
          self._anchor_threshold,
      )[..., None]
      # Compute valid mask for ground-truths.
      labels_mask = tf.logical_and(t[..., :1] != -1, labels_mask)

      labels_mask = tf.reshape(labels_mask, [batch_size, -1])
      t = tf.reshape(t, [batch_size, -1, 6])

      # Find the matching offsets for valid labels.
      gxy = t[..., 1:3]  # grid xy
      gxi = gain[1:3] - gxy  # inverse
      j, k = tf.split((gxy % 1.0 < bias) & (gxy >= 1.0), 2, axis=-1)
      l, m = tf.split((gxi % 1.0 < bias) & (gxi >= 1.0), 2, axis=-1)

      j, k, l, m = j[..., 0], k[..., 0], l[..., 0], m[..., 0]

      # Note that j and l, k and m are conjugate to each other, so at most one
      # of them will be True during running. Therefore, we can reduce memory
      # usage by gathering the selected index.
      x_map = tf.cast(tf.stack([j, l], axis=-1), tf.int8)
      y_map = tf.cast(tf.stack([k, m], axis=-1), tf.int8)

      # Add the indices offsets.
      x_indices = tf.argmax(x_map, axis=-1) * 2 + 1
      y_indices = tf.argmax(y_map, axis=-1) * 2 + 2
      three_targets_indices = tf.stack(
          [tf.zeros_like(x_indices), x_indices, y_indices], axis=-1
      )[..., None]

      # Gather the selected 3 targets from the 5-target map.
      j = tf.stack([tf.ones_like(j), j, k, l, m], axis=-1)
      three_targets_mask = tf.gather_nd(j, three_targets_indices, batch_dims=2)

      labels_mask = tf.tile(labels_mask[:, :, None], [1, 1, 5])
      t = tf.tile(t[:, :, None], [1, 1, 5, 1])

      labels_mask = tf.gather_nd(
          labels_mask, three_targets_indices, batch_dims=2
      )
      t = tf.gather_nd(t, three_targets_indices, batch_dims=2)

      offsets = tf.zeros_like(gxy)[:, :, None] + off[None, None]
      offsets = tf.gather_nd(offsets, three_targets_indices, batch_dims=2)

      cls_target = tf.cast(t[..., 0], tf.int32)
      gxy, gwh = t[..., 1:3], t[..., 3:5]
      # Find the actual grid locations.
      gij = tf.cast(gxy - offsets * 2, tf.int32)
      gi, gj = tf.split(gij, 2, axis=-1)
      gi, gj = gi[..., 0], gj[..., 0]

      # Append the result.
      anchor_idx = tf.cast(t[..., 5], tf.int32)
      gain = tf.cast(gain, tf.int32)
      gi = tf.clip_by_value(gi, 0, gain[2] - 1)
      gj = tf.clip_by_value(gj, 0, gain[3] - 1)
      gij = tf.stack([gi, gj], axis=-1)

      labels_mask = tf.logical_and(labels_mask, three_targets_mask)
      masks.append(labels_mask)
      indices.append(tf.stack([anchor_idx, gj, gi], axis=-1))
      anch.append(tf.gather(anchors, anchor_idx))
      cls_targets.append(cls_target)
      box_targets.append(
          tf.concat([gxy - tf.cast(gij, tf.float32), gwh], axis=-1))  # box

    # [batch_size, num_layers, num_anchors * max_boxes, num_targets]
    masks = tf.stack(masks, axis=1)
    indices = tf.stack(indices, axis=1)
    anch = tf.stack(anch, axis=1)
    cls_targets = tf.stack(cls_targets, axis=1)
    box_targets = tf.stack(box_targets, axis=1)
    return masks, indices, anch, cls_targets, box_targets

  def report_separate_losses(self):
    return {
        'box_loss': self._box_loss,
        'obj_loss': self._obj_loss,
        'cls_loss': self._cls_loss,
        'iou': self._iou,
    }

  def report_stats(self):
    return {
        'num_gts': self._num_gts,
        'num_matchings': self._num_matchings,
        # No duplicates.
        'num_duplicates': tf.constant(0),
    }

  def get_config(self):
    config = {
        'alpha': self._alpha,
        'gamma': self._gamma,
        'box_weight': self._box_weight,
        'obj_weight': self._obj_weight,
        'cls_weight': self._cls_weight,
        'pos_targets': self._pos_targets,
        'neg_targets': self._neg_targets,
        'num_classes': self._num_classes,
        'num_layers': self._num_layers,
        'num_anchors': self._num_anchors,
        'auto_balance': self._auto_balance,
        'balance': self._balance,
        'strides': self._strides,
        'anchors': self._anchors,
        'input_size': self._input_size,
        'anchor_threshold': self._anchor_threshold,
    }
    base_config = super().get_config()
    return dict(list(base_config.items()) + list(config.items()))


class YoloV7LossOTA(tf_keras.losses.Loss):
  """YOLOv7 loss function with OTA.

  OTA (Optimal Transport Assignment) uses Sinkhorn-Knopp algorithm to copmute
  a matching between anchors and ground-truth labels.

  Paper: https://arxiv.org/pdf/2103.14259.pdf
  """

  def __init__(
      self,
      anchors,
      strides,
      input_size,
      alpha=0.25,
      gamma=1.5,
      box_weight=0.05,
      obj_weight=0.7,
      cls_weight=0.3,
      iou_weight=3.0,
      label_smoothing=0.0,
      anchor_threshold=4.0,
      iou_mix_ratio=1.0,
      num_classes=80,
      auto_balance=False,
      reduction=tf_keras.losses.Reduction.NONE,
      name=None,
  ):
    """Constructor for YOLOv7 loss OTA.

    Follows the implementation here:
      https://github.com/WongKinYiu/yolov7/blob/main/utils/loss.py#L556

    Args:
      anchors: a 2D array represents different anchors used at each level.
      strides: a 1D array represents the strides. Note that all numbers should
        be a power of 2, and they usually start with level 3 and end at level 5
        or 7. Therefore, the list should usually be [8, 16, 32] or [8, 16, 32,
        64, 128].
      input_size: a list containing the height and width of the input image.
      alpha: alpha for focal loss.
      gamma: gamma for focal loss. If set to 0, focal loss will be disabled.
      box_weight: float weight scalar applied to bounding box loss.
      obj_weight: float weight scalar applied to objectness loss.
      cls_weight: float weight scalar applied to class loss.
      iou_weight: float weight scalar to mix class loss and IoU class to
        construct the cost matrix.
      label_smoothing: small float number used to compute positive and negative
        targets. If set to 0, the positive targets will be 1 and negative
        targets will be 0.
      anchor_threshold: threshold for the anchor matching. Larger number allows
        more displacements between anchors and targets.
      iou_mix_ratio: float ratio to mix the IoU score with the positive target,
        which is 1.
      num_classes: number of classes.
      auto_balance: a boolean flag that indicates whether auto balance should be
        used. If used, the default balance factors will automatically update for
        each batch.
      reduction: Reduction method. Should be set to None at all time as this
        loss module always output a loss scalar.
      name: Optional name for the loss.
    """
    # Loss required fields.
    self._num_classes = num_classes
    self._num_layers = len(strides)
    self._num_anchors = len(anchors[0])
    self._anchors = []
    self._strides = strides
    self._input_size = input_size
    self._iou_mix_ratio = iou_mix_ratio

    # Scale down anchors by the strides to match the feature map.
    for i, stride in enumerate(strides):
      self._anchors.append(tf.constant(anchors[i], tf.float32) / stride)

    self._anchor_threshold = anchor_threshold

    self._pos_targets, self._neg_targets = smooth_bce_targets(label_smoothing)
    if gamma > 0:
      self._cls_loss_fn = focal_loss.FocalLoss(
          alpha=alpha, gamma=gamma, reduction=reduction, name='cls_loss')
      self._obj_loss_fn = focal_loss.FocalLoss(
          alpha=alpha, gamma=gamma, reduction=reduction, name='obj_loss')
    else:
      self._cls_loss_fn = tf.nn.sigmoid_cross_entropy_with_logits
      self._obj_loss_fn = tf.nn.sigmoid_cross_entropy_with_logits

    # Weight to combine losses
    self._box_weight = box_weight
    self._obj_weight = obj_weight * input_size[0] / 640 * input_size[1] / 640
    self._cls_weight = cls_weight * num_classes / 80

    # Weight to construct cost matrix
    self._iou_weight = iou_weight

    # Layer balance scalar
    self._balance = _LAYER_BALANCE[str(self._num_layers)][:]
    for i, bal in enumerate(self._balance):
      self._balance[i] = tf.constant(bal, tf.float32)
    self._auto_balance = auto_balance
    assert 16 in strides, (
        'Expect level 4 (stride of 16) always exist in the strides, received %s'
        % strides
    )
    self._ssi = list(strides).index(16) if auto_balance else 0  # stride 16 idx

    super().__init__(reduction=reduction, name=name)

  def call(self, labels, predictions):
    """Comptues the OTA loss.

    Args:
      labels: a dictionary contains the following required keys:
        - classes: class indices in shape [batch_size, max_num_instances].
        - bbox: bounding boxes in shape [batch_size, max_num_instances, 4].
        - image_info: image info in shape [batch_size, 4, 2].
      predictions: a dictionary contains model outputs at different layers.
        They are in shape of [batch_size, h_at_level, w_at_level, num_anchors,
        num_classes + 4 (box coordinates) + 1 (objectness)].

    Returns:
      The scaled loss (up by batch size) from OTA.
    """
    image_info = labels['image_info']
    # Convert labels dictionary into tensors.
    labels = merge_labels(labels)
    p = {}
    for key in predictions:
      # [batch_size, num_anchors, height, width, num_classes + boxes + obj]
      p[key] = tf.transpose(predictions[key], [0, 3, 1, 2, 4])

    cls_loss, box_loss, obj_loss, iou_metric = [tf.zeros(1) for _ in range(4)]
    total_num_matchings = tf.zeros(1)
    total_num_gts = tf.reduce_sum(tf.cast(labels[..., 0] != -1, tf.float32))
    (matched_indices, matched_anchors, matched_mask, matched_targets,
     num_duplicates) = self._build_targets(labels, p, image_info)
    # Get height and width for each layers.
    pre_gen_gains = [
        tf.gather(tf.shape(p[str(i + 3)]), [3, 2, 3, 2])
        for i in range(self._num_layers)
    ]

    batch_size = tf.shape(matched_indices)[0]
    layer_shape = [batch_size, self._num_layers, -1]
    # [anchor_indices, grid_js, grid_is]
    masks = tf.reshape(matched_mask, layer_shape)
    indices = tf.reshape(matched_indices, [*layer_shape, 3])
    anchors = tf.reshape(matched_anchors, [*layer_shape, 2])
    targets = tf.reshape(matched_targets, [*layer_shape, 5])

    # Losses
    for layer_idx, layer_pred in p.items():
      # Always assume the output level starts with 3.
      i = int(layer_idx) - 3

      obj_targets = tf.zeros_like(layer_pred[..., 0])

      # Get layer inputs
      layer_masks = masks[:, i]
      num_matchings = tf.reduce_sum(tf.cast(layer_masks, tf.int32))
      total_num_matchings += tf.cast(num_matchings, tf.float32)

      if num_matchings > 0:
        layer_indices = indices[:, i]
        batch_indices = tf.tile(
            tf.range(batch_size)[:, None], [1, tf.shape(layer_indices)[1]]
        )[..., None]
        layer_indices = tf.concat([batch_indices, layer_indices], axis=-1)
        layer_indices = tf.boolean_mask(layer_indices, layer_masks)
        layer_anchors = tf.boolean_mask(anchors[:, i], layer_masks)

        layer_targets = tf.boolean_mask(targets[:, i], layer_masks)
        layer_cls_targets = tf.cast(layer_targets[:, 0], tf.int32)
        layer_box_targets = layer_targets[:, 1:]

        # In the same shape of layer_target.
        matched_pred = tf.gather_nd(layer_pred, layer_indices)

        pred_xcyc = tf.sigmoid(matched_pred[..., :2]) * 2 - 0.5
        pred_wh = (
            tf.square(tf.sigmoid(matched_pred[..., 2:4]) * 2) * layer_anchors)
        pred_xcycwh = tf.concat([pred_xcyc, pred_wh], axis=-1)

        grid = tf.cast(
            tf.stack(
                [
                    layer_indices[:, 3],  # gi
                    layer_indices[:, 2],  # gj
                    tf.zeros_like(layer_indices[:, 0]),
                    tf.zeros_like(layer_indices[:, 0]),
                ],
                axis=-1,
            ),
            tf.float32,
        )
        target_xcycwh = layer_box_targets * tf.cast(
            pre_gen_gains[i], layer_targets.dtype
        )
        target_xcycwh -= grid
        _, ciou = box_ops.compute_ciou(target_xcycwh, pred_xcycwh)

        box_loss += tf.reduce_mean(1.0 - ciou)
        iou_metric += tf.reduce_mean(ciou)

        # Compute classification loss.
        if self._num_classes > 1:  # cls loss (only if multiple classes)
          t = tf.one_hot(
              layer_cls_targets,
              self._num_classes,
              on_value=self._pos_targets,
              off_value=self._neg_targets,
          )
          cls_loss += tf.reduce_mean(
              self._cls_loss_fn(t, matched_pred[..., 5:]))

        # Compute objectness loss.
        iou_ratio = tf.cast(
            (1.0 - self._iou_mix_ratio)
            + (self._iou_mix_ratio * tf.maximum(tf.stop_gradient(ciou), 0)),
            obj_targets.dtype,
        )
        obj_targets = tf.tensor_scatter_nd_max(
            obj_targets, layer_indices, iou_ratio
        )
      layer_obj_loss = tf.reduce_mean(
          self._obj_loss_fn(obj_targets, layer_pred[..., 4])
      )
      obj_loss += layer_obj_loss * self._balance[i]
      # Updates the balance factor, which is a moving average of previous
      # factor at the same level.
      if self._auto_balance:
        self._balance[i] = self._balance[
            i
        ] * 0.9999 + 0.0001 / tf.stop_gradient(layer_obj_loss)

    # Re-balance the factors so that stride at self._ssi always receives 1.
    if self._auto_balance:
      self._balance = [x / self._balance[self._ssi] for x in self._balance]

    # Keep separate losses for summary purpose.
    box_loss *= self._box_weight
    obj_loss *= self._obj_weight
    cls_loss *= self._cls_weight

    self._iou = tf.stop_gradient(iou_metric) / self._num_layers
    self._num_matchings = tf.stop_gradient(
        total_num_matchings) / tf.cast(batch_size, tf.float32)
    self._num_gts = total_num_gts / tf.cast(batch_size, tf.float32)
    self._num_duplicates = tf.stop_gradient(
        num_duplicates) / tf.cast(batch_size, tf.float32)
    self._box_loss = tf.stop_gradient(box_loss)
    self._obj_loss = tf.stop_gradient(obj_loss)
    self._cls_loss = tf.stop_gradient(cls_loss)

    loss = box_loss + obj_loss + cls_loss

    # Scale up the loss by batch size.
    return loss * tf.cast(batch_size, loss.dtype)

  def _build_targets(self, labels, predictions, image_info):
    """Finds the matching targets using Sinkhorn-Knopp."""
    # Find the three positives matching first for predictions.
    masks, indices, anchors = self._find_three_positives(labels, predictions)

    batch_size = tf.shape(masks)[0]

    # Collect the predictions.
    p_box, p_cls, p_obj = [], [], []
    for layer_key, layer_p in predictions.items():
      # Always assume level starts from 3.
      i = int(layer_key) - 3
      layer_indices = tf.reshape(indices[:, i], [batch_size, -1, 3])
      anchor = tf.reshape(anchors[:, i], [batch_size, -1, 2])

      fg_pred = tf.gather_nd(layer_p, layer_indices, batch_dims=1)

      grid = tf.stack([layer_indices[..., 2], layer_indices[..., 1]], axis=-1)
      grid = tf.cast(grid, fg_pred.dtype)

      pxy = (tf.sigmoid(fg_pred[..., :2]) * 2 - 0.5 + grid) * self._strides[i]
      pwh = (
          tf.square(tf.sigmoid(fg_pred[..., 2:4]) * 2)
          * anchor
          * self._strides[i]
      )
      pxywh = tf.concat([pxy, pwh], axis=-1)

      p_box.append(pxywh)
      p_obj.append(fg_pred[..., 4:5])
      p_cls.append(fg_pred[..., 5:])

    p_box = tf.concat(p_box, axis=1)
    p_cls = tf.concat(p_cls, axis=1)
    p_obj = tf.concat(p_obj, axis=1)

    # Compute valid masks for both targets and predictions.
    t_mask = labels[..., 0] != -1
    p_mask = tf.reshape(masks, [batch_size, -1])
    # [anchor_idx, gj, gi]
    indices = tf.reshape(indices, [batch_size, -1, 3])
    anchors = tf.reshape(anchors, [batch_size, -1, 2])

    num_preds = tf.shape(p_box)[1]
    num_gts = tf.shape(labels)[1]

    # Computes pair-wise IoU.
    t_box = labels[..., 1:5] * tf.tile(image_info[0, 1], [2])

    pair_wise_iou = box_ops.compute_iou(t_box[:, :, None], p_box[:, None])
    pair_wise_iou_loss = -tf.math.log(pair_wise_iou + 1e-8)

    # Computes pair-wise class loss.
    y = tf.sqrt(tf.sigmoid(p_cls) * tf.sigmoid(p_obj))
    # Add 1e-9 to avoid nan.
    logits = tf.math.log(y / (1 - y + 1e-9) + 1e-9)
    logits = tf.tile(logits[:, None], [1, num_gts, 1, 1])

    t_cls = tf.cast(labels[..., 0], tf.int32)
    class_labels = tf.one_hot(t_cls, self._num_classes, dtype=tf.float32)
    class_labels = tf.tile(class_labels[:, :, None], [1, 1, num_preds, 1])

    pair_wise_cls_loss = tf.reduce_sum(
        tf.nn.sigmoid_cross_entropy_with_logits(class_labels, logits), axis=-1
    )

    # Compute the cost matrix and its corresponding valid mask.
    cost_mask = tf.logical_and(t_mask[..., None], p_mask[:, None])
    cost = tf.stop_gradient(pair_wise_cls_loss + 3 * pair_wise_iou_loss)
    largest_cost = tf.reduce_max(cost)

    # Set invalid IoU to 0.0 for top_k.
    valid_iou = tf.where(cost_mask, pair_wise_iou, tf.zeros_like(pair_wise_iou))

    # Compute top-10 IoUs from valid IoUs for each target.
    # When matched predictions is smaller than 10, we only want the top-k where
    # k is the total size of the matched predictions (k < 10).
    top_k_mask = tf.less(
        tf.range(10)[None],
        tf.minimum(10, tf.reduce_sum(tf.cast(p_mask, tf.int32), axis=-1))[
            :, None
        ],
    )
    top_k_mask = tf.logical_and(top_k_mask[:, None], t_mask[..., None])
    top_k, _ = tf.nn.top_k(valid_iou, k=10)
    top_k = tf.where(top_k_mask, top_k, tf.zeros_like(top_k))

    # Use top_k to compute the dynamic ks for target matching. Each target_i can
    # match to k_i predictions, and k_i is computed based on the pair-wise
    # valid IoU.
    dynamic_ks = tf.maximum(tf.cast(tf.reduce_sum(top_k, axis=-1), tf.int32), 1)
    dynamic_ks = tf.where(t_mask, dynamic_ks, tf.zeros_like(dynamic_ks))
    dynamic_ks = tf.stop_gradient(dynamic_ks)
    dynamic_mask = tf.range(10)[None, None] < dynamic_ks[..., None]

    # Set the invalid field to maximum cost so that they won't be selected
    # during matching.
    cost = tf.where(cost_mask, cost, tf.ones_like(cost) * (largest_cost + 1))

    matching_matrix = tf.zeros_like(cost, dtype=tf.int32)
    _, pred_idx = tf.nn.top_k(-cost, k=10)

    # Update matching matrix.
    # [batch_size, num_gts, 10]
    batch_idx = tf.tile(tf.range(batch_size)[:, None, None], [1, num_gts, 10])
    gt_idx = tf.tile(tf.range(num_gts)[None, :, None], [batch_size, 1, 10])
    matched_indices = tf.stack([batch_idx, gt_idx, pred_idx], axis=-1)
    matching_matrix = tf.tensor_scatter_nd_add(
        matching_matrix,
        matched_indices,
        tf.cast(dynamic_mask, matching_matrix.dtype),
    )

    # Detect if there is a detection matches to multiple targets, if so, we
    # assign it to the target with minimum cost.
    duplicate_mask = tf.reduce_sum(matching_matrix, axis=1) > 1
    num_duplicates = tf.reduce_sum(tf.cast(duplicate_mask, tf.float32))
    cost_argmin = tf.argmin(cost, axis=1, output_type=tf.int32)

    remove_mask = tf.tile(duplicate_mask[:, None], [1, num_gts, 1])
    matching_matrix = tf.where(
        remove_mask, tf.zeros_like(matching_matrix), matching_matrix)

    min_mask = tf.equal(
        tf.tile(tf.range(num_gts)[None, :, None], [batch_size, 1, num_preds]),
        cost_argmin[:, None],
    )
    update_mask = tf.logical_and(min_mask, duplicate_mask[:, None])
    matching_matrix = tf.where(
        update_mask, tf.ones_like(matching_matrix), matching_matrix)

    # Find the final matching and collect the matched targets.
    matched_gt_indices = tf.argmax(
        matching_matrix, axis=1, output_type=tf.int32
    )
    matched_mask = tf.reduce_sum(matching_matrix, axis=1) > 0
    matched_targets = tf.gather_nd(
        labels, matched_gt_indices[..., None], batch_dims=1
    )
    return indices, anchors, matched_mask, matched_targets, num_duplicates

  def _find_three_positives(self, labels, predictions):
    """Finds three matching anchors for each ground-truth."""
    label_shape = tf.shape(labels)
    batch_size, max_boxes = label_shape[0], label_shape[1]
    masks, indices, anch = [], [], []
    anchor_indices = tf.tile(
        tf.range(self._num_anchors, dtype=tf.float32)[None, None],
        [batch_size, max_boxes, 1],
    )
    # Append anchor indices to labels.
    labels = tf.tile(labels[:, :, None], [1, 1, self._num_anchors, 1])
    labels = tf.concat([labels, anchor_indices[..., None]], axis=-1)

    # Bias is used to determine the matching. 0.5 means matching anchors that
    # fall in the 0.5 differences in the feature map. For instance, a box
    # coordinates of (15.6, 35.4) will match the anchors at [15, 35], [16, 35],
    # and [15, 34].
    bias = 0.5  # bias
    off = (
        tf.constant(
            [
                [0, 0],
                [1, 0], [0, 1], [-1, 0], [0, -1],  # j, k, l, m
                # [1, 1], [1, -1], [-1, 1], [-1, -1],  # jk,jm,lk,lm
            ],
            tf.float32,
        )
        * bias
    )  # offsets

    for i in range(self._num_layers):
      anchors = self._anchors[i]
      _, _, h, w, _ = predictions[str(i + 3)].get_shape().as_list()
      gain = tf.constant([1, w, h, w, h, 1], dtype=tf.float32)

      t = labels * gain

      # Filter out targets that do not match the current anchors.
      wh_ratio = t[..., 3:5] / tf.cast(anchors[None, None], tf.float32)
      labels_mask = tf.less(
          tf.reduce_max(tf.maximum(wh_ratio, 1.0 / wh_ratio), axis=-1),
          self._anchor_threshold,
      )[..., None]
      # Compute valid mask for ground-truths.
      labels_mask = tf.logical_and(t[..., :1] != -1, labels_mask)

      labels_mask = tf.reshape(labels_mask, [batch_size, -1])
      t = tf.reshape(t, [batch_size, -1, 6])

      # Find the matching offsets for valid labels.
      gxy = t[..., 1:3]  # grid xy
      gxi = gain[1:3] - gxy  # inverse
      j, k = tf.split((gxy % 1.0 < bias) & (gxy >= 1.0), 2, axis=-1)
      l, m = tf.split((gxi % 1.0 < bias) & (gxi >= 1.0), 2, axis=-1)

      j, k, l, m = j[..., 0], k[..., 0], l[..., 0], m[..., 0]

      # Note that j and l, k and m are conjugate to each other, so at most one
      # of them will be True during running. Therefore, we can reduce memory
      # usage by gathering the selected index.
      x_map = tf.cast(tf.stack([j, l], axis=-1), tf.int8)
      y_map = tf.cast(tf.stack([k, m], axis=-1), tf.int8)

      # Add the indices offsets.
      x_indices = tf.argmax(x_map, axis=-1) * 2 + 1
      y_indices = tf.argmax(y_map, axis=-1) * 2 + 2
      three_targets_indices = tf.stack(
          [tf.zeros_like(x_indices), x_indices, y_indices], axis=-1
      )[..., None]

      # Gather the selected 3 targets from the 5-target map.
      j = tf.stack([tf.ones_like(j), j, k, l, m], axis=-1)
      three_targets_mask = tf.gather_nd(j, three_targets_indices, batch_dims=2)

      labels_mask = tf.tile(labels_mask[:, :, None], [1, 1, 5])
      t = tf.tile(t[:, :, None], [1, 1, 5, 1])

      labels_mask = tf.gather_nd(
          labels_mask, three_targets_indices, batch_dims=2
      )
      t = tf.gather_nd(t, three_targets_indices, batch_dims=2)

      offsets = tf.zeros_like(gxy)[:, :, None] + off[None, None]
      offsets = tf.gather_nd(offsets, three_targets_indices, batch_dims=2)

      gxy = t[..., 1:3]
      # Find the actual grid locations.
      gij = tf.cast(gxy - offsets * 2, tf.int32)
      gi, gj = tf.split(gij, 2, axis=-1)
      gi, gj = gi[..., 0], gj[..., 0]

      # Append the result.
      anchor_idx = tf.cast(t[..., 5], tf.int32)
      gain = tf.cast(gain, tf.int32)
      gi = tf.clip_by_value(gi, 0, gain[2] - 1)
      gj = tf.clip_by_value(gj, 0, gain[3] - 1)

      labels_mask = tf.logical_and(labels_mask, three_targets_mask)
      masks.append(labels_mask)
      indices.append(tf.stack([anchor_idx, gj, gi], axis=-1))
      anch.append(tf.gather(anchors, anchor_idx))

    # [batch_size, num_layers, num_anchors * max_boxes, num_targets]
    masks = tf.stack(masks, axis=1)
    indices = tf.stack(indices, axis=1)
    anch = tf.stack(anch, axis=1)
    return masks, indices, anch

  def report_stats(self):
    return {
        'num_gts': self._num_gts,
        'num_matchings': self._num_matchings,
        'num_duplicates': self._num_duplicates,
    }

  def report_separate_losses(self):
    """Returns separate losses that construct the reported loss."""
    return {
        'iou': self._iou,
        'box_loss': self._box_loss,
        'obj_loss': self._obj_loss,
        'cls_loss': self._cls_loss,
    }

  def get_config(self):
    """Configs for the loss constructor."""
    config = {
        'alpha': self._alpha,
        'gamma': self._gamma,
        'box_weight': self._box_weight,
        'obj_weight': self._obj_weight,
        'cls_weight': self._cls_weight,
        'iou_weight': self._iou_weight,
        'iou_mix_ratio': self._iou_mix_ratio,
        'pos_targets': self._pos_targets,
        'neg_targets': self._neg_targets,
        'num_classes': self._num_classes,
        'num_layers': self._num_layers,
        'num_anchors': self._num_anchors,
        'auto_balance': self._auto_balance,
        'balance': self._balance,
        'strides': self._strides,
        'anchors': self._anchors,
        'input_size': self._input_size,
        'anchor_threshold': self._anchor_threshold,
    }
    base_config = super().get_config()
    return dict(list(base_config.items()) + list(config.items()))