tensorflow/models

View on GitHub
research/object_detection/dataset_tools/oid_hierarchical_labels_expansion.py

Summary

Maintainability
B
5 hrs
Test Coverage
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""An executable to expand image-level labels, boxes and segments.

The expansion is performed using class hierarchy, provided in JSON file.

The expected file formats are the following:
- for box and segment files: CSV file is expected to have LabelName field
- for image-level labels: CSV file is expected to have LabelName and Confidence
fields

Note, that LabelName is the only field used for expansion.

Example usage:
python models/research/object_detection/dataset_tools/\
oid_hierarchical_labels_expansion.py \
--json_hierarchy_file=<path to JSON hierarchy> \
--input_annotations=<input csv file> \
--output_annotations=<output csv file> \
--annotation_type=<1 (for boxes and segments) or 2 (for image-level labels)>
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import json
from absl import app
from absl import flags
import six

flags.DEFINE_string(
    'json_hierarchy_file', None,
    'Path to the file containing label hierarchy in JSON format.')
flags.DEFINE_string(
    'input_annotations', None, 'Path to Open Images annotations file'
    '(either bounding boxes, segments or image-level labels).')
flags.DEFINE_string('output_annotations', None, 'Path to the output file.')
flags.DEFINE_integer(
    'annotation_type', None,
    'Type of the input annotations: 1 - boxes or segments,'
    '2 - image-level labels.'
)

FLAGS = flags.FLAGS


def _update_dict(initial_dict, update):
  """Updates dictionary with update content.

  Args:
   initial_dict: initial dictionary.
   update: updated dictionary.
  """

  for key, value_list in update.items():
    if key in initial_dict:
      initial_dict[key].update(value_list)
    else:
      initial_dict[key] = set(value_list)


def _build_plain_hierarchy(hierarchy, skip_root=False):
  """Expands tree hierarchy representation to parent-child dictionary.

  Args:
   hierarchy: labels hierarchy as JSON file.
   skip_root: if true skips root from the processing (done for the case when all
     classes under hierarchy are collected under virtual node).

  Returns:
    keyed_parent - dictionary of parent - all its children nodes.
    keyed_child  - dictionary of children - all its parent nodes
    children - all children of the current node.
  """
  all_children = set([])
  all_keyed_parent = {}
  all_keyed_child = {}
  if 'Subcategory' in hierarchy:
    for node in hierarchy['Subcategory']:
      keyed_parent, keyed_child, children = _build_plain_hierarchy(node)
      # Update is not done through dict.update() since some children have multi-
      # ple parents in the hiearchy.
      _update_dict(all_keyed_parent, keyed_parent)
      _update_dict(all_keyed_child, keyed_child)
      all_children.update(children)

  if not skip_root:
    all_keyed_parent[hierarchy['LabelName']] = copy.deepcopy(all_children)
    all_children.add(hierarchy['LabelName'])
    for child, _ in all_keyed_child.items():
      all_keyed_child[child].add(hierarchy['LabelName'])
    all_keyed_child[hierarchy['LabelName']] = set([])

  return all_keyed_parent, all_keyed_child, all_children


class OIDHierarchicalLabelsExpansion(object):
  """ Main class to perform labels hierachical expansion."""

  def __init__(self, hierarchy):
    """Constructor.

    Args:
      hierarchy: labels hierarchy as JSON object.
    """

    self._hierarchy_keyed_parent, self._hierarchy_keyed_child, _ = (
        _build_plain_hierarchy(hierarchy, skip_root=True))

  def expand_boxes_or_segments_from_csv(self, csv_row,
                                        labelname_column_index=1):
    """Expands a row containing bounding boxes/segments from CSV file.

    Args:
      csv_row: a single row of Open Images released groundtruth file.
      labelname_column_index: 0-based index of LabelName column in CSV file.

    Returns:
      a list of strings (including the initial row) corresponding to the ground
      truth expanded to multiple annotation for evaluation with Open Images
      Challenge 2018/2019 metrics.
    """
    # Row header is expected to be the following for boxes:
    # ImageID,LabelName,Confidence,XMin,XMax,YMin,YMax,IsGroupOf
    # Row header is expected to be the following for segments:
    # ImageID,LabelName,ImageWidth,ImageHeight,XMin,XMax,YMin,YMax,
    # IsGroupOf,Mask
    split_csv_row = six.ensure_str(csv_row).split(',')
    result = [csv_row]
    assert split_csv_row[
        labelname_column_index] in self._hierarchy_keyed_child
    parent_nodes = self._hierarchy_keyed_child[
        split_csv_row[labelname_column_index]]
    for parent_node in parent_nodes:
      split_csv_row[labelname_column_index] = parent_node
      result.append(','.join(split_csv_row))
    return result

  def expand_labels_from_csv(self,
                             csv_row,
                             labelname_column_index=1,
                             confidence_column_index=2):
    """Expands a row containing labels from CSV file.

    Args:
      csv_row: a single row of Open Images released groundtruth file.
      labelname_column_index: 0-based index of LabelName column in CSV file.
      confidence_column_index: 0-based index of Confidence column in CSV file.

    Returns:
      a list of strings (including the initial row) corresponding to the ground
      truth expanded to multiple annotation for evaluation with Open Images
      Challenge 2018/2019 metrics.
    """
    # Row header is expected to be exactly:
    # ImageID,Source,LabelName,Confidence
    split_csv_row = six.ensure_str(csv_row).split(',')
    result = [csv_row]
    if int(split_csv_row[confidence_column_index]) == 1:
      assert split_csv_row[
          labelname_column_index] in self._hierarchy_keyed_child
      parent_nodes = self._hierarchy_keyed_child[
          split_csv_row[labelname_column_index]]
      for parent_node in parent_nodes:
        split_csv_row[labelname_column_index] = parent_node
        result.append(','.join(split_csv_row))
    else:
      assert split_csv_row[
          labelname_column_index] in self._hierarchy_keyed_parent
      child_nodes = self._hierarchy_keyed_parent[
          split_csv_row[labelname_column_index]]
      for child_node in child_nodes:
        split_csv_row[labelname_column_index] = child_node
        result.append(','.join(split_csv_row))
    return result


def main(unused_args):

  del unused_args

  with open(FLAGS.json_hierarchy_file) as f:
    hierarchy = json.load(f)
  expansion_generator = OIDHierarchicalLabelsExpansion(hierarchy)
  labels_file = False
  if FLAGS.annotation_type == 2:
    labels_file = True
  elif FLAGS.annotation_type != 1:
    print('--annotation_type expected value is 1 or 2.')
    return -1
  confidence_column_index = -1
  labelname_column_index = -1
  with open(FLAGS.input_annotations, 'r') as source:
    with open(FLAGS.output_annotations, 'w') as target:
      header = source.readline()
      target.writelines([header])
      column_names = header.strip().split(',')
      labelname_column_index = column_names.index('LabelName')
      if labels_file:
        confidence_column_index = column_names.index('Confidence')
      for line in source:
        if labels_file:
          expanded_lines = expansion_generator.expand_labels_from_csv(
              line, labelname_column_index, confidence_column_index)
        else:
          expanded_lines = (
              expansion_generator.expand_boxes_or_segments_from_csv(
                  line, labelname_column_index))
        target.writelines(expanded_lines)


if __name__ == '__main__':
  flags.mark_flag_as_required('json_hierarchy_file')
  flags.mark_flag_as_required('input_annotations')
  flags.mark_flag_as_required('output_annotations')
  flags.mark_flag_as_required('annotation_type')

  app.run(main)