tensorflow/models

View on GitHub
official/projects/waste_identification_ml/model_inference/postprocessing.py

Summary

Maintainability
D
1 day
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Post process the results output from the ML model.

Given the output from the 2 mask RCNN models. The 3 main tasks are done by three
functions mentioned below -
1. reframing_masks : Reframe the masks according to the size of an image and
their respective positions within an image.
2. find_similar_masks : Given masks from the output of 2 models. Find masks 
which belong to the same object and combine all of their attributes like 
confidence score, bounding boxes, label names, etc. The masks are mapped to each
other if their score is above a threshold limit. Two outputs are combined into 
a single output.
3. filter_bounding_boxes : The combined output may have nested bounding boxes of 
the same object. The parent bounding boxes are removed in this step so that any 
object should not have more than a single bounding box. 
"""
import copy
from typing import Any, Optional, TypedDict, Dict, Tuple, List
import numpy as np
import tensorflow as tf, tf_keras


class DetectionResult(TypedDict):
  num_detections: np.ndarray
  detection_classes: np.ndarray
  detection_scores: np.ndarray
  detection_boxes: np.ndarray
  detection_classes_names: np.ndarray
  detection_masks_reframed: np.ndarray


class ItemDict(TypedDict):
  id: int
  name: str
  supercategory: str


def reframe_image_corners_relative_to_boxes(boxes: tf.Tensor) -> tf.Tensor:
  """Reframe the image corners ([0, 0, 1, 1]) to be relative to boxes.

  The local coordinate frame of each box is assumed to be relative to
  its own for corners.

  Args:
    boxes: A float tensor of [num_boxes, 4] of (ymin, xmin, ymax, xmax)
      coordinates in relative coordinate space of each bounding box.

  Returns:
    reframed_boxes: Reframes boxes with same shape as input.
  """
  ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=1)

  height = tf.maximum(ymax - ymin, 1e-4)
  width = tf.maximum(xmax - xmin, 1e-4)

  ymin_out = (0 - ymin) / height
  xmin_out = (0 - xmin) / width
  ymax_out = (1 - ymin) / height
  xmax_out = (1 - xmin) / width
  return tf.stack([ymin_out, xmin_out, ymax_out, xmax_out], axis=1)


def reframe_box_masks_to_image_masks(
    box_masks: tf.Tensor,
    boxes: tf.Tensor,
    image_height: int,
    image_width: int,
    resize_method: str = 'bilinear',
) -> tf.Tensor:
  """Transforms the box masks back to full image masks.

  Embeds masks in bounding boxes of larger masks whose shapes correspond to
  image shape.

  Args:
    box_masks: A tensor of size [num_masks, mask_height, mask_width].
    boxes: A tf.float32 tensor of size [num_masks, 4] containing the box
      corners. Row i contains [ymin, xmin, ymax, xmax] of the box corresponding
      to mask i. Note that the box corners are in normalized coordinates.
    image_height: Image height. The output mask will have the same height as the
      image height.
    image_width: Image width. The output mask will have the same width as the
      image width.
    resize_method: The resize method, either 'bilinear' or 'nearest'. Note that
      'bilinear' is only respected if box_masks is a float.

  Returns:
    A tensor of size [num_masks, image_height, image_width] with the same dtype
    as `box_masks`.
  """
  resize_method = 'nearest' if box_masks.dtype == tf.uint8 else resize_method

  def reframe_box_masks_to_image_masks_default():
    """The default function when there are more than 0 box masks."""

    num_boxes = tf.shape(box_masks)[0]
    box_masks_expanded = tf.expand_dims(box_masks, axis=3)

    resized_crops = tf.image.crop_and_resize(
        image=box_masks_expanded,
        boxes=reframe_image_corners_relative_to_boxes(boxes),
        box_indices=tf.range(num_boxes),
        crop_size=[image_height, image_width],
        method=resize_method,
        extrapolation_value=0,
    )
    return tf.cast(resized_crops, box_masks.dtype)

  image_masks = tf.cond(
      tf.shape(box_masks)[0] > 0,
      reframe_box_masks_to_image_masks_default,
      lambda: tf.zeros([0, image_height, image_width, 1], box_masks.dtype),
  )
  return tf.squeeze(image_masks, axis=3)


def reframing_masks(
    results: Dict[str, np.ndarray], height: int, width: int
) -> Dict[str, np.ndarray]:
  """Processes the output from Mask RCNN model to create a full size mask.

  Args:
    results: list of dictionaries containing the output of Mask RCNN.
    height: The height of the image.
    width: The width of the image.

  Returns:
    A processed list of dictionaries.
  """
  result = copy.deepcopy(results)
  result['detection_boxes'][0][:, [0, 2]] /= height
  result['detection_boxes'][0][:, [1, 3]] /= width

  detection_masks = tf.convert_to_tensor(result['detection_masks'][0])
  detection_boxes = tf.convert_to_tensor(result['detection_boxes'][0])
  detection_masks_reframed = reframe_box_masks_to_image_masks(
      detection_masks, detection_boxes, height, width
  )
  detection_masks_reframed = tf.cast(detection_masks_reframed > 0.8, np.uint8)
  result['detection_masks_reframed'] = detection_masks_reframed.numpy()
  return result


def find_id_by_name(
    dictionary: Dict[int, ItemDict], name: str
) -> Optional[int]:
  """Finds the id of a dictionary given its value.

  Args:
    dictionary: The dictionary containing the data.
    name: The value to find.

  Returns:
    The id, or None if its not found.
  """

  # Iterate over the dictionary, and check if the name of the user
  # matches the name that was passed in.
  for value in dictionary.values():
    if value['name'] == name:
      # If the name matches, return the id of the user.
      return value['id']

  return None


def combine_bounding_boxes(
    box1: List[float],
    box2: List[float],
) -> List[float]:
  """Combines two bounding boxes.

  Args:
    box1: A list of four numbers representing the coordinates of the first
      bounding box.
    box2: A list of four numbers representing the coordinates of the second
      bounding box.

  Returns:
    A list of four numbers representing the coordinates of the combined
    bounding box.
  """

  ymin = min(box1[0], box2[0])
  xmin = min(box1[1], box2[1])
  ymax = max(box1[2], box2[2])
  xmax = max(box1[3], box2[3])

  return [ymin, xmin, ymax, xmax]


def calculate_combined_scores_boxes_classes(
    i: int,
    j: int,
    results_1: DetectionResult,
    results_2: DetectionResult,
    category_indices: List[List[Any]],
    category_index_combined: Dict[int, ItemDict],
) -> Tuple[Any, List[float], Any, Optional[int]]:
  """Calculate combined scores, boxes, and classes for matched masks.

  Args:
      i: Index of the mask from the results_1.
      j: Index of the mask from the results_2.
      results_1: A dictionary which contains the results from the first model.
      results_2: A dictionary which contains the results from the second model.
      category_indices: list of sub lists which contains the labels of 1st and
        2nd ML model.
      category_index_combined: Combined category index.

  Returns:
      tuple: A tuple containing:
          - avg_score: Average score of the matched masks.
          - combined_box: Combined bounding box for the matched masks.
          - combined_label: Combined label of the matched masks.
          - result_id: ID associated with the combined label.
  """
  score_1 = results_1['detection_scores'][0][i]
  score_2 = results_2['detection_scores'][0][j]
  avg_score = (score_1 + score_2) / 2

  box_1 = results_1['detection_boxes'][0][i]
  box_2 = results_2['detection_boxes'][0][j]
  combined_box = combine_bounding_boxes(box_1, box_2)

  class_1 = results_1['detection_classes'][0][i]
  class_2 = results_2['detection_classes'][0][j]
  combined_label = (
      category_indices[0][class_1] + '_' + category_indices[1][class_2]
  )
  result_id = find_id_by_name(category_index_combined, combined_label)

  return avg_score, combined_box, combined_label, result_id


def calculate_single_result(
    index: int,
    result: DetectionResult,
    category_indices: List[List[Any]],
    flag: Any | str,
) -> Tuple[float, Tuple[float, float, float, float], str]:
  """Calculate scores, boxes, and classes for non-matched masks.

  Args:
      index: Index of the mask in the results.
      result: A dictionary containing detection results (either from results_1
        or results_2).
      category_indices: list of category indices.
      flag: To identify whose model did not detected an object.

  Returns:
    score: Score of the mask.
    box: Bounding box of the mask.
    combined_label: Label of the mask with the added suffix.
  """
  combined_label = 'Default Value'
  score = result['detection_scores'][0][index]
  box = result['detection_boxes'][0][index]
  class_idx = result['detection_classes'][0][index]
  if flag == 'after':
    combined_label = category_indices[class_idx] + '_Na'
  elif flag == 'before':
    combined_label = 'Na_' + category_indices[class_idx]
  return score, box, combined_label


def calculate_iou(
    mask1: np.ndarray, mask2: np.ndarray
) -> Tuple[float, np.ndarray]:
  """Calculates the intersection over union (IoU) score for two masks.

  Args:
    mask1: The first mask.
    mask2: The second mask.

  Returns:
    The IoU scorea and union of two masks.
  """

  # Check if the masks have the same dimensions.
  if mask1.shape != mask2.shape:
    raise ValueError('The masks must have the same dimensions.')

  intersection = np.logical_and(mask1, mask2)
  union = np.logical_or(mask1, mask2)
  iou_score = np.sum(intersection) / np.sum(union)
  return iou_score, union


def find_similar_masks(
    results_1: DetectionResult,
    results_2: DetectionResult,
    num_detections: int,
    min_score_thresh: float,
    category_indices: List[List[Any]],
    category_index_combined: Dict[int, ItemDict],
    area_threshold: float,
    iou_threshold: float = 0.8,
) -> Dict[str, np.ndarray]:
  """Aligns the masks of the detections in `results_1` and `results_2`.

  Args:
    results_1: A dictionary which contains the results from the first model.
    results_2: A dictionary which contains the results from the second model.
    num_detections: The number of detections to consider.
    min_score_thresh: The minimum score threshold for a detection
    category_indices: list of sub lists which contains the labels of 1st and 2nd
      ML model
    category_index_combined: A dictionary with an object ID and nested
      dictionary with name. e.g. {1: {'id': 1, 'name': 'Fiber_Na_Bag',
      'supercategory': 'objects'}}
    area_threshold: Threshold for mask area consideration.
    iou_threshold: IOU threshold to compare masks.

  Returns:
     A dictionary containing the following keys:
       - num_detections: The number of aligned detections.
       - detection_classes: A NumPy array of shape (num_detections,) containing
       the classes for the aligned detections.
       - detection_scores: A NumPy array of shape (num_detections,) containing
       the scores for the aligned detections.
       - detection_boxes: A NumPy array of shape (num_detections, 4) containing
       the bounding boxes for the aligned detections.
       - detection_classes_names: A list of strings containing the names of the
       classes for the aligned detections.
       - detection_masks_reframed: A NumPy array of shape (num_detections,
       height, width) containing the full masks for the aligned detections.
  """
  detection_masks_reframed = []
  detection_scores = []
  detection_boxes = []
  detection_classes = []
  detection_classes_names = []

  aligned_masks = 0
  masks_list1 = results_1['detection_masks_reframed'][:num_detections]
  masks_list2 = results_2['detection_masks_reframed'][:num_detections]
  scores_list1 = results_1['detection_scores'][0]
  scores_list2 = results_2['detection_scores'][0]
  matched_masks_list2 = [False] * len(masks_list2)
  matched_masks_list1 = [False] * len(masks_list1)

  for i, mask1 in enumerate(masks_list1):
    if (scores_list1[i] > min_score_thresh) and (
        np.sum(mask1) < area_threshold
    ):
      is_similar = False

      for j, mask2 in enumerate(masks_list2):
        if scores_list2[j] > min_score_thresh and (
            np.sum(mask2) < area_threshold
        ):
          iou, union = calculate_iou(mask1, mask2)

          # masks which are present both in the 'detection_masks_reframed'
          # key of 'results_1' & 'results_2' dictionary
          if iou > iou_threshold:
            aligned_masks += 1
            is_similar = True
            matched_masks_list2[j] = True
            matched_masks_list1[i] = True

            detection_masks_reframed.append(union)

            avg_score, combined_box, combined_label, result_id = (
                calculate_combined_scores_boxes_classes(
                    i,
                    j,
                    results_1,
                    results_2,
                    category_indices,
                    category_index_combined,
                )
            )
            detection_scores.append(avg_score)
            detection_boxes.append(combined_box)
            detection_classes_names.append(combined_label)
            detection_classes.append(result_id)
            break

      # masks which are only present in the 'detection_masks_reframed'
      # of 'results_1' dictionary
      if not is_similar:
        aligned_masks += 1
        detection_masks_reframed.append(mask1)
        score, box, combined_label = calculate_single_result(
            i, results_1, category_indices[0], 'after'
        )
        detection_scores.append(score)
        detection_boxes.append(box)
        detection_classes_names.append(combined_label)
        result_id = find_id_by_name(category_index_combined, combined_label)
        detection_classes.append(result_id)

  # masks which are only present in the 'detection_masks_reframed'
  # key of 'results_2' dictionary
  for k, mask2 in enumerate(masks_list2):
    if (
        (not matched_masks_list2[k])
        and (scores_list2[k] > min_score_thresh)
        and (np.sum(mask2) < area_threshold)
    ):
      aligned_masks += 1
      detection_masks_reframed.append(mask2)
      score, box, combined_label = calculate_single_result(
          k, results_2, category_indices[1], 'before'
      )
      detection_scores.append(score)
      detection_boxes.append(box)
      detection_classes_names.append(combined_label)
      result_id = find_id_by_name(category_index_combined, combined_label)
      detection_classes.append(result_id)

  final_result = {
      'num_detections': np.array([aligned_masks]),
      'detection_classes': np.array(detection_classes),
      'detection_scores': np.array([detection_scores]),
      'detection_boxes': np.array([detection_boxes]),
      'detection_classes_names': np.array(detection_classes_names),
      'detection_masks_reframed': np.array(detection_masks_reframed),
  }

  return final_result


def filter_bounding_boxes(
    bounding_boxes: List[Tuple[int, int, int, int]],
    iou_threshold: float = 0.5,
    area_ratio_threshold: float = 0.8,
) -> Tuple[List[Tuple[int, int, int, int]], List[int]]:
  """Filters overlapping bounding boxes based on IoU and area ratio criteria.

  This function filters out overlapping bounding boxes from a given list based
  on Intersection over Union (IoU) and area ratio of the intersection to the
  smaller bounding box's area.

  Args:
      bounding_boxes: A list of bounding boxes, where each bounding box is
        represented as a tuple of (xmin, ymin, xmax, ymax).
      iou_threshold: Threshold for Intersection over Union. Bounding boxes with
        IoU above this threshold will be considered overlapping. Defaults to
        0.5.
      area_ratio_threshold: Threshold for the area ratio of the intersection to
        the smaller bounding box's area. Defaults to 0.8.

  Returns:
      tuple: A tuple containing:
          - filtered_boxes: A list of bounding boxes that passed the filtering
          criteria.
          - eliminated_indices: Indices of the bounding boxes that didn't pass
          the filtering criteria.

  Example:
      >>> bounding_boxes = [(10, 10, 50, 50), (20, 20, 60, 60)]
      >>> filter_bounding_boxes(bounding_boxes)
      ([(10, 10, 50, 50)], [1])
  """
  filtered_boxes = []
  eliminated_indices = []

  # Enumerate and sort the boxes based on their area in descending order
  enumerated_boxes = list(enumerate(bounding_boxes))
  sorted_boxes = sorted(
      enumerated_boxes,
      key=lambda item: (item[1][2] - item[1][0]) * (item[1][3] - item[1][1]),
      reverse=True,
  )

  for idx, bbox in sorted_boxes:
    skip_box = False

    # Calculate areas of individual bounding boxes
    area_bbox = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])

    for jdx, other_bbox in sorted_boxes:
      if idx == jdx:
        continue

      # Calculate intersection coordinates
      xmin_inter = max(bbox[0], other_bbox[0])
      ymin_inter = max(bbox[1], other_bbox[1])
      xmax_inter = min(bbox[2], other_bbox[2])
      ymax_inter = min(bbox[3], other_bbox[3])

      # Calculate intersection area
      width_inter = max(0, xmax_inter - xmin_inter)
      height_inter = max(0, ymax_inter - ymin_inter)
      area_inter = width_inter * height_inter

      area_other_bbox = (other_bbox[2] - other_bbox[0]) * (
          other_bbox[3] - other_bbox[1]
      )

      # Calculate area ratio
      area_ratio = area_inter / min(area_bbox, area_other_bbox)

      # Check for overlapping and area ratio thresholds
      if area_ratio > area_ratio_threshold:
        if area_bbox > area_other_bbox:
          skip_box = True
          eliminated_indices.append(idx)
          break
      elif (
          area_inter > 0
          and area_inter / (area_bbox + area_other_bbox - area_inter)
          > iou_threshold
      ):
        if area_bbox > area_other_bbox:
          skip_box = True
          eliminated_indices.append(idx)
          break

    if not skip_box:
      filtered_boxes.append(bbox)

  return filtered_boxes, eliminated_indices