tensorflow/models

View on GitHub
research/lstm_object_detection/inputs/tf_sequence_example_decoder.py

Summary

Maintainability
A
35 mins
Test Coverage
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tensorflow Sequence Example proto decoder.

A decoder to decode string tensors containing serialized
tensorflow.SequenceExample protos.
"""
import tensorflow.compat.v1 as tf
import tf_slim as slim
from object_detection.core import data_decoder
from object_detection.core import standard_fields as fields

tfexample_decoder = slim.tfexample_decoder


class BoundingBoxSequence(tfexample_decoder.ItemHandler):
  """An ItemHandler that concatenates SparseTensors to Bounding Boxes.
  """

  def __init__(self, keys=None, prefix=None, return_dense=True,
               default_value=-1.0):
    """Initialize the bounding box handler.

    Args:
      keys: A list of four key names representing the ymin, xmin, ymax, xmax
        in the Example or SequenceExample.
      prefix: An optional prefix for each of the bounding box keys in the
        Example or SequenceExample. If provided, `prefix` is prepended to each
        key in `keys`.
      return_dense: if True, returns a dense tensor; if False, returns as
        sparse tensor.
      default_value: The value used when the `tensor_key` is not found in a
        particular `TFExample`.

    Raises:
      ValueError: if keys is not `None` and also not a list of exactly 4 keys
    """
    if keys is None:
      keys = ['ymin', 'xmin', 'ymax', 'xmax']
    elif len(keys) != 4:
      raise ValueError('BoundingBoxSequence expects 4 keys but got {}'.format(
          len(keys)))
    self._prefix = prefix
    self._keys = keys
    self._full_keys = [prefix + k for k in keys]
    self._return_dense = return_dense
    self._default_value = default_value
    super(BoundingBoxSequence, self).__init__(self._full_keys)

  def tensors_to_item(self, keys_to_tensors):
    """Maps the given dictionary of tensors to a concatenated list of bboxes.

    Args:
      keys_to_tensors: a mapping of TF-Example keys to parsed tensors.

    Returns:
      [time, num_boxes, 4] tensor of bounding box coordinates, in order
          [y_min, x_min, y_max, x_max]. Whether the tensor is a SparseTensor
          or a dense Tensor is determined by the return_dense parameter. Empty
          positions in the sparse tensor are filled with -1.0 values.
    """
    sides = []
    for key in self._full_keys:
      value = keys_to_tensors[key]
      expanded_dims = tf.concat(
          [tf.to_int64(tf.shape(value)),
           tf.constant([1], dtype=tf.int64)], 0)
      side = tf.sparse_reshape(value, expanded_dims)
      sides.append(side)
    bounding_boxes = tf.sparse_concat(2, sides)
    if self._return_dense:
      bounding_boxes = tf.sparse_tensor_to_dense(
          bounding_boxes, default_value=self._default_value)
    return bounding_boxes


class TFSequenceExampleDecoder(data_decoder.DataDecoder):
  """Tensorflow Sequence Example proto decoder."""

  def __init__(self):
    """Constructor sets keys_to_features and items_to_handlers."""
    self.keys_to_context_features = {
        'image/format':
            tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/filename':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/key/sha256':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/source_id':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/height':
            tf.FixedLenFeature((), tf.int64, 1),
        'image/width':
            tf.FixedLenFeature((), tf.int64, 1),
    }
    self.keys_to_features = {
        'image/encoded': tf.FixedLenSequenceFeature((), tf.string),
        'bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
        'bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
        'bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
        'bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
        'bbox/label/index': tf.VarLenFeature(dtype=tf.int64),
        'bbox/label/string': tf.VarLenFeature(tf.string),
        'area': tf.VarLenFeature(tf.float32),
        'is_crowd': tf.VarLenFeature(tf.int64),
        'difficult': tf.VarLenFeature(tf.int64),
        'group_of': tf.VarLenFeature(tf.int64),
    }
    self.items_to_handlers = {
        fields.InputDataFields.image:
            tfexample_decoder.Image(
                image_key='image/encoded',
                format_key='image/format',
                channels=3,
                repeated=True),
        fields.InputDataFields.source_id: (
            tfexample_decoder.Tensor('image/source_id')),
        fields.InputDataFields.key: (
            tfexample_decoder.Tensor('image/key/sha256')),
        fields.InputDataFields.filename: (
            tfexample_decoder.Tensor('image/filename')),
        # Object boxes and classes.
        fields.InputDataFields.groundtruth_boxes:
            BoundingBoxSequence(prefix='bbox/'),
        fields.InputDataFields.groundtruth_classes: (
            tfexample_decoder.Tensor('bbox/label/index')),
        fields.InputDataFields.groundtruth_area:
            tfexample_decoder.Tensor('area'),
        fields.InputDataFields.groundtruth_is_crowd: (
            tfexample_decoder.Tensor('is_crowd')),
        fields.InputDataFields.groundtruth_difficult: (
            tfexample_decoder.Tensor('difficult')),
        fields.InputDataFields.groundtruth_group_of: (
            tfexample_decoder.Tensor('group_of'))
    }

  def decode(self, tf_seq_example_string_tensor, items=None):
    """Decodes serialized tf.SequenceExample and returns a tensor dictionary.

    Args:
      tf_seq_example_string_tensor: A string tensor holding a serialized
        tensorflow example proto.
      items: The list of items to decode. These must be a subset of the item
        keys in self._items_to_handlers. If `items` is left as None, then all
        of the items in self._items_to_handlers are decoded.

    Returns:
      A dictionary of the following tensors.
      fields.InputDataFields.image - 3D uint8 tensor of shape [None, None, seq]
        containing image(s).
      fields.InputDataFields.source_id - string tensor containing original
        image id.
      fields.InputDataFields.key - string tensor with unique sha256 hash key.
      fields.InputDataFields.filename - string tensor with original dataset
        filename.
      fields.InputDataFields.groundtruth_boxes - 2D float32 tensor of shape
        [None, 4] containing box corners.
      fields.InputDataFields.groundtruth_classes - 1D int64 tensor of shape
        [None] containing classes for the boxes.
      fields.InputDataFields.groundtruth_area - 1D float32 tensor of shape
        [None] containing object mask area in pixel squared.
      fields.InputDataFields.groundtruth_is_crowd - 1D bool tensor of shape
        [None] indicating if the boxes enclose a crowd.
      fields.InputDataFields.groundtruth_difficult - 1D bool tensor of shape
        [None] indicating if the boxes represent `difficult` instances.
    """
    serialized_example = tf.reshape(tf_seq_example_string_tensor, shape=[])
    decoder = TFSequenceExampleDecoderHelper(self.keys_to_context_features,
                                             self.keys_to_features,
                                             self.items_to_handlers)
    if not items:
      items = decoder.list_items()
    tensors = decoder.decode(serialized_example, items=items)
    tensor_dict = dict(zip(items, tensors))

    return tensor_dict


class TFSequenceExampleDecoderHelper(data_decoder.DataDecoder):
  """A decoder helper class for TensorFlow SequenceExamples.

  To perform this decoding operation, a SequenceExampleDecoder is given a list
  of ItemHandlers. Each ItemHandler indicates the set of features.
  """

  def __init__(self, keys_to_context_features, keys_to_sequence_features,
               items_to_handlers):
    """Constructs the decoder.

    Args:
      keys_to_context_features: A dictionary from TF-SequenceExample context
        keys to either tf.VarLenFeature or tf.FixedLenFeature instances.
        See tensorflow's parsing_ops.py.
      keys_to_sequence_features: A dictionary from TF-SequenceExample sequence
        keys to either tf.VarLenFeature or tf.FixedLenSequenceFeature instances.
      items_to_handlers: A dictionary from items (strings) to ItemHandler
        instances. Note that the ItemHandler's are provided the keys that they
        use to return the final item Tensors.
    Raises:
      ValueError: If the same key is present for context features and sequence
        features.
    """
    unique_keys = set()
    unique_keys.update(keys_to_context_features)
    unique_keys.update(keys_to_sequence_features)
    if len(unique_keys) != (
        len(keys_to_context_features) + len(keys_to_sequence_features)):
      # This situation is ambiguous in the decoder's keys_to_tensors variable.
      raise ValueError('Context and sequence keys are not unique. \n'
                       ' Context keys: %s \n Sequence keys: %s' %
                       (list(keys_to_context_features.keys()),
                        list(keys_to_sequence_features.keys())))
    self._keys_to_context_features = keys_to_context_features
    self._keys_to_sequence_features = keys_to_sequence_features
    self._items_to_handlers = items_to_handlers

  def list_items(self):
    """Returns keys of items."""
    return self._items_to_handlers.keys()

  def decode(self, serialized_example, items=None):
    """Decodes the given serialized TF-SequenceExample.

    Args:
      serialized_example: A serialized TF-SequenceExample tensor.
      items: The list of items to decode. These must be a subset of the item
        keys in self._items_to_handlers. If `items` is left as None, then all
        of the items in self._items_to_handlers are decoded.
    Returns:
      The decoded items, a list of tensor.
    """
    context, feature_list = tf.parse_single_sequence_example(
        serialized_example, self._keys_to_context_features,
        self._keys_to_sequence_features)
    # Reshape non-sparse elements just once:
    for k in self._keys_to_context_features:
      v = self._keys_to_context_features[k]
      if isinstance(v, tf.FixedLenFeature):
        context[k] = tf.reshape(context[k], v.shape)
    if not items:
      items = self._items_to_handlers.keys()
    outputs = []
    for item in items:
      handler = self._items_to_handlers[item]
      keys_to_tensors = {
          key: context[key] if key in context else feature_list[key]
          for key in handler.keys
      }
      outputs.append(handler.tensors_to_item(keys_to_tensors))
    return outputs