tensorflow/models

View on GitHub
research/deeplab/evaluation/panoptic_quality.py

Summary

Maintainability
C
1 day
Test Coverage
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the Panoptic Quality metric.

Panoptic Quality is an instance-based metric for evaluating the task of
image parsing, aka panoptic segmentation.

Please see the paper for details:
"Panoptic Segmentation", Alexander Kirillov, Kaiming He, Ross Girshick,
Carsten Rother and Piotr Dollar. arXiv:1801.00868, 2018.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import numpy as np
import prettytable
import six

from deeplab.evaluation import base_metric


def _ids_to_counts(id_array):
  """Given a numpy array, a mapping from each unique entry to its count."""
  ids, counts = np.unique(id_array, return_counts=True)
  return dict(six.moves.zip(ids, counts))


class PanopticQuality(base_metric.SegmentationMetric):
  """Metric class for Panoptic Quality.

  "Panoptic Segmentation" by Alexander Kirillov, Kaiming He, Ross Girshick,
  Carsten Rother, Piotr Dollar.
  https://arxiv.org/abs/1801.00868
  """

  def compare_and_accumulate(
      self, groundtruth_category_array, groundtruth_instance_array,
      predicted_category_array, predicted_instance_array):
    """See base class."""
    # First, combine the category and instance labels so that every unique
    # value for (category, instance) is assigned a unique integer label.
    pred_segment_id = self._naively_combine_labels(predicted_category_array,
                                                   predicted_instance_array)
    gt_segment_id = self._naively_combine_labels(groundtruth_category_array,
                                                 groundtruth_instance_array)

    # Pre-calculate areas for all groundtruth and predicted segments.
    gt_segment_areas = _ids_to_counts(gt_segment_id)
    pred_segment_areas = _ids_to_counts(pred_segment_id)

    # We assume there is only one void segment and it has instance id = 0.
    void_segment_id = self.ignored_label * self.max_instances_per_category

    # There may be other ignored groundtruth segments with instance id > 0, find
    # those ids using the unique segment ids extracted with the area computation
    # above.
    ignored_segment_ids = {
        gt_segment_id for gt_segment_id in six.iterkeys(gt_segment_areas)
        if (gt_segment_id //
            self.max_instances_per_category) == self.ignored_label
    }

    # Next, combine the groundtruth and predicted labels. Dividing up the pixels
    # based on which groundtruth segment and which predicted segment they belong
    # to, this will assign a different 32-bit integer label to each choice
    # of (groundtruth segment, predicted segment), encoded as
    #   gt_segment_id * offset + pred_segment_id.
    intersection_id_array = (
        gt_segment_id.astype(np.uint32) * self.offset +
        pred_segment_id.astype(np.uint32))

    # For every combination of (groundtruth segment, predicted segment) with a
    # non-empty intersection, this counts the number of pixels in that
    # intersection.
    intersection_areas = _ids_to_counts(intersection_id_array)

    # Helper function that computes the area of the overlap between a predicted
    # segment and the ground-truth void/ignored segment.
    def prediction_void_overlap(pred_segment_id):
      void_intersection_id = void_segment_id * self.offset + pred_segment_id
      return intersection_areas.get(void_intersection_id, 0)

    # Compute overall ignored overlap.
    def prediction_ignored_overlap(pred_segment_id):
      total_ignored_overlap = 0
      for ignored_segment_id in ignored_segment_ids:
        intersection_id = ignored_segment_id * self.offset + pred_segment_id
        total_ignored_overlap += intersection_areas.get(intersection_id, 0)
      return total_ignored_overlap

    # Sets that are populated with which segments groundtruth/predicted segments
    # have been matched with overlapping predicted/groundtruth segments
    # respectively.
    gt_matched = set()
    pred_matched = set()

    # Calculate IoU per pair of intersecting segments of the same category.
    for intersection_id, intersection_area in six.iteritems(intersection_areas):
      gt_segment_id = intersection_id // self.offset
      pred_segment_id = intersection_id % self.offset

      gt_category = gt_segment_id // self.max_instances_per_category
      pred_category = pred_segment_id // self.max_instances_per_category
      if gt_category != pred_category:
        continue

      # Union between the groundtruth and predicted segments being compared does
      # not include the portion of the predicted segment that consists of
      # groundtruth "void" pixels.
      union = (
          gt_segment_areas[gt_segment_id] +
          pred_segment_areas[pred_segment_id] - intersection_area -
          prediction_void_overlap(pred_segment_id))
      iou = intersection_area / union
      if iou > 0.5:
        self.tp_per_class[gt_category] += 1
        self.iou_per_class[gt_category] += iou
        gt_matched.add(gt_segment_id)
        pred_matched.add(pred_segment_id)

    # Count false negatives for each category.
    for gt_segment_id in six.iterkeys(gt_segment_areas):
      if gt_segment_id in gt_matched:
        continue
      category = gt_segment_id // self.max_instances_per_category
      # Failing to detect a void segment is not a false negative.
      if category == self.ignored_label:
        continue
      self.fn_per_class[category] += 1

    # Count false positives for each category.
    for pred_segment_id in six.iterkeys(pred_segment_areas):
      if pred_segment_id in pred_matched:
        continue
      # A false positive is not penalized if is mostly ignored in the
      # groundtruth.
      if (prediction_ignored_overlap(pred_segment_id) /
          pred_segment_areas[pred_segment_id]) > 0.5:
        continue
      category = pred_segment_id // self.max_instances_per_category
      self.fp_per_class[category] += 1

    return self.result()

  def _valid_categories(self):
    """Categories with a "valid" value for the metric, have > 0 instances.

    We will ignore the `ignore_label` class and other classes which have
    `tp + fn + fp = 0`.

    Returns:
      Boolean array of shape `[num_categories]`.
    """
    valid_categories = np.not_equal(
        self.tp_per_class + self.fn_per_class + self.fp_per_class, 0)
    if self.ignored_label >= 0 and self.ignored_label < self.num_categories:
      valid_categories[self.ignored_label] = False
    return valid_categories

  def detailed_results(self, is_thing=None):
    """See base class."""
    valid_categories = self._valid_categories()

    # If known, break down which categories are valid _and_ things/stuff.
    category_sets = collections.OrderedDict()
    category_sets['All'] = valid_categories
    if is_thing is not None:
      category_sets['Things'] = np.logical_and(valid_categories, is_thing)
      category_sets['Stuff'] = np.logical_and(valid_categories,
                                              np.logical_not(is_thing))

    # Compute individual per-class metrics that constitute factors of PQ.
    sq = base_metric.realdiv_maybe_zero(self.iou_per_class, self.tp_per_class)
    rq = base_metric.realdiv_maybe_zero(
        self.tp_per_class,
        self.tp_per_class + 0.5 * self.fn_per_class + 0.5 * self.fp_per_class)
    pq = np.multiply(sq, rq)

    # Assemble detailed results dictionary.
    results = {}
    for category_set_name, in_category_set in six.iteritems(category_sets):
      if np.any(in_category_set):
        results[category_set_name] = {
            'pq': np.mean(pq[in_category_set]),
            'sq': np.mean(sq[in_category_set]),
            'rq': np.mean(rq[in_category_set]),
            # The number of categories in this subset.
            'n': np.sum(in_category_set.astype(np.int32)),
        }
      else:
        results[category_set_name] = {'pq': 0, 'sq': 0, 'rq': 0, 'n': 0}

    return results

  def result_per_category(self):
    """See base class."""
    sq = base_metric.realdiv_maybe_zero(self.iou_per_class, self.tp_per_class)
    rq = base_metric.realdiv_maybe_zero(
        self.tp_per_class,
        self.tp_per_class + 0.5 * self.fn_per_class + 0.5 * self.fp_per_class)
    return np.multiply(sq, rq)

  def print_detailed_results(self, is_thing=None, print_digits=3):
    """See base class."""
    results = self.detailed_results(is_thing=is_thing)

    tab = prettytable.PrettyTable()

    tab.add_column('', [], align='l')
    for fieldname in ['PQ', 'SQ', 'RQ', 'N']:
      tab.add_column(fieldname, [], align='r')

    for category_set, subset_results in six.iteritems(results):
      data_cols = [
          round(subset_results[col_key], print_digits) * 100
          for col_key in ['pq', 'sq', 'rq']
      ]
      data_cols += [subset_results['n']]
      tab.add_row([category_set] + data_cols)

    print(tab)

  def result(self):
    """See base class."""
    pq_per_class = self.result_per_category()
    valid_categories = self._valid_categories()
    if not np.any(valid_categories):
      return 0.
    return np.mean(pq_per_class[valid_categories])

  def merge(self, other_instance):
    """See base class."""
    self.iou_per_class += other_instance.iou_per_class
    self.tp_per_class += other_instance.tp_per_class
    self.fn_per_class += other_instance.fn_per_class
    self.fp_per_class += other_instance.fp_per_class

  def reset(self):
    """See base class."""
    self.iou_per_class = np.zeros(self.num_categories, dtype=np.float64)
    self.tp_per_class = np.zeros(self.num_categories, dtype=np.float64)
    self.fn_per_class = np.zeros(self.num_categories, dtype=np.float64)
    self.fp_per_class = np.zeros(self.num_categories, dtype=np.float64)