# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

"""Yolo Loss function."""
import abc
import collections
import functools

import tensorflow as tf, tf_keras

from official.projects.yolo.ops import box_ops
from official.projects.yolo.ops import loss_utils
from official.projects.yolo.ops import math_ops

class YoloLossBase(object, metaclass=abc.ABCMeta):
  """Parameters for the YOLO loss functions used at each detection generator.

  This base class implements the base functionality required to implement a Yolo
  Loss function.

  def __init__(self,
    """Loss Function Initialization.

      classes: `int` for the number of classes
      anchors: `List[List[int]]` for the anchor boxes that are used in the model
        at all levels. For anchor free prediction set the anchor list to be the
        same as the image resolution.
      path_stride: `int` for how much to scale this level to get the orginal
        input shape.
      ignore_thresh: `float` for the IOU value over which the loss is not
        propagated, and a detection is assumed to have been made.
      truth_thresh: `float` for the IOU value over which the loss is propagated
        despite a detection being made.
      loss_type: `str` for the typeof iou loss to use with in {ciou, diou, giou,
      iou_normalizer: `float` for how much to scale the loss on the IOU or the
      cls_normalizer: `float` for how much to scale the loss on the classes.
      object_normalizer: `float` for how much to scale loss on the detection
      label_smoothing: `float` for how much to smooth the loss on the classes.
      objectness_smooth: `float` for how much to smooth the loss on the
        detection map.
      update_on_repeat: `bool` for whether to replace with the newest or the
        best value when an index is consumed by multiple objects.
      box_type: `bool` for which scaling type to use.
      scale_x_y: dictionary `float` values inidcating how far each pixel can see
        outside of its containment of 1.0. a value of 1.2 indicates there is a
        20% extended radius around each pixel that this specific pixel can
        predict values for a center at. the center can range from 0 - value/2 to
        1 + value/2, this value is set in the yolo filter, and resused here.
        there should be one value for scale_xy for each level from min_level to
      max_delta: gradient clipping to apply to the box loss.
    self._loss_type = loss_type
    self._classes = classes
    self._num = tf.cast(len(anchors), dtype=tf.int32)
    self._truth_thresh = truth_thresh
    self._ignore_thresh = ignore_thresh
    self._anchors = anchors

    self._iou_normalizer = iou_normalizer
    self._cls_normalizer = cls_normalizer
    self._object_normalizer = object_normalizer
    self._scale_x_y = scale_x_y
    self._max_delta = max_delta

    self._label_smoothing = tf.cast(label_smoothing, tf.float32)
    self._objectness_smooth = float(objectness_smooth)
    self._update_on_repeat = update_on_repeat
    self._box_type = box_type
    self._path_stride = path_stride

    box_kwargs = dict(
    self._decode_boxes = functools.partial(
        loss_utils.get_predicted_box, **box_kwargs)

    self._search_pairs = lambda *args: (None, None, None, None)

  def box_loss(self, true_box, pred_box, darknet=False):
    """Call iou function and use it to compute the loss for the box maps."""
    if self._loss_type == 'giou':
      iou, liou = box_ops.compute_giou(true_box, pred_box)
    elif self._loss_type == 'ciou':
      iou, liou = box_ops.compute_ciou(true_box, pred_box, darknet=darknet)
      liou = iou = box_ops.compute_iou(true_box, pred_box)
    loss_box = 1 - liou
    return iou, liou, loss_box

  def _tiled_global_box_search(self,
    """Search of all groundtruths to associate groundtruths to predictions."""

    boxes = box_ops.yxyx_to_xcycwh(boxes)

    if scale is not None:
      boxes = boxes * tf.cast(tf.stop_gradient(scale), boxes.dtype)

    # Search all predictions against ground truths to find mathcing boxes for
    # each pixel.
    _, _, iou_max, _ = self._search_pairs(pred_boxes, pred_classes, boxes,

    if iou_max is None:
      return true_conf, tf.ones_like(true_conf)

    # Find the exact indexes to ignore and keep.
    ignore_mask = tf.cast(iou_max < self._ignore_thresh, pred_boxes.dtype)
    iou_mask = iou_max > self._ignore_thresh

    if not smoothed:
      # Ignore all pixels where a box was not supposed to be predicted but a
      # high confidence box was predicted.
      obj_mask = true_conf + (1 - true_conf) * ignore_mask
      # Replace pixels in the tre confidence map with the max iou predicted
      # with in that cell.
      obj_mask = tf.ones_like(true_conf)
      iou_ = (1 - self._objectness_smooth) + self._objectness_smooth * iou_max
      iou_ = tf.where(iou_max > 0, iou_, tf.zeros_like(iou_))
      true_conf = tf.where(iou_mask, iou_, true_conf)

    # Stop gradient so while loop is not tracked.
    obj_mask = tf.stop_gradient(obj_mask)
    true_conf = tf.stop_gradient(true_conf)
    return true_conf, obj_mask

  def __call__(self, true_counts, inds, y_true, boxes, classes, y_pred):
    """Call function to compute the loss and a set of metrics per FPN level.

      true_counts: `Tensor` of shape [batchsize, height, width, num_anchors]
        represeneting how many boxes are in a given pixel [j, i] in the output
      inds: `Tensor` of shape [batchsize, None, 3] indicating the location [j,
        i] that a given box is associatied with in the FPN prediction map.
      y_true: `Tensor` of shape [batchsize, None, 8] indicating the actual box
        associated with each index in the inds tensor list.
      boxes: `Tensor` of shape [batchsize, None, 4] indicating the original
        ground truth boxes for each image as they came from the decoder used for
        bounding box search.
      classes: `Tensor` of shape [batchsize, None, 1] indicating the original
        ground truth classes for each image as they came from the decoder used
        for bounding box search.
      y_pred: `Tensor` of shape [batchsize, height, width, output_depth] holding
        the models output at a specific FPN level.

      loss: `float` for the actual loss.
      box_loss: `float` loss on the boxes used for metrics.
      conf_loss: `float` loss on the confidence used for metrics.
      class_loss: `float` loss on the classes used for metrics.
      avg_iou: `float` metric for the average iou between predictions and ground
      avg_obj: `float` metric for the average confidence of the model for
    (loss, box_loss, conf_loss, class_loss, mean_loss, iou, pred_conf, ind_mask,
     grid_mask) = self._compute_loss(true_counts, inds, y_true, boxes, classes,

    # Metric compute using done here to save time and resources.
    sigmoid_conf = tf.stop_gradient(tf.sigmoid(pred_conf))
    iou = tf.stop_gradient(iou)
    avg_iou = loss_utils.average_iou(
        loss_utils.apply_mask(tf.squeeze(ind_mask, axis=-1), iou))
    avg_obj = loss_utils.average_iou(
        tf.squeeze(sigmoid_conf, axis=-1) * grid_mask)
    return (loss, box_loss, conf_loss, class_loss, mean_loss,
            tf.stop_gradient(avg_iou), tf.stop_gradient(avg_obj))

  def _build_per_path_attributes(self):
    """Additional initialization required for each YOLO loss version."""

  def _compute_loss(self, true_counts, inds, y_true, boxes, classes, y_pred):
    """The actual logic to apply to the raw model for optimization."""

  def post_path_aggregation(self, loss, box_loss, conf_loss, class_loss,
                            ground_truths, predictions):  # pylint:disable=unused-argument
    """This method allows for post processing of a loss value.

    After the loss has been aggregated across all the FPN levels some post
    proceessing may need to occur to poroperly scale the loss. The default
    behavior is to pass the loss through with no alterations. Passing the
    individual losses for each mask will allow for aggeregation of loss across
    paths for some losses.

      loss: `tf.float` scalar for the actual loss.
      box_loss: `tf.float` for the loss on the boxs only.
      conf_loss: `tf.float` for the loss on the confidences only.
      class_loss: `tf.float` for the loss on the classes only.
      ground_truths: `Dict` holding all the ground truth tensors.
      predictions: `Dict` holding all the predicted values.

      loss: `tf.float` scalar for the scaled loss.
      scale: `tf.float` how much the loss was scaled by.
    del box_loss
    del conf_loss
    del class_loss
    del ground_truths
    del predictions
    return loss, tf.ones_like(loss)

  def cross_replica_aggregation(self, loss, num_replicas_in_sync):
    """This controls how the loss should be aggregated across replicas."""

def grad_sigmoid(values):
  """This function scales the gradient as if a signmoid was applied.

  This is used in the Darknet Loss when the choosen box type is the scaled
  coordinate type. This function is used to match the propagated gradient to
  match that of the Darkent Yolov4 model. This is an Identity operation that
  allows us to add some extra steps to the back propagation.

    values: A tensor of any shape.

    values: The unaltered input tensor.
    delta: A custom gradient function that adds the sigmoid step to the

  def delta(dy):
    t = tf.math.sigmoid(values)
    return dy * t * (1 - t)

  return values, delta

class DarknetLoss(YoloLossBase):
  """This class implements the full logic for the standard Yolo models."""

  def _build_per_path_attributes(self):
    """Paramterization of pair wise search and grid generators.

    Objects created here are used for box decoding and dynamic ground truth
    self._anchor_generator = loss_utils.GridGenerator(

    if self._ignore_thresh > 0.0:
      self._search_pairs = loss_utils.PairWiseSearch(
          iou_type='iou', any_match=True, min_conf=0.25)

  def _compute_loss(self, true_counts, inds, y_true, boxes, classes, y_pred):
    """Per FPN path loss logic used for Yolov3, Yolov4, and Yolo-Tiny."""
    if self._box_type == 'scaled':
      # Darknet Model Propagates a sigmoid once in back prop so we replicate
      # that behaviour
      y_pred = grad_sigmoid(y_pred)

    # Generate and store constants and format output.
    shape = tf.shape(true_counts)
    batch_size, width, height, num = shape[0], shape[1], shape[2], shape[3]
    fwidth = tf.cast(width, tf.float32)
    fheight = tf.cast(height, tf.float32)
    grid_points, anchor_grid = self._anchor_generator(
        width, height, batch_size, dtype=tf.float32)

    # Cast all input compontnts to float32 and stop gradient to save memory.
    boxes = tf.stop_gradient(tf.cast(boxes, tf.float32))
    classes = tf.stop_gradient(tf.cast(classes, tf.float32))
    y_true = tf.stop_gradient(tf.cast(y_true, tf.float32))
    true_counts = tf.stop_gradient(tf.cast(true_counts, tf.float32))
    true_conf = tf.stop_gradient(tf.clip_by_value(true_counts, 0.0, 1.0))
    grid_points = tf.stop_gradient(grid_points)
    anchor_grid = tf.stop_gradient(anchor_grid)

    # Split all the ground truths to use as separate items in loss computation.
    (true_box, ind_mask, true_class) = tf.split(y_true, [4, 1, 1], axis=-1)
    true_conf = tf.squeeze(true_conf, axis=-1)
    true_class = tf.squeeze(true_class, axis=-1)
    grid_mask = true_conf

    # Splits all predictions.
    y_pred = tf.cast(
        tf.reshape(y_pred, [batch_size, width, height, num, -1]), tf.float32)
    pred_box, pred_conf, pred_class = tf.split(y_pred, [4, 1, -1], axis=-1)

    # Decode the boxes to be used for loss compute.
    _, _, pred_box = self._decode_boxes(
        fwidth, fheight, pred_box, anchor_grid, grid_points, darknet=True)

    # If the ignore threshold is enabled, search all boxes ignore all
    # IOU valeus larger than the ignore threshold that are not in the
    # noted ground truth list.
    if self._ignore_thresh != 0.0:
      (true_conf, obj_mask) = self._tiled_global_box_search(
          smoothed=self._objectness_smooth > 0)

    # Build the one hot class list that are used for class loss.
    true_class = tf.one_hot(
        tf.cast(true_class, tf.int32),
    true_class = tf.stop_gradient(loss_utils.apply_mask(ind_mask, true_class))

    # Reorganize the one hot class list as a grid.
    true_class_grid = loss_utils.build_grid(
        inds, true_class, pred_class, ind_mask, update=False)
    true_class_grid = tf.stop_gradient(true_class_grid)

    # Use the class mask to find the number of objects located in
    # each predicted grid cell/pixel.
    counts = true_class_grid
    counts = tf.reduce_sum(counts, axis=-1, keepdims=True)
    reps = tf.gather_nd(counts, inds, batch_dims=1)
    reps = tf.squeeze(reps, axis=-1)
    reps = tf.stop_gradient(tf.where(reps == 0.0, tf.ones_like(reps), reps))

    # Compute the loss for only the cells in which the boxes are located.
    pred_box = loss_utils.apply_mask(ind_mask,
                                     tf.gather_nd(pred_box, inds, batch_dims=1))
    iou, _, box_loss = self.box_loss(true_box, pred_box, darknet=True)
    box_loss = loss_utils.apply_mask(tf.squeeze(ind_mask, axis=-1), box_loss)
    box_loss = math_ops.divide_no_nan(box_loss, reps)
    box_loss = tf.cast(tf.reduce_sum(box_loss, axis=1), dtype=y_pred.dtype)

    if self._update_on_repeat:
      # Converts list of gound truths into a grid where repeated values
      # are replaced by the most recent value. So some class identities may
      # get lost but the loss computation will be more stable. Results are
      # more consistent.

      # Compute the sigmoid binary cross entropy for the class maps.
      class_loss = tf.reduce_mean(
              tf.expand_dims(true_class_grid, axis=-1),
              tf.expand_dims(pred_class, axis=-1), self._label_smoothing),

      # Apply normalization to the class losses.
      if self._cls_normalizer < 1.0:
        # Build a mask based on the true class locations.
        cls_norm_mask = true_class_grid
        # Apply the classes weight to class indexes were one_hot is one.
        class_loss *= ((1 - cls_norm_mask) +
                       cls_norm_mask * self._cls_normalizer)

      # Mask to the class loss and compute the sum over all the objects.
      class_loss = tf.reduce_sum(class_loss, axis=-1)
      class_loss = loss_utils.apply_mask(grid_mask, class_loss)
      class_loss = math_ops.rm_nan_inf(class_loss, val=0.0)
      class_loss = tf.cast(
          tf.reduce_sum(class_loss, axis=(1, 2, 3)), dtype=y_pred.dtype)
      # Computes the loss while keeping the structure as a list in
      # order to ensure all objects are considered. In some cases can
      # make training more unstable but may also return higher APs.
      pred_class = loss_utils.apply_mask(
          ind_mask, tf.gather_nd(pred_class, inds, batch_dims=1))
      class_loss = tf_keras.losses.binary_crossentropy(
          tf.expand_dims(true_class, axis=-1),
          tf.expand_dims(pred_class, axis=-1),
      class_loss = loss_utils.apply_mask(ind_mask, class_loss)
      class_loss = math_ops.divide_no_nan(class_loss,
                                          tf.expand_dims(reps, axis=-1))
      class_loss = tf.cast(
          tf.reduce_sum(class_loss, axis=(1, 2)), dtype=y_pred.dtype)
      class_loss *= self._cls_normalizer

    # Compute the sigmoid binary cross entropy for the confidence maps.
    bce = tf.reduce_mean(
            tf.expand_dims(true_conf, axis=-1), pred_conf, 0.0),

    # Mask the confidence loss and take the sum across all the grid cells.
    if self._ignore_thresh != 0.0:
      bce = loss_utils.apply_mask(obj_mask, bce)
    conf_loss = tf.cast(tf.reduce_sum(bce, axis=(1, 2, 3)), dtype=y_pred.dtype)

    # Apply the weights to each loss.
    box_loss *= self._iou_normalizer
    conf_loss *= self._object_normalizer

    # Add all the losses together then take the mean over the batches.
    loss = box_loss + class_loss + conf_loss
    loss = tf.reduce_mean(loss)

    # Reduce the mean of the losses to use as a metric.
    box_loss = tf.reduce_mean(box_loss)
    conf_loss = tf.reduce_mean(conf_loss)
    class_loss = tf.reduce_mean(class_loss)

    return (loss, box_loss, conf_loss, class_loss, loss, iou, pred_conf,
            ind_mask, grid_mask)

  def cross_replica_aggregation(self, loss, num_replicas_in_sync):
    """This method is not specific to each loss path, but each loss type."""
    return loss / num_replicas_in_sync

class ScaledLoss(YoloLossBase):
  """This class implements the full logic for the scaled Yolo models."""

  def _build_per_path_attributes(self):
    """Paramterization of pair wise search and grid generators.

    Objects created here are used for box decoding and dynamic ground truth
    self._anchor_generator = loss_utils.GridGenerator(

    if self._ignore_thresh > 0.0:
      self._search_pairs = loss_utils.PairWiseSearch(
          iou_type=self._loss_type, any_match=False, min_conf=0.25)

    self._cls_normalizer = self._cls_normalizer * self._classes / 80

  def _compute_loss(self, true_counts, inds, y_true, boxes, classes, y_pred):
    """Per FPN path loss logic for Yolov4-csp, Yolov4-Large, and Yolov5."""
    # Generate shape constants.
    shape = tf.shape(true_counts)
    batch_size, width, height, num = shape[0], shape[1], shape[2], shape[3]
    fwidth = tf.cast(width, tf.float32)
    fheight = tf.cast(height, tf.float32)

    # Cast all input compontnts to float32 and stop gradient to save memory.
    y_true = tf.cast(y_true, tf.float32)
    true_counts = tf.cast(true_counts, tf.float32)
    true_conf = tf.clip_by_value(true_counts, 0.0, 1.0)
    grid_points, anchor_grid = self._anchor_generator(
        width, height, batch_size, dtype=tf.float32)

    # Split the y_true list.
    (true_box, ind_mask, true_class) = tf.split(y_true, [4, 1, 1], axis=-1)
    grid_mask = true_conf = tf.squeeze(true_conf, axis=-1)
    true_class = tf.squeeze(true_class, axis=-1)
    num_objs = tf.cast(tf.reduce_sum(ind_mask), dtype=y_pred.dtype)

    # Split up the predicitons.
    y_pred = tf.cast(
        tf.reshape(y_pred, [batch_size, width, height, num, -1]), tf.float32)
    pred_box, pred_conf, pred_class = tf.split(y_pred, [4, 1, -1], axis=-1)

    # Decode the boxes for loss compute.
    scale, pred_box, pbg = self._decode_boxes(
        fwidth, fheight, pred_box, anchor_grid, grid_points, darknet=False)

    # If the ignore threshold is enabled, search all boxes ignore all
    # IOU valeus larger than the ignore threshold that are not in the
    # noted ground truth list.
    if self._ignore_thresh != 0.0:
      (_, obj_mask) = self._tiled_global_box_search(

    # Scale and shift and select the ground truth boxes
    # and predictions to the prediciton domain.
    if self._box_type == 'anchor_free':
      true_box = loss_utils.apply_mask(ind_mask,
                                       (scale * self._path_stride * true_box))
      offset = tf.cast(
          tf.gather_nd(grid_points, inds, batch_dims=1), true_box.dtype)
      offset = tf.concat([offset, tf.zeros_like(offset)], axis=-1)
      true_box = loss_utils.apply_mask(ind_mask, (scale * true_box) - offset)
    pred_box = loss_utils.apply_mask(ind_mask,
                                     tf.gather_nd(pred_box, inds, batch_dims=1))

    # Select the correct/used prediction classes.
    true_class = tf.one_hot(
        tf.cast(true_class, tf.int32),
    true_class = loss_utils.apply_mask(ind_mask, true_class)
    pred_class = loss_utils.apply_mask(
        ind_mask, tf.gather_nd(pred_class, inds, batch_dims=1))

    # Compute the box loss.
    _, iou, box_loss = self.box_loss(true_box, pred_box, darknet=False)
    box_loss = loss_utils.apply_mask(tf.squeeze(ind_mask, axis=-1), box_loss)
    box_loss = math_ops.divide_no_nan(tf.reduce_sum(box_loss), num_objs)

    # Use the box IOU to build the map for confidence loss computation.
    iou = tf.maximum(tf.stop_gradient(iou), 0.0)
    smoothed_iou = ((
        (1 - self._objectness_smooth) * tf.cast(ind_mask, iou.dtype)) +
                    self._objectness_smooth * tf.expand_dims(iou, axis=-1))
    smoothed_iou = loss_utils.apply_mask(ind_mask, smoothed_iou)
    true_conf = loss_utils.build_grid(
        inds, smoothed_iou, pred_conf, ind_mask, update=self._update_on_repeat)
    true_conf = tf.squeeze(true_conf, axis=-1)

    # Compute the cross entropy loss for the confidence map.
    bce = tf_keras.losses.binary_crossentropy(
        tf.expand_dims(true_conf, axis=-1), pred_conf, from_logits=True)
    if self._ignore_thresh != 0.0:
      bce = loss_utils.apply_mask(obj_mask, bce)
      conf_loss = tf.reduce_sum(bce) / tf.reduce_sum(obj_mask)
      conf_loss = tf.reduce_mean(bce)

    # Compute the cross entropy loss for the class maps.
    class_loss = tf_keras.losses.binary_crossentropy(
    class_loss = loss_utils.apply_mask(
        tf.squeeze(ind_mask, axis=-1), class_loss)
    class_loss = math_ops.divide_no_nan(tf.reduce_sum(class_loss), num_objs)

    # Apply the weights to each loss.
    box_loss *= self._iou_normalizer
    class_loss *= self._cls_normalizer
    conf_loss *= self._object_normalizer

    # Add all the losses together then take the sum over the batches.
    mean_loss = box_loss + class_loss + conf_loss
    loss = mean_loss * tf.cast(batch_size, mean_loss.dtype)

    return (loss, box_loss, conf_loss, class_loss, mean_loss, iou, pred_conf,
            ind_mask, grid_mask)

  def post_path_aggregation(self, loss, box_loss, conf_loss, class_loss,
                            ground_truths, predictions):
    """This method allows for post processing of a loss value.

    By default the model will have about 3 FPN levels {3, 4, 5}, on
    larger model that have more like 4 or 5 FPN levels the loss needs to
    be scaled such that the total update is scaled to the same effective
    magintude as the model with 3 FPN levels. This helps to prevent gradient

      loss: `tf.float` scalar for the actual loss.
      box_loss: `tf.float` for the loss on the boxs only.
      conf_loss: `tf.float` for the loss on the confidences only.
      class_loss: `tf.float` for the loss on the classes only.
      ground_truths: `Dict` holding all the ground truth tensors.
      predictions: `Dict` holding all the predicted values.
      loss: `tf.float` scalar for the scaled loss.
      scale: `tf.float` how much the loss was scaled by.
    scale = tf.stop_gradient(3 / len(list(predictions.keys())))
    return loss * scale, 1 / scale

  def cross_replica_aggregation(self, loss, num_replicas_in_sync):
    """This method is not specific to each loss path, but each loss type."""
    return loss

class YoloLoss:
  """This class implements the aggregated loss across YOLO model FPN levels."""

  def __init__(self,
    """Loss Function Initialization.

      keys: `List[str]` indicating the name of the FPN paths that need to be
      classes: `int` for the number of classes
      anchors: `List[List[int]]` for the anchor boxes that are used in the model
        at all levels. For anchor free prediction set the anchor list to be the
        same as the image resolution.
      path_strides: `Dict[int]` for how much to scale this level to get the
        orginal input shape for each FPN path.
      truth_thresholds: `Dict[float]` for the IOU value over which the loss is
        propagated despite a detection being made for each FPN path.
      ignore_thresholds: `Dict[float]` for the IOU value over which the loss is
        not propagated, and a detection is assumed to have been made for each
        FPN path.
      loss_types: `Dict[str]` for the typeof iou loss to use with in {ciou,
        diou, giou, iou} for each FPN path.
      iou_normalizers: `Dict[float]` for how much to scale the loss on the IOU
        or the boxes for each FPN path.
      cls_normalizers: `Dict[float]` for how much to scale the loss on the
        classes for each FPN path.
      object_normalizers: `Dict[float]` for how much to scale loss on the
        detection map for each FPN path.
      objectness_smooths: `Dict[float]` for how much to smooth the loss on the
        detection map for each FPN path.
      box_types: `Dict[bool]` for which scaling type to use for each FPN path.
      scale_xys:  `Dict[float]` values inidcating how far each pixel can see
        outside of its containment of 1.0. a value of 1.2 indicates there is a
        20% extended radius around each pixel that this specific pixel can
        predict values for a center at. the center can range from 0 - value/2 to
        1 + value/2, this value is set in the yolo filter, and resused here.
        there should be one value for scale_xy for each level from min_level to
        max_level. One for each FPN path.
      max_deltas: `Dict[float]` for gradient clipping to apply to the box loss
        for each FPN path.
      label_smoothing: `Dict[float]` for how much to smooth the loss on the
        classes for each FPN path.
      use_scaled_loss: `bool` for whether to use the scaled loss or the
        traditional loss.
      update_on_repeat: `bool` for whether to replace with the newest or the
        best value when an index is consumed by multiple objects.

    losses = {'darknet': DarknetLoss, 'scaled': ScaledLoss}

    if use_scaled_loss:
      loss_type = 'scaled'
      loss_type = 'darknet'

    self._loss_dict = {}
    for key in keys:
      self._loss_dict[key] = losses[loss_type](

  def __call__(self, ground_truth, predictions):
    metric_dict = collections.defaultdict(dict)
    metric_dict['net']['box'] = 0
    metric_dict['net']['class'] = 0
    metric_dict['net']['conf'] = 0

    loss_val, metric_loss = 0, 0
    num_replicas_in_sync = tf.distribute.get_strategy().num_replicas_in_sync

    for key in predictions.keys():
      (loss, loss_box, loss_conf, loss_class, mean_loss, avg_iou,
       avg_obj) = self._loss_dict[key](ground_truth['true_conf'][key],

      # after computing the loss, scale loss as needed for aggregation
      # across FPN levels
      loss, scale = self._loss_dict[key].post_path_aggregation(
          loss, loss_box, loss_conf, loss_class, ground_truth, predictions)

      # after completing the scaling of the loss on each replica, handle
      # scaling the loss for mergeing the loss across replicas
      loss = self._loss_dict[key].cross_replica_aggregation(
          loss, num_replicas_in_sync)
      loss_val += loss

      # detach all the below gradients: none of them should make a
      # contribution to the gradient form this point forwards
      metric_loss += tf.stop_gradient(mean_loss / scale)
      metric_dict[key]['loss'] = tf.stop_gradient(mean_loss / scale)
      metric_dict[key]['avg_iou'] = tf.stop_gradient(avg_iou)
      metric_dict[key]['avg_obj'] = tf.stop_gradient(avg_obj)

      metric_dict['net']['box'] += tf.stop_gradient(loss_box / scale)
      metric_dict['net']['class'] += tf.stop_gradient(loss_class / scale)
      metric_dict['net']['conf'] += tf.stop_gradient(loss_conf / scale)

    return loss_val, metric_loss, metric_dict