research/object_detection/data_decoders/tf_sequence_example_decoder.py
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sequence example decoder for object detection."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import zip
import tensorflow.compat.v1 as tf
from tf_slim import tfexample_decoder as slim_example_decoder
from object_detection.core import data_decoder
from object_detection.core import standard_fields as fields
from object_detection.utils import label_map_util
# pylint: disable=g-import-not-at-top
try:
from tensorflow.contrib import lookup as contrib_lookup
except ImportError:
# TF 2.0 doesn't ship with contrib.
pass
# pylint: enable=g-import-not-at-top
class _ClassTensorHandler(slim_example_decoder.Tensor):
"""An ItemHandler to fetch class ids from class text."""
def __init__(self,
tensor_key,
label_map_proto_file,
shape_keys=None,
shape=None,
default_value=''):
"""Initializes the LookupTensor handler.
Simply calls a vocabulary (most often, a label mapping) lookup.
Args:
tensor_key: the name of the `TFExample` feature to read the tensor from.
label_map_proto_file: File path to a text format LabelMapProto message
mapping class text to id.
shape_keys: Optional name or list of names of the TF-Example feature in
which the tensor shape is stored. If a list, then each corresponds to
one dimension of the shape.
shape: Optional output shape of the `Tensor`. If provided, the `Tensor` is
reshaped accordingly.
default_value: The value used when the `tensor_key` is not found in a
particular `TFExample`.
Raises:
ValueError: if both `shape_keys` and `shape` are specified.
"""
name_to_id = label_map_util.get_label_map_dict(
label_map_proto_file, use_display_name=False)
# We use a default_value of -1, but we expect all labels to be contained
# in the label map.
try:
# Dynamically try to load the tf v2 lookup, falling back to contrib
lookup = tf.compat.v2.lookup
hash_table_class = tf.compat.v2.lookup.StaticHashTable
except AttributeError:
lookup = contrib_lookup
hash_table_class = contrib_lookup.HashTable
name_to_id_table = hash_table_class(
initializer=lookup.KeyValueTensorInitializer(
keys=tf.constant(list(name_to_id.keys())),
values=tf.constant(list(name_to_id.values()), dtype=tf.int64)),
default_value=-1)
self._name_to_id_table = name_to_id_table
super(_ClassTensorHandler, self).__init__(tensor_key, shape_keys, shape,
default_value)
def tensors_to_item(self, keys_to_tensors):
unmapped_tensor = super(_ClassTensorHandler,
self).tensors_to_item(keys_to_tensors)
return self._name_to_id_table.lookup(unmapped_tensor)
class TfSequenceExampleDecoder(data_decoder.DataDecoder):
"""Tensorflow Sequence Example proto decoder for Object Detection.
Sequence examples contain sequences of images which share common
features. The structure of TfSequenceExamples can be seen in
dataset_tools/seq_example_util.py
For the TFODAPI, the following fields are required:
Shared features:
'image/format'
'image/height'
'image/width'
Features with an entry for each image, where bounding box features can
be empty lists if the image does not contain any objects:
'image/encoded'
'image/source_id'
'region/bbox/xmin'
'region/bbox/xmax'
'region/bbox/ymin'
'region/bbox/ymax'
'region/label/string'
Optionally, the sequence example can include context_features for use in
Context R-CNN (see https://arxiv.org/abs/1912.03538):
'image/context_features'
'image/context_feature_length'
'image/context_features_image_id_list'
"""
def __init__(self,
label_map_proto_file,
load_context_features=False,
load_context_image_ids=False,
use_display_name=False,
fully_annotated=False):
"""Constructs `TfSequenceExampleDecoder` object.
Args:
label_map_proto_file: a file path to a
object_detection.protos.StringIntLabelMap proto. The
label map will be used to map IDs of 'region/label/string'.
It is assumed that 'region/label/string' will be in the data.
load_context_features: Whether to load information from context_features,
to provide additional context to a detection model for training and/or
inference
load_context_image_ids: Whether to load the corresponding image ids for
the context_features in order to visualize attention.
use_display_name: whether or not to use the `display_name` for label
mapping (instead of `name`). Only used if label_map_proto_file is
provided.
fully_annotated: If True, will assume that every frame (whether it has
boxes or not), has been fully annotated. If False, a
'region/is_annotated' field must be provided in the dataset which
indicates which frames have annotations. Default False.
"""
# Specifies how the tf.SequenceExamples are decoded.
self._context_keys_to_features = {
'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/height': tf.FixedLenFeature((), tf.int64),
'image/width': tf.FixedLenFeature((), tf.int64),
}
self._sequence_keys_to_feature_lists = {
'image/encoded': tf.FixedLenSequenceFeature([], dtype=tf.string),
'image/source_id': tf.FixedLenSequenceFeature([], dtype=tf.string),
'region/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
'region/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
'region/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
'region/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
'region/label/string': tf.VarLenFeature(dtype=tf.string),
'region/label/confidence': tf.VarLenFeature(dtype=tf.float32),
}
self._items_to_handlers = {
# Context.
fields.InputDataFields.image_height:
slim_example_decoder.Tensor('image/height'),
fields.InputDataFields.image_width:
slim_example_decoder.Tensor('image/width'),
# Sequence.
fields.InputDataFields.num_groundtruth_boxes:
slim_example_decoder.NumBoxesSequence('region/bbox/xmin'),
fields.InputDataFields.groundtruth_boxes:
slim_example_decoder.BoundingBoxSequence(
prefix='region/bbox/', default_value=0.0),
fields.InputDataFields.groundtruth_weights:
slim_example_decoder.Tensor('region/label/confidence'),
}
# If the dataset is sparsely annotated, parse sequence features which
# indicate which frames have been labeled.
if not fully_annotated:
self._sequence_keys_to_feature_lists['region/is_annotated'] = (
tf.FixedLenSequenceFeature([], dtype=tf.int64))
self._items_to_handlers[fields.InputDataFields.is_annotated] = (
slim_example_decoder.Tensor('region/is_annotated'))
self._items_to_handlers[fields.InputDataFields.image] = (
slim_example_decoder.Tensor('image/encoded'))
self._items_to_handlers[fields.InputDataFields.source_id] = (
slim_example_decoder.Tensor('image/source_id'))
label_handler = _ClassTensorHandler(
'region/label/string', label_map_proto_file, default_value='')
self._items_to_handlers[
fields.InputDataFields.groundtruth_classes] = label_handler
if load_context_features:
self._context_keys_to_features['image/context_features'] = (
tf.VarLenFeature(dtype=tf.float32))
self._items_to_handlers[fields.InputDataFields.context_features] = (
slim_example_decoder.ItemHandlerCallback(
['image/context_features', 'image/context_feature_length'],
self._reshape_context_features))
self._context_keys_to_features['image/context_feature_length'] = (
tf.FixedLenFeature((), tf.int64))
self._items_to_handlers[fields.InputDataFields.context_feature_length] = (
slim_example_decoder.Tensor('image/context_feature_length'))
if load_context_image_ids:
self._context_keys_to_features['image/context_features_image_id_list'] = (
tf.VarLenFeature(dtype=tf.string))
self._items_to_handlers[
fields.InputDataFields.context_features_image_id_list] = (
slim_example_decoder.Tensor(
'image/context_features_image_id_list',
default_value=''))
self._fully_annotated = fully_annotated
def decode(self, tf_seq_example_string_tensor):
"""Decodes serialized `tf.SequenceExample`s and returns a tensor dictionary.
Args:
tf_seq_example_string_tensor: a string tensor holding a serialized
`tf.SequenceExample`.
Returns:
A list of dictionaries with (at least) the following tensors:
fields.InputDataFields.source_id: a [num_frames] string tensor with a
unique ID for each frame.
fields.InputDataFields.num_groundtruth_boxes: a [num_frames] int32 tensor
specifying the number of boxes in each frame.
fields.InputDataFields.groundtruth_boxes: a [num_frames, num_boxes, 4]
float32 tensor with bounding boxes for each frame. Note that num_boxes
is the maximum boxes seen in any individual frame. Any frames with fewer
boxes are padded with 0.0.
fields.InputDataFields.groundtruth_classes: a [num_frames, num_boxes]
int32 tensor with class indices for each box in each frame.
fields.InputDataFields.groundtruth_weights: a [num_frames, num_boxes]
float32 tensor with weights of the groundtruth boxes.
fields.InputDataFields.is_annotated: a [num_frames] bool tensor specifying
whether the image was annotated or not. If False, the corresponding
entries in the groundtruth tensor will be ignored.
fields.InputDataFields.context_features - 1D float32 tensor of shape
[context_feature_length * num_context_features]
fields.InputDataFields.context_feature_length - int32 tensor specifying
the length of each feature in context_features
fields.InputDataFields.image: a [num_frames] string tensor with
the encoded images.
fields.inputDataFields.context_features_image_id_list: a 1D vector
of shape [num_context_features] containing string tensors.
"""
serialized_example = tf.reshape(tf_seq_example_string_tensor, shape=[])
decoder = slim_example_decoder.TFSequenceExampleDecoder(
self._context_keys_to_features, self._sequence_keys_to_feature_lists,
self._items_to_handlers)
keys = decoder.list_items()
tensors = decoder.decode(serialized_example, items=keys)
tensor_dict = dict(list(zip(keys, tensors)))
tensor_dict[fields.InputDataFields.groundtruth_boxes].set_shape(
[None, None, 4])
tensor_dict[fields.InputDataFields.num_groundtruth_boxes] = tf.cast(
tensor_dict[fields.InputDataFields.num_groundtruth_boxes],
dtype=tf.int32)
tensor_dict[fields.InputDataFields.groundtruth_classes] = tf.cast(
tensor_dict[fields.InputDataFields.groundtruth_classes], dtype=tf.int32)
tensor_dict[fields.InputDataFields.original_image_spatial_shape] = tf.cast(
tf.stack([
tensor_dict[fields.InputDataFields.image_height],
tensor_dict[fields.InputDataFields.image_width]
]),
dtype=tf.int32)
tensor_dict.pop(fields.InputDataFields.image_height)
tensor_dict.pop(fields.InputDataFields.image_width)
def default_groundtruth_weights():
"""Produces weights of 1.0 for each valid box, and 0.0 otherwise."""
num_boxes_per_frame = tensor_dict[
fields.InputDataFields.num_groundtruth_boxes]
max_num_boxes = tf.reduce_max(num_boxes_per_frame)
num_boxes_per_frame_tiled = tf.tile(
tf.expand_dims(num_boxes_per_frame, axis=-1),
multiples=tf.stack([1, max_num_boxes]))
range_tiled = tf.tile(
tf.expand_dims(tf.range(max_num_boxes), axis=0),
multiples=tf.stack([tf.shape(num_boxes_per_frame)[0], 1]))
return tf.cast(
tf.greater(num_boxes_per_frame_tiled, range_tiled), tf.float32)
tensor_dict[fields.InputDataFields.groundtruth_weights] = tf.cond(
tf.greater(
tf.size(tensor_dict[fields.InputDataFields.groundtruth_weights]),
0), lambda: tensor_dict[fields.InputDataFields.groundtruth_weights],
default_groundtruth_weights)
if self._fully_annotated:
tensor_dict[fields.InputDataFields.is_annotated] = tf.ones_like(
tensor_dict[fields.InputDataFields.num_groundtruth_boxes],
dtype=tf.bool)
else:
tensor_dict[fields.InputDataFields.is_annotated] = tf.cast(
tensor_dict[fields.InputDataFields.is_annotated], dtype=tf.bool)
return tensor_dict
def _reshape_context_features(self, keys_to_tensors):
"""Reshape context features.
The instance context_features are reshaped to
[num_context_features, context_feature_length]
Args:
keys_to_tensors: a dictionary from keys to tensors.
Returns:
A 2-D float tensor of shape [num_context_features, context_feature_length]
"""
context_feature_length = keys_to_tensors['image/context_feature_length']
to_shape = tf.cast(tf.stack([-1, context_feature_length]), tf.int32)
context_features = keys_to_tensors['image/context_features']
if isinstance(context_features, tf.SparseTensor):
context_features = tf.sparse_tensor_to_dense(context_features)
context_features = tf.reshape(context_features, to_shape)
return context_features