tensorflow/models

View on GitHub
official/vision/modeling/layers/roi_sampler.py

Summary

Maintainability
A
2 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains definitions of ROI sampler."""
from typing import Optional, Tuple, Union
# Import libraries
import tensorflow as tf, tf_keras

from official.vision.modeling.layers import box_sampler
from official.vision.ops import box_matcher
from official.vision.ops import iou_similarity
from official.vision.ops import target_gather

# The return type can be a tuple of 4 or 5 tf.Tensor.
ROISamplerReturnType = Union[
    Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor],
    Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]]


@tf_keras.utils.register_keras_serializable(package='Vision')
class ROISampler(tf_keras.layers.Layer):
  """Samples ROIs and assigns targets to the sampled ROIs."""

  def __init__(self,
               mix_gt_boxes: bool = True,
               num_sampled_rois: int = 512,
               foreground_fraction: float = 0.25,
               foreground_iou_threshold: float = 0.5,
               background_iou_high_threshold: float = 0.5,
               background_iou_low_threshold: float = 0,
               skip_subsampling: bool = False,
               **kwargs):
    """Initializes a ROI sampler.

    Args:
      mix_gt_boxes: A `bool` of whether to mix the groundtruth boxes with
        proposed ROIs.
      num_sampled_rois: An `int` of the number of sampled ROIs per image.
      foreground_fraction: A `float` in [0, 1], what percentage of proposed ROIs
        should be sampled from the foreground boxes.
      foreground_iou_threshold: A `float` that represents the IoU threshold for
        a box to be considered as positive (if >= `foreground_iou_threshold`).
      background_iou_high_threshold: A `float` that represents the IoU threshold
        for a box to be considered as negative (if overlap in
        [`background_iou_low_threshold`, `background_iou_high_threshold`]).
      background_iou_low_threshold: A `float` that represents the IoU threshold
        for a box to be considered as negative (if overlap in
        [`background_iou_low_threshold`, `background_iou_high_threshold`])
      skip_subsampling: a bool that determines if we want to skip the sampling
        procedure than balances the fg/bg classes. Used for upper frcnn layers
        in cascade RCNN.
      **kwargs: Additional keyword arguments passed to Layer.
    """
    self._config_dict = {
        'mix_gt_boxes': mix_gt_boxes,
        'num_sampled_rois': num_sampled_rois,
        'foreground_fraction': foreground_fraction,
        'foreground_iou_threshold': foreground_iou_threshold,
        'background_iou_high_threshold': background_iou_high_threshold,
        'background_iou_low_threshold': background_iou_low_threshold,
        'skip_subsampling': skip_subsampling,
    }

    self._sim_calc = iou_similarity.IouSimilarity()
    self._box_matcher = box_matcher.BoxMatcher(
        thresholds=[
            background_iou_low_threshold, background_iou_high_threshold,
            foreground_iou_threshold
        ],
        indicators=[-3, -1, -2, 1])
    self._target_gather = target_gather.TargetGather()

    self._sampler = box_sampler.BoxSampler(
        num_sampled_rois, foreground_fraction)
    super().__init__(**kwargs)

  def call(
      self,
      boxes: tf.Tensor,
      gt_boxes: tf.Tensor,
      gt_classes: tf.Tensor,
      gt_outer_boxes: Optional[tf.Tensor] = None) -> ROISamplerReturnType:
    """Assigns the proposals with groundtruth classes and performs subsmpling.

    Given `proposed_boxes`, `gt_boxes`, and `gt_classes`, the function uses the
    following algorithm to generate the final `num_samples_per_image` RoIs.
      1. Calculates the IoU between each proposal box and each gt_boxes.
      2. Assigns each proposed box with a groundtruth class and box by choosing
         the largest IoU overlap.
      3. Samples `num_samples_per_image` boxes from all proposed boxes, and
         returns box_targets, class_targets, and RoIs.

    Args:
      boxes: A `tf.Tensor` of shape of [batch_size, N, 4]. N is the number of
        proposals before groundtruth assignment. The last dimension is the
        box coordinates w.r.t. the scaled images in [ymin, xmin, ymax, xmax]
        format.
      gt_boxes: A `tf.Tensor` of shape of [batch_size, MAX_NUM_INSTANCES, 4].
        The coordinates of gt_boxes are in the pixel coordinates of the scaled
        image. This tensor might have padding of values -1 indicating the
        invalid box coordinates.
      gt_classes: A `tf.Tensor` with a shape of [batch_size, MAX_NUM_INSTANCES].
        This tensor might have paddings with values of -1 indicating the invalid
        classes.
      gt_outer_boxes: A `tf.Tensor` of shape of [batch_size, MAX_NUM_INSTANCES,
        4]. The corrdinates of gt_outer_boxes are in the pixel coordinates of
        the scaled image. This tensor might have padding of values -1 indicating
        the invalid box coordinates. Ignored if not provided.

    Returns:
      sampled_rois: A `tf.Tensor` of shape of [batch_size, K, 4], representing
        the coordinates of the sampled RoIs, where K is the number of the
        sampled RoIs, i.e. K = num_samples_per_image.
      sampled_gt_boxes: A `tf.Tensor` of shape of [batch_size, K, 4], storing
        the box coordinates of the matched groundtruth boxes of the samples
        RoIs.
      sampled_gt_outer_boxes: A `tf.Tensor` of shape of [batch_size, K, 4],
        storing the box coordinates of the matched groundtruth outer boxes of
        the samples RoIs. This field is missing if gt_outer_boxes is None.
      sampled_gt_classes: A `tf.Tensor` of shape of [batch_size, K], storing the
        classes of the matched groundtruth boxes of the sampled RoIs.
      sampled_gt_indices: A `tf.Tensor` of shape of [batch_size, K], storing the
        indices of the sampled groudntruth boxes in the original `gt_boxes`
        tensor, i.e.,
        gt_boxes[sampled_gt_indices[:, i]] = sampled_gt_boxes[:, i].
    """
    gt_boxes = tf.cast(gt_boxes, dtype=boxes.dtype)
    if self._config_dict['mix_gt_boxes']:
      boxes = tf.concat([boxes, gt_boxes], axis=1)

    boxes_invalid_mask = tf.less(
        tf.reduce_max(boxes, axis=-1, keepdims=True), 0.0)
    gt_invalid_mask = tf.less(
        tf.reduce_max(gt_boxes, axis=-1, keepdims=True), 0.0)
    similarity_matrix = self._sim_calc(boxes, gt_boxes, boxes_invalid_mask,
                                       gt_invalid_mask)
    matched_gt_indices, match_indicators = self._box_matcher(similarity_matrix)
    positive_matches = tf.greater_equal(match_indicators, 0)
    negative_matches = tf.equal(match_indicators, -1)
    ignored_matches = tf.equal(match_indicators, -2)
    invalid_matches = tf.equal(match_indicators, -3)

    background_mask = tf.expand_dims(
        tf.logical_or(negative_matches, invalid_matches), -1)
    gt_classes = tf.expand_dims(gt_classes, axis=-1)
    matched_gt_classes = self._target_gather(gt_classes, matched_gt_indices,
                                             background_mask)
    matched_gt_classes = tf.where(background_mask,
                                  tf.zeros_like(matched_gt_classes),
                                  matched_gt_classes)
    matched_gt_boxes = self._target_gather(gt_boxes, matched_gt_indices,
                                           tf.tile(background_mask, [1, 1, 4]))
    matched_gt_boxes = tf.where(background_mask,
                                tf.zeros_like(matched_gt_boxes),
                                matched_gt_boxes)
    if gt_outer_boxes is not None:
      matched_gt_outer_boxes = self._target_gather(
          gt_outer_boxes, matched_gt_indices, tf.tile(background_mask,
                                                      [1, 1, 4]))
      matched_gt_outer_boxes = tf.where(background_mask,
                                        tf.zeros_like(matched_gt_outer_boxes),
                                        matched_gt_outer_boxes)
    matched_gt_indices = tf.where(
        tf.squeeze(background_mask, -1), -tf.ones_like(matched_gt_indices),
        matched_gt_indices)

    if self._config_dict['skip_subsampling']:
      matched_gt_classes = tf.squeeze(matched_gt_classes, axis=-1)
      if gt_outer_boxes is None:
        return (boxes, matched_gt_boxes, matched_gt_classes, matched_gt_indices)
      return (boxes, matched_gt_boxes, matched_gt_outer_boxes,
              matched_gt_classes, matched_gt_indices)

    sampled_indices = self._sampler(
        positive_matches, negative_matches, ignored_matches)

    sampled_rois = self._target_gather(boxes, sampled_indices)
    sampled_gt_boxes = self._target_gather(matched_gt_boxes, sampled_indices)
    sampled_gt_classes = tf.squeeze(self._target_gather(
        matched_gt_classes, sampled_indices), axis=-1)
    sampled_gt_indices = tf.squeeze(self._target_gather(
        tf.expand_dims(matched_gt_indices, -1), sampled_indices), axis=-1)
    if gt_outer_boxes is None:
      return (sampled_rois, sampled_gt_boxes, sampled_gt_classes,
              sampled_gt_indices)
    sampled_gt_outer_boxes = self._target_gather(matched_gt_outer_boxes,
                                                 sampled_indices)
    return (sampled_rois, sampled_gt_boxes, sampled_gt_outer_boxes,
            sampled_gt_classes, sampled_gt_indices)

  def get_config(self):
    return self._config_dict

  @classmethod
  def from_config(cls, config):
    return cls(**config)