research/object_detection/meta_architectures/center_net_meta_arch.py
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The CenterNet meta architecture as described in the "Objects as Points" paper [1].
[1]: https://arxiv.org/abs/1904.07850
"""
import abc
import collections
import functools
import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf2
from object_detection.core import box_list
from object_detection.core import box_list_ops
from object_detection.core import keypoint_ops
from object_detection.core import model
from object_detection.core import standard_fields as fields
from object_detection.core import target_assigner as cn_assigner
from object_detection.utils import shape_utils
from object_detection.utils import target_assigner_utils as ta_utils
from object_detection.utils import tf_version
# Number of channels needed to predict size and offsets.
NUM_OFFSET_CHANNELS = 2
NUM_SIZE_CHANNELS = 2
# Error range for detecting peaks.
PEAK_EPSILON = 1e-6
class CenterNetFeatureExtractor(tf.keras.Model):
"""Base class for feature extractors for the CenterNet meta architecture.
Child classes are expected to override the _output_model property which will
return 1 or more tensors predicted by the feature extractor.
"""
__metaclass__ = abc.ABCMeta
def __init__(self, name=None, channel_means=(0., 0., 0.),
channel_stds=(1., 1., 1.), bgr_ordering=False):
"""Initializes a CenterNet feature extractor.
Args:
name: str, the name used for the underlying keras model.
channel_means: A tuple of floats, denoting the mean of each channel
which will be subtracted from it. If None or empty, we use 0s.
channel_stds: A tuple of floats, denoting the standard deviation of each
channel. Each channel will be divided by its standard deviation value.
If None or empty, we use 1s.
bgr_ordering: bool, if set will change the channel ordering to be in the
[blue, red, green] order.
"""
super(CenterNetFeatureExtractor, self).__init__(name=name)
if channel_means is None or len(channel_means) == 0: # pylint:disable=g-explicit-length-test
channel_means = [0., 0., 0.]
if channel_stds is None or len(channel_stds) == 0: # pylint:disable=g-explicit-length-test
channel_stds = [1., 1., 1.]
self._channel_means = channel_means
self._channel_stds = channel_stds
self._bgr_ordering = bgr_ordering
def preprocess(self, inputs):
"""Converts a batch of unscaled images to a scale suitable for the model.
This method normalizes the image using the given `channel_means` and
`channels_stds` values at initialization time while optionally flipping
the channel order if `bgr_ordering` is set.
Args:
inputs: a [batch, height, width, channels] float32 tensor
Returns:
outputs: a [batch, height, width, channels] float32 tensor
"""
if self._bgr_ordering:
red, green, blue = tf.unstack(inputs, axis=3)
inputs = tf.stack([blue, green, red], axis=3)
channel_means = tf.reshape(tf.constant(self._channel_means),
[1, 1, 1, -1])
channel_stds = tf.reshape(tf.constant(self._channel_stds),
[1, 1, 1, -1])
return (inputs - channel_means)/channel_stds
def preprocess_reverse(self, preprocessed_inputs):
"""Undo the preprocessing and return the raw image.
This is a convenience function for some algorithms that require access
to the raw inputs.
Args:
preprocessed_inputs: A [batch_size, height, width, channels] float
tensor preprocessed_inputs from the preprocess function.
Returns:
images: A [batch_size, height, width, channels] float tensor with
the preprocessing removed.
"""
channel_means = tf.reshape(tf.constant(self._channel_means),
[1, 1, 1, -1])
channel_stds = tf.reshape(tf.constant(self._channel_stds),
[1, 1, 1, -1])
inputs = (preprocessed_inputs * channel_stds) + channel_means
if self._bgr_ordering:
blue, green, red = tf.unstack(inputs, axis=3)
inputs = tf.stack([red, green, blue], axis=3)
return inputs
@property
@abc.abstractmethod
def out_stride(self):
"""The stride in the output image of the network."""
pass
@property
@abc.abstractmethod
def num_feature_outputs(self):
"""Ther number of feature outputs returned by the feature extractor."""
pass
@property
def classification_backbone(self):
raise NotImplementedError(
'Classification backbone not supported for {}'.format(type(self)))
def make_prediction_net(num_out_channels, kernel_sizes=(3), num_filters=(256),
bias_fill=None, use_depthwise=False, name=None,
unit_height_conv=True):
"""Creates a network to predict the given number of output channels.
This function is intended to make the prediction heads for the CenterNet
meta architecture.
Args:
num_out_channels: Number of output channels.
kernel_sizes: A list representing the sizes of the conv kernel in the
intermediate layer. Note that the length of the list indicates the number
of intermediate conv layers and it must be the same as the length of the
num_filters.
num_filters: A list representing the number of filters in the intermediate
conv layer. Note that the length of the list indicates the number of
intermediate conv layers.
bias_fill: If not None, is used to initialize the bias in the final conv
layer.
use_depthwise: If true, use SeparableConv2D to construct the Sequential
layers instead of Conv2D.
name: Optional name for the prediction net.
unit_height_conv: If True, Conv2Ds have asymmetric kernels with height=1.
Returns:
net: A keras module which when called on an input tensor of size
[batch_size, height, width, num_in_channels] returns an output
of size [batch_size, height, width, num_out_channels]
"""
if isinstance(kernel_sizes, int) and isinstance(num_filters, int):
kernel_sizes = [kernel_sizes]
num_filters = [num_filters]
assert len(kernel_sizes) == len(num_filters)
if use_depthwise:
conv_fn = tf.keras.layers.SeparableConv2D
else:
conv_fn = tf.keras.layers.Conv2D
# We name the convolution operations explicitly because Keras, by default,
# uses different names during training and evaluation. By setting the names
# here, we avoid unexpected pipeline breakage in TF1.
out_conv = tf.keras.layers.Conv2D(
num_out_channels,
kernel_size=1,
name='conv1' if tf_version.is_tf1() else None)
if bias_fill is not None:
out_conv.bias_initializer = tf.keras.initializers.constant(bias_fill)
layers = []
for idx, (kernel_size,
num_filter) in enumerate(zip(kernel_sizes, num_filters)):
layers.append(
conv_fn(
num_filter,
kernel_size=[1, kernel_size] if unit_height_conv else kernel_size,
padding='same',
name='conv2_%d' % idx if tf_version.is_tf1() else None))
layers.append(tf.keras.layers.ReLU())
layers.append(out_conv)
net = tf.keras.Sequential(layers, name=name)
return net
def _to_float32(x):
return tf.cast(x, tf.float32)
def _get_shape(tensor, num_dims):
assert len(tensor.shape.as_list()) == num_dims
return shape_utils.combined_static_and_dynamic_shape(tensor)
def _flatten_spatial_dimensions(batch_images):
batch_size, height, width, channels = _get_shape(batch_images, 4)
return tf.reshape(batch_images, [batch_size, height * width,
channels])
def _multi_range(limit,
value_repetitions=1,
range_repetitions=1,
dtype=tf.int32):
"""Creates a sequence with optional value duplication and range repetition.
As an example (see the Args section for more details),
_multi_range(limit=2, value_repetitions=3, range_repetitions=4) returns:
[0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1]
Args:
limit: A 0-D Tensor (scalar). Upper limit of sequence, exclusive.
value_repetitions: Integer. The number of times a value in the sequence is
repeated. With value_repetitions=3, the result is [0, 0, 0, 1, 1, 1, ..].
range_repetitions: Integer. The number of times the range is repeated. With
range_repetitions=3, the result is [0, 1, 2, .., 0, 1, 2, ..].
dtype: The type of the elements of the resulting tensor.
Returns:
A 1-D tensor of type `dtype` and size
[`limit` * `value_repetitions` * `range_repetitions`] that contains the
specified range with given repetitions.
"""
return tf.reshape(
tf.tile(
tf.expand_dims(tf.range(limit, dtype=dtype), axis=-1),
multiples=[range_repetitions, value_repetitions]), [-1])
def top_k_feature_map_locations(feature_map, max_pool_kernel_size=3, k=100,
per_channel=False):
"""Returns the top k scores and their locations in a feature map.
Given a feature map, the top k values (based on activation) are returned. If
`per_channel` is True, the top k values **per channel** are returned. Note
that when k equals to 1, ths function uses reduce_max and argmax instead of
top_k to make the logics more efficient.
The `max_pool_kernel_size` argument allows for selecting local peaks in a
region. This filtering is done per channel, so nothing prevents two values at
the same location to be returned.
Args:
feature_map: [batch, height, width, channels] float32 feature map.
max_pool_kernel_size: integer, the max pool kernel size to use to pull off
peak score locations in a neighborhood (independently for each channel).
For example, to make sure no two neighboring values (in the same channel)
are returned, set max_pool_kernel_size=3. If None or 1, will not apply max
pooling.
k: The number of highest scoring locations to return.
per_channel: If True, will return the top k scores and locations per
feature map channel. If False, the top k across the entire feature map
(height x width x channels) are returned.
Returns:
Tuple of
scores: A [batch, N] float32 tensor with scores from the feature map in
descending order. If per_channel is False, N = k. Otherwise,
N = k * channels, and the first k elements correspond to channel 0, the
second k correspond to channel 1, etc.
y_indices: A [batch, N] int tensor with y indices of the top k feature map
locations. If per_channel is False, N = k. Otherwise,
N = k * channels.
x_indices: A [batch, N] int tensor with x indices of the top k feature map
locations. If per_channel is False, N = k. Otherwise,
N = k * channels.
channel_indices: A [batch, N] int tensor with channel indices of the top k
feature map locations. If per_channel is False, N = k. Otherwise,
N = k * channels.
"""
if not max_pool_kernel_size or max_pool_kernel_size == 1:
feature_map_peaks = feature_map
else:
feature_map_max_pool = tf.nn.max_pool(
feature_map, ksize=max_pool_kernel_size, strides=1, padding='SAME')
feature_map_peak_mask = tf.math.abs(
feature_map - feature_map_max_pool) < PEAK_EPSILON
# Zero out everything that is not a peak.
feature_map_peaks = (
feature_map * _to_float32(feature_map_peak_mask))
batch_size, _, width, num_channels = _get_shape(feature_map, 4)
if per_channel:
if k == 1:
feature_map_flattened = tf.reshape(
feature_map_peaks, [batch_size, -1, num_channels])
scores = tf.math.reduce_max(feature_map_flattened, axis=1)
peak_flat_indices = tf.math.argmax(
feature_map_flattened, axis=1, output_type=tf.dtypes.int32)
peak_flat_indices = tf.expand_dims(peak_flat_indices, axis=-1)
else:
# Perform top k over batch and channels.
feature_map_peaks_transposed = tf.transpose(feature_map_peaks,
perm=[0, 3, 1, 2])
feature_map_peaks_transposed = tf.reshape(
feature_map_peaks_transposed, [batch_size, num_channels, -1])
# safe_k will be used whenever there are fewer positions in the heatmap
# than the requested number of locations to score. In that case, all
# positions are returned in sorted order. To ensure consistent shapes for
# downstream ops the outputs are padded with zeros. Safe_k is also
# fine for TPU because TPUs require a fixed input size so the number of
# positions will also be fixed.
safe_k = tf.minimum(k, tf.shape(feature_map_peaks_transposed)[-1])
scores, peak_flat_indices = tf.math.top_k(
feature_map_peaks_transposed, k=safe_k)
scores = tf.pad(scores, [(0, 0), (0, 0), (0, k - safe_k)])
peak_flat_indices = tf.pad(peak_flat_indices,
[(0, 0), (0, 0), (0, k - safe_k)])
scores = tf.ensure_shape(scores, (batch_size, num_channels, k))
peak_flat_indices = tf.ensure_shape(peak_flat_indices,
(batch_size, num_channels, k))
# Convert the indices such that they represent the location in the full
# (flattened) feature map of size [batch, height * width * channels].
channel_idx = tf.range(num_channels)[tf.newaxis, :, tf.newaxis]
peak_flat_indices = num_channels * peak_flat_indices + channel_idx
scores = tf.reshape(scores, [batch_size, -1])
peak_flat_indices = tf.reshape(peak_flat_indices, [batch_size, -1])
else:
if k == 1:
feature_map_peaks_flat = tf.reshape(feature_map_peaks, [batch_size, -1])
scores = tf.math.reduce_max(feature_map_peaks_flat, axis=1, keepdims=True)
peak_flat_indices = tf.expand_dims(tf.math.argmax(
feature_map_peaks_flat, axis=1, output_type=tf.dtypes.int32), axis=-1)
else:
feature_map_peaks_flat = tf.reshape(feature_map_peaks, [batch_size, -1])
safe_k = tf.minimum(k, tf.shape(feature_map_peaks_flat)[1])
scores, peak_flat_indices = tf.math.top_k(feature_map_peaks_flat,
k=safe_k)
# Get x, y and channel indices corresponding to the top indices in the flat
# array.
y_indices, x_indices, channel_indices = (
row_col_channel_indices_from_flattened_indices(
peak_flat_indices, width, num_channels))
return scores, y_indices, x_indices, channel_indices
def prediction_tensors_to_boxes(y_indices, x_indices, height_width_predictions,
offset_predictions):
"""Converts CenterNet class-center, offset and size predictions to boxes.
Args:
y_indices: A [batch, num_boxes] int32 tensor with y indices corresponding to
object center locations (expressed in output coordinate frame).
x_indices: A [batch, num_boxes] int32 tensor with x indices corresponding to
object center locations (expressed in output coordinate frame).
height_width_predictions: A float tensor of shape [batch_size, height,
width, 2] representing the height and width of a box centered at each
pixel.
offset_predictions: A float tensor of shape [batch_size, height, width, 2]
representing the y and x offsets of a box centered at each pixel. This
helps reduce the error from downsampling.
Returns:
detection_boxes: A tensor of shape [batch_size, num_boxes, 4] holding the
the raw bounding box coordinates of boxes.
"""
batch_size, num_boxes = _get_shape(y_indices, 2)
_, height, width, _ = _get_shape(height_width_predictions, 4)
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
# TF Lite does not support tf.gather with batch_dims > 0, so we need to use
# tf_gather_nd instead and here we prepare the indices for that.
combined_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_boxes),
tf.reshape(y_indices, [-1]),
tf.reshape(x_indices, [-1])
], axis=1)
new_height_width = tf.gather_nd(height_width_predictions, combined_indices)
new_height_width = tf.reshape(new_height_width, [batch_size, num_boxes, 2])
new_offsets = tf.gather_nd(offset_predictions, combined_indices)
offsets = tf.reshape(new_offsets, [batch_size, num_boxes, 2])
y_indices = _to_float32(y_indices)
x_indices = _to_float32(x_indices)
height_width = tf.maximum(new_height_width, 0)
heights, widths = tf.unstack(height_width, axis=2)
y_offsets, x_offsets = tf.unstack(offsets, axis=2)
ymin = y_indices + y_offsets - heights / 2.0
xmin = x_indices + x_offsets - widths / 2.0
ymax = y_indices + y_offsets + heights / 2.0
xmax = x_indices + x_offsets + widths / 2.0
ymin = tf.clip_by_value(ymin, 0., height)
xmin = tf.clip_by_value(xmin, 0., width)
ymax = tf.clip_by_value(ymax, 0., height)
xmax = tf.clip_by_value(xmax, 0., width)
boxes = tf.stack([ymin, xmin, ymax, xmax], axis=2)
return boxes
def prediction_tensors_to_temporal_offsets(
y_indices, x_indices, offset_predictions):
"""Converts CenterNet temporal offset map predictions to batched format.
This function is similar to the box offset conversion function, as both
temporal offsets and box offsets are size-2 vectors.
Args:
y_indices: A [batch, num_boxes] int32 tensor with y indices corresponding to
object center locations (expressed in output coordinate frame).
x_indices: A [batch, num_boxes] int32 tensor with x indices corresponding to
object center locations (expressed in output coordinate frame).
offset_predictions: A float tensor of shape [batch_size, height, width, 2]
representing the y and x offsets of a box's center across adjacent frames.
Returns:
offsets: A tensor of shape [batch_size, num_boxes, 2] holding the
the object temporal offsets of (y, x) dimensions.
"""
batch_size, num_boxes = _get_shape(y_indices, 2)
# TF Lite does not support tf.gather with batch_dims > 0, so we need to use
# tf_gather_nd instead and here we prepare the indices for that.
combined_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_boxes),
tf.reshape(y_indices, [-1]),
tf.reshape(x_indices, [-1])
], axis=1)
new_offsets = tf.gather_nd(offset_predictions, combined_indices)
offsets = tf.reshape(new_offsets, [batch_size, num_boxes, -1])
return offsets
def prediction_tensors_to_keypoint_candidates(keypoint_heatmap_predictions,
keypoint_heatmap_offsets,
keypoint_score_threshold=0.1,
max_pool_kernel_size=1,
max_candidates=20,
keypoint_depths=None):
"""Convert keypoint heatmap predictions and offsets to keypoint candidates.
Args:
keypoint_heatmap_predictions: A float tensor of shape [batch_size, height,
width, num_keypoints] representing the per-keypoint heatmaps.
keypoint_heatmap_offsets: A float tensor of shape [batch_size, height,
width, 2] (or [batch_size, height, width, 2 * num_keypoints] if
'per_keypoint_offset' is set True) representing the per-keypoint offsets.
keypoint_score_threshold: float, the threshold for considering a keypoint a
candidate.
max_pool_kernel_size: integer, the max pool kernel size to use to pull off
peak score locations in a neighborhood. For example, to make sure no two
neighboring values for the same keypoint are returned, set
max_pool_kernel_size=3. If None or 1, will not apply any local filtering.
max_candidates: integer, maximum number of keypoint candidates per keypoint
type.
keypoint_depths: (optional) A float tensor of shape [batch_size, height,
width, 1] (or [batch_size, height, width, num_keypoints] if
'per_keypoint_depth' is set True) representing the per-keypoint depths.
Returns:
keypoint_candidates: A tensor of shape
[batch_size, max_candidates, num_keypoints, 2] holding the
location of keypoint candidates in [y, x] format (expressed in absolute
coordinates in the output coordinate frame).
keypoint_scores: A float tensor of shape
[batch_size, max_candidates, num_keypoints] with the scores for each
keypoint candidate. The scores come directly from the heatmap predictions.
num_keypoint_candidates: An integer tensor of shape
[batch_size, num_keypoints] with the number of candidates for each
keypoint type, as it's possible to filter some candidates due to the score
threshold.
depth_candidates: A tensor of shape [batch_size, max_candidates,
num_keypoints] representing the estimated depth of each keypoint
candidate. Return None if the input keypoint_depths is None.
"""
batch_size, _, _, num_keypoints = _get_shape(keypoint_heatmap_predictions, 4)
# Get x, y and channel indices corresponding to the top indices in the
# keypoint heatmap predictions.
# Note that the top k candidates are produced for **each keypoint type**.
# Might be worth eventually trying top k in the feature map, independent of
# the keypoint type.
keypoint_scores, y_indices, x_indices, channel_indices = (
top_k_feature_map_locations(keypoint_heatmap_predictions,
max_pool_kernel_size=max_pool_kernel_size,
k=max_candidates,
per_channel=True))
# TF Lite does not support tf.gather with batch_dims > 0, so we need to use
# tf_gather_nd instead and here we prepare the indices for that.
_, num_indices = _get_shape(y_indices, 2)
combined_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_indices),
tf.reshape(y_indices, [-1]),
tf.reshape(x_indices, [-1])
], axis=1)
selected_offsets_flat = tf.gather_nd(keypoint_heatmap_offsets,
combined_indices)
selected_offsets = tf.reshape(selected_offsets_flat,
[batch_size, num_indices, -1])
y_indices = _to_float32(y_indices)
x_indices = _to_float32(x_indices)
_, _, num_channels = _get_shape(selected_offsets, 3)
if num_channels > 2:
# Offsets are per keypoint and the last dimension of selected_offsets
# contains all those offsets, so reshape the offsets to make sure that the
# last dimension contains (y_offset, x_offset) for a single keypoint.
reshaped_offsets = tf.reshape(selected_offsets,
[batch_size, num_indices, -1, 2])
# TF Lite does not support tf.gather with batch_dims > 0, so we need to use
# tf_gather_nd instead and here we prepare the indices for that. In this
# case, channel_indices indicates which keypoint to use the offset from.
channel_combined_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_indices),
_multi_range(num_indices, range_repetitions=batch_size),
tf.reshape(channel_indices, [-1])
], axis=1)
offsets = tf.gather_nd(reshaped_offsets, channel_combined_indices)
offsets = tf.reshape(offsets, [batch_size, num_indices, -1])
else:
offsets = selected_offsets
y_offsets, x_offsets = tf.unstack(offsets, axis=2)
keypoint_candidates = tf.stack([y_indices + y_offsets,
x_indices + x_offsets], axis=2)
keypoint_candidates = tf.reshape(
keypoint_candidates,
[batch_size, num_keypoints, max_candidates, 2])
keypoint_candidates = tf.transpose(keypoint_candidates, [0, 2, 1, 3])
keypoint_scores = tf.reshape(
keypoint_scores,
[batch_size, num_keypoints, max_candidates])
keypoint_scores = tf.transpose(keypoint_scores, [0, 2, 1])
num_candidates = tf.reduce_sum(
tf.to_int32(keypoint_scores >= keypoint_score_threshold), axis=1)
depth_candidates = None
if keypoint_depths is not None:
selected_depth_flat = tf.gather_nd(keypoint_depths, combined_indices)
selected_depth = tf.reshape(selected_depth_flat,
[batch_size, num_indices, -1])
_, _, num_depth_channels = _get_shape(selected_depth, 3)
if num_depth_channels > 1:
combined_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_indices),
_multi_range(num_indices, range_repetitions=batch_size),
tf.reshape(channel_indices, [-1])
], axis=1)
depth = tf.gather_nd(selected_depth, combined_indices)
depth = tf.reshape(depth, [batch_size, num_indices, -1])
else:
depth = selected_depth
depth_candidates = tf.reshape(depth,
[batch_size, num_keypoints, max_candidates])
depth_candidates = tf.transpose(depth_candidates, [0, 2, 1])
return keypoint_candidates, keypoint_scores, num_candidates, depth_candidates
def argmax_feature_map_locations(feature_map):
"""Returns the peak locations in the feature map."""
batch_size, _, width, num_channels = _get_shape(feature_map, 4)
feature_map_flattened = tf.reshape(
feature_map, [batch_size, -1, num_channels])
peak_flat_indices = tf.math.argmax(
feature_map_flattened, axis=1, output_type=tf.dtypes.int32)
# Get x and y indices corresponding to the top indices in the flat array.
y_indices, x_indices = (
row_col_indices_from_flattened_indices(peak_flat_indices, width))
channel_indices = tf.tile(
tf.range(num_channels)[tf.newaxis, :], [batch_size, 1])
return y_indices, x_indices, channel_indices
def prediction_tensors_to_single_instance_kpts(
keypoint_heatmap_predictions,
keypoint_heatmap_offsets,
keypoint_score_heatmap=None):
"""Convert keypoint heatmap predictions and offsets to keypoint candidates.
Args:
keypoint_heatmap_predictions: A float tensor of shape [batch_size, height,
width, num_keypoints] representing the per-keypoint heatmaps which is
used for finding the best keypoint candidate locations.
keypoint_heatmap_offsets: A float tensor of shape [batch_size, height,
width, 2] (or [batch_size, height, width, 2 * num_keypoints] if
'per_keypoint_offset' is set True) representing the per-keypoint offsets.
keypoint_score_heatmap: (optional) A float tensor of shape [batch_size,
height, width, num_keypoints] representing the heatmap which is used for
reporting the confidence scores. If not provided, then the values in the
keypoint_heatmap_predictions will be used.
Returns:
keypoint_candidates: A tensor of shape
[batch_size, max_candidates, num_keypoints, 2] holding the
location of keypoint candidates in [y, x] format (expressed in absolute
coordinates in the output coordinate frame).
keypoint_scores: A float tensor of shape
[batch_size, max_candidates, num_keypoints] with the scores for each
keypoint candidate. The scores come directly from the heatmap predictions.
num_keypoint_candidates: An integer tensor of shape
[batch_size, num_keypoints] with the number of candidates for each
keypoint type, as it's possible to filter some candidates due to the score
threshold.
"""
batch_size, _, _, num_keypoints = _get_shape(
keypoint_heatmap_predictions, 4)
# Get x, y and channel indices corresponding to the top indices in the
# keypoint heatmap predictions.
y_indices, x_indices, channel_indices = argmax_feature_map_locations(
keypoint_heatmap_predictions)
# TF Lite does not support tf.gather with batch_dims > 0, so we need to use
# tf_gather_nd instead and here we prepare the indices for that.
_, num_keypoints = _get_shape(y_indices, 2)
combined_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_keypoints),
tf.reshape(y_indices, [-1]),
tf.reshape(x_indices, [-1]),
], axis=1)
# shape: [num_keypoints, num_keypoints * 2]
selected_offsets_flat = tf.gather_nd(keypoint_heatmap_offsets,
combined_indices)
# shape: [num_keypoints, num_keypoints, 2].
selected_offsets_flat = tf.reshape(
selected_offsets_flat, [num_keypoints, num_keypoints, -1])
# shape: [num_keypoints].
channel_indices = tf.keras.backend.flatten(channel_indices)
# shape: [num_keypoints, 2].
retrieve_indices = tf.stack([channel_indices, channel_indices], axis=1)
# shape: [num_keypoints, 2]
selected_offsets = tf.gather_nd(selected_offsets_flat, retrieve_indices)
y_offsets, x_offsets = tf.unstack(selected_offsets, axis=1)
keypoint_candidates = tf.stack([
tf.cast(y_indices, dtype=tf.float32) + tf.expand_dims(y_offsets, axis=0),
tf.cast(x_indices, dtype=tf.float32) + tf.expand_dims(x_offsets, axis=0)
], axis=2)
keypoint_candidates = tf.expand_dims(keypoint_candidates, axis=0)
# Append the channel indices back to retrieve the keypoint scores from the
# heatmap.
combined_indices = tf.concat(
[combined_indices, tf.expand_dims(channel_indices, axis=-1)], axis=1)
if keypoint_score_heatmap is None:
keypoint_scores = tf.gather_nd(
keypoint_heatmap_predictions, combined_indices)
else:
keypoint_scores = tf.gather_nd(keypoint_score_heatmap, combined_indices)
keypoint_scores = tf.expand_dims(
tf.expand_dims(keypoint_scores, axis=0), axis=0)
return keypoint_candidates, keypoint_scores
def _score_to_distance_map(y_grid, x_grid, heatmap, points_y, points_x,
score_distance_offset):
"""Rescores heatmap using the distance information.
Rescore the heatmap scores using the formula:
score / (d + score_distance_offset), where the d is the distance from each
pixel location to the target point location.
Args:
y_grid: A float tensor with shape [height, width] representing the
y-coordinate of each pixel grid.
x_grid: A float tensor with shape [height, width] representing the
x-coordinate of each pixel grid.
heatmap: A float tensor with shape [1, height, width, channel]
representing the heatmap to be rescored.
points_y: A float tensor with shape [channel] representing the y
coordinates of the target points for each channel.
points_x: A float tensor with shape [channel] representing the x
coordinates of the target points for each channel.
score_distance_offset: A constant used in the above formula.
Returns:
A float tensor with shape [1, height, width, channel] representing the
rescored heatmap.
"""
y_diff = y_grid[:, :, tf.newaxis] - points_y
x_diff = x_grid[:, :, tf.newaxis] - points_x
distance = tf.math.sqrt(y_diff**2 + x_diff**2)
return tf.math.divide(heatmap, distance + score_distance_offset)
def prediction_to_single_instance_keypoints(
object_heatmap,
keypoint_heatmap,
keypoint_offset,
keypoint_regression,
kp_params,
keypoint_depths=None):
"""Postprocess function to predict single instance keypoints.
This is a simplified postprocessing function based on the assumption that
there is only one instance in the image. If there are multiple instances in
the image, the model prefers to predict the one that is closest to the image
center. Here is a high-level description of what this function does:
1) Object heatmap re-weighted by the distance between each pixel to the
image center is used to determine the instance center.
2) Regressed keypoint locations are retrieved from the instance center. The
Gaussian kernel is applied to the regressed keypoint locations to
re-weight the keypoint heatmap. This is to select the keypoints that are
associated with the center instance without using top_k op.
3) The keypoint locations are computed by the re-weighted keypoint heatmap
and the keypoint offset.
Args:
object_heatmap: A float tensor of shape [1, height, width, 1] representing
the heapmap of the class.
keypoint_heatmap: A float tensor of shape [1, height, width, num_keypoints]
representing the per-keypoint heatmaps.
keypoint_offset: A float tensor of shape [1, height, width, 2] (or [1,
height, width, 2 * num_keypoints] if 'per_keypoint_offset' is set True)
representing the per-keypoint offsets.
keypoint_regression: A float tensor of shape [1, height, width, 2 *
num_keypoints] representing the joint regression prediction.
kp_params: A `KeypointEstimationParams` object with parameters for a single
keypoint class.
keypoint_depths: (optional) A float tensor of shape [batch_size, height,
width, 1] (or [batch_size, height, width, num_keypoints] if
'per_keypoint_depth' is set True) representing the per-keypoint depths.
Returns:
A tuple of two tensors:
keypoint_candidates: A float tensor with shape [1, 1, num_keypoints, 2]
representing the yx-coordinates of the keypoints in the output feature
map space.
keypoint_scores: A float tensor with shape [1, 1, num_keypoints]
representing the keypoint prediction scores.
Raises:
ValueError: if the input keypoint_std_dev doesn't have valid number of
elements (1 or num_keypoints).
"""
# TODO(yuhuic): add the keypoint depth prediction logics in the browser
# postprocessing back.
del keypoint_depths
num_keypoints = len(kp_params.keypoint_std_dev)
batch_size, height, width, _ = _get_shape(keypoint_heatmap, 4)
# Create the image center location.
image_center_y = tf.convert_to_tensor([0.5 * height], dtype=tf.float32)
image_center_x = tf.convert_to_tensor([0.5 * width], dtype=tf.float32)
(y_grid, x_grid) = ta_utils.image_shape_to_grids(height, width)
# Rescore the object heatmap by the distnace to the image center.
object_heatmap = _score_to_distance_map(
y_grid, x_grid, object_heatmap, image_center_y,
image_center_x, kp_params.score_distance_offset)
# Pick the highest score and location of the weighted object heatmap.
y_indices, x_indices, _ = argmax_feature_map_locations(object_heatmap)
_, num_indices = _get_shape(y_indices, 2)
combined_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_indices),
tf.reshape(y_indices, [-1]),
tf.reshape(x_indices, [-1])
], axis=1)
# Select the regression vectors from the object center.
selected_regression_flat = tf.gather_nd(keypoint_regression, combined_indices)
# shape: [num_keypoints, 2]
regression_offsets = tf.reshape(selected_regression_flat, [num_keypoints, -1])
(y_reg, x_reg) = tf.unstack(regression_offsets, axis=1)
y_regressed = tf.cast(y_indices, dtype=tf.float32) + y_reg
x_regressed = tf.cast(x_indices, dtype=tf.float32) + x_reg
if kp_params.candidate_ranking_mode == 'score_distance_ratio':
reweighted_keypoint_heatmap = _score_to_distance_map(
y_grid, x_grid, keypoint_heatmap, y_regressed, x_regressed,
kp_params.score_distance_offset)
else:
raise ValueError('Unsupported candidate_ranking_mode: %s' %
kp_params.candidate_ranking_mode)
# Get the keypoint locations/scores:
# keypoint_candidates: [1, 1, num_keypoints, 2]
# keypoint_scores: [1, 1, num_keypoints]
# depth_candidates: [1, 1, num_keypoints]
(keypoint_candidates, keypoint_scores
) = prediction_tensors_to_single_instance_kpts(
reweighted_keypoint_heatmap,
keypoint_offset,
keypoint_score_heatmap=keypoint_heatmap)
return keypoint_candidates, keypoint_scores, None
def _gaussian_weighted_map_const_multi(
y_grid, x_grid, heatmap, points_y, points_x, boxes,
gaussian_denom_ratio):
"""Rescores heatmap using the distance information.
The function is called when the candidate_ranking_mode in the
KeypointEstimationParams is set to be 'gaussian_weighted_const'. The
keypoint candidates are ranked using the formula:
heatmap_score * exp((-distances^2) / (gaussian_denom))
where 'gaussian_denom' is determined by:
min(output_feature_height, output_feature_width) * gaussian_denom_ratio
the 'distances' are the distances between the grid coordinates and the target
points.
Note that the postfix 'const' refers to the fact that the denominator is a
constant given the input image size, not scaled by the size of each of the
instances.
Args:
y_grid: A float tensor with shape [height, width] representing the
y-coordinate of each pixel grid.
x_grid: A float tensor with shape [height, width] representing the
x-coordinate of each pixel grid.
heatmap: A float tensor with shape [height, width, num_keypoints]
representing the heatmap to be rescored.
points_y: A float tensor with shape [num_instances, num_keypoints]
representing the y coordinates of the target points for each channel.
points_x: A float tensor with shape [num_instances, num_keypoints]
representing the x coordinates of the target points for each channel.
boxes: A tensor of shape [num_instances, 4] with predicted bounding boxes
for each instance, expressed in the output coordinate frame.
gaussian_denom_ratio: A constant used in the above formula that determines
the denominator of the Gaussian kernel.
Returns:
A float tensor with shape [height, width, channel] representing
the rescored heatmap.
"""
num_instances, _ = _get_shape(boxes, 2)
height, width, num_keypoints = _get_shape(heatmap, 3)
# [height, width, num_instances, num_keypoints].
# Note that we intentionally avoid using tf.newaxis as TfLite converter
# doesn't like it.
y_diff = (
tf.reshape(y_grid, [height, width, 1, 1]) -
tf.reshape(points_y, [1, 1, num_instances, num_keypoints]))
x_diff = (
tf.reshape(x_grid, [height, width, 1, 1]) -
tf.reshape(points_x, [1, 1, num_instances, num_keypoints]))
distance_square = y_diff * y_diff + x_diff * x_diff
y_min, x_min, y_max, x_max = tf.split(boxes, 4, axis=1)
# Make the mask with all 1.0 in the box regions.
# Shape: [height, width, num_instances]
in_boxes = tf.math.logical_and(
tf.math.logical_and(
tf.reshape(y_grid, [height, width, 1]) >= tf.reshape(
y_min, [1, 1, num_instances]),
tf.reshape(y_grid, [height, width, 1]) < tf.reshape(
y_max, [1, 1, num_instances])),
tf.math.logical_and(
tf.reshape(x_grid, [height, width, 1]) >= tf.reshape(
x_min, [1, 1, num_instances]),
tf.reshape(x_grid, [height, width, 1]) < tf.reshape(
x_max, [1, 1, num_instances])))
in_boxes = tf.cast(in_boxes, dtype=tf.float32)
gaussian_denom = tf.cast(
tf.minimum(height, width), dtype=tf.float32) * gaussian_denom_ratio
# shape: [height, width, num_instances, num_keypoints]
gaussian_map = tf.exp((-1 * distance_square) / gaussian_denom)
return tf.expand_dims(heatmap, axis=2) * gaussian_map * tf.reshape(
in_boxes, [height, width, num_instances, 1])
def prediction_tensors_to_multi_instance_kpts(
keypoint_heatmap_predictions,
keypoint_heatmap_offsets,
keypoint_score_heatmap=None):
"""Converts keypoint heatmap predictions and offsets to keypoint candidates.
This function is similar to the 'prediction_tensors_to_single_instance_kpts'
function except that the input keypoint_heatmap_predictions is prepared to
have an additional 'num_instances' dimension for multi-instance prediction.
Args:
keypoint_heatmap_predictions: A float tensor of shape [height,
width, num_instances, num_keypoints] representing the per-keypoint and
per-instance heatmaps which is used for finding the best keypoint
candidate locations.
keypoint_heatmap_offsets: A float tensor of shape [height,
width, 2 * num_keypoints] representing the per-keypoint offsets.
keypoint_score_heatmap: (optional) A float tensor of shape [height, width,
num_keypoints] representing the heatmap which is used for reporting the
confidence scores. If not provided, then the values in the
keypoint_heatmap_predictions will be used.
Returns:
keypoint_candidates: A tensor of shape
[1, max_candidates, num_keypoints, 2] holding the
location of keypoint candidates in [y, x] format (expressed in absolute
coordinates in the output coordinate frame).
keypoint_scores: A float tensor of shape
[1, max_candidates, num_keypoints] with the scores for each
keypoint candidate. The scores come directly from the heatmap predictions.
"""
height, width, num_instances, num_keypoints = _get_shape(
keypoint_heatmap_predictions, 4)
# [height * width, num_instances * num_keypoints].
feature_map_flattened = tf.reshape(
keypoint_heatmap_predictions,
[-1, num_instances * num_keypoints])
# [num_instances * num_keypoints].
peak_flat_indices = tf.math.argmax(
feature_map_flattened, axis=0, output_type=tf.dtypes.int32)
# Get x and y indices corresponding to the top indices in the flat array.
y_indices, x_indices = (
row_col_indices_from_flattened_indices(peak_flat_indices, width))
# [num_instances * num_keypoints].
y_indices = tf.reshape(y_indices, [-1])
x_indices = tf.reshape(x_indices, [-1])
# Prepare the indices to gather the offsets from the keypoint_heatmap_offsets.
kpts_idx = _multi_range(
limit=num_keypoints, value_repetitions=1,
range_repetitions=num_instances)
combined_indices = tf.stack([
y_indices,
x_indices,
kpts_idx
], axis=1)
keypoint_heatmap_offsets = tf.reshape(
keypoint_heatmap_offsets, [height, width, num_keypoints, 2])
# Retrieve the keypoint offsets: shape:
# [num_instance * num_keypoints, 2].
selected_offsets_flat = tf.gather_nd(keypoint_heatmap_offsets,
combined_indices)
y_offsets, x_offsets = tf.unstack(selected_offsets_flat, axis=1)
keypoint_candidates = tf.stack([
tf.cast(y_indices, dtype=tf.float32) + tf.expand_dims(y_offsets, axis=0),
tf.cast(x_indices, dtype=tf.float32) + tf.expand_dims(x_offsets, axis=0)
], axis=2)
keypoint_candidates = tf.reshape(
keypoint_candidates, [num_instances, num_keypoints, 2])
if keypoint_score_heatmap is None:
keypoint_scores = tf.gather_nd(
tf.reduce_max(keypoint_heatmap_predictions, axis=2), combined_indices)
else:
keypoint_scores = tf.gather_nd(keypoint_score_heatmap, combined_indices)
return tf.expand_dims(keypoint_candidates, axis=0), tf.reshape(
keypoint_scores, [1, num_instances, num_keypoints])
def prediction_to_keypoints_argmax(
prediction_dict,
object_y_indices,
object_x_indices,
boxes,
task_name,
kp_params):
"""Postprocess function to predict multi instance keypoints with argmax op.
This is a different implementation of the original keypoint postprocessing
function such that it avoids using topk op (replaced by argmax) as it runs
much slower in the browser. Note that in this function, we assume the
batch_size to be 1 to avoid using 5D tensors which cause issues when
converting to the TfLite model.
Args:
prediction_dict: a dictionary holding predicted tensors, returned from the
predict() method. This dictionary should contain keypoint prediction
feature maps for each keypoint task.
object_y_indices: A float tensor of shape [batch_size, max_instances]
representing the location indices of the object centers.
object_x_indices: A float tensor of shape [batch_size, max_instances]
representing the location indices of the object centers.
boxes: A tensor of shape [batch_size, num_instances, 4] with predicted
bounding boxes for each instance, expressed in the output coordinate
frame.
task_name: string, the name of the task this namedtuple corresponds to.
Note that it should be an unique identifier of the task.
kp_params: A `KeypointEstimationParams` object with parameters for a single
keypoint class.
Returns:
A tuple of two tensors:
keypoint_candidates: A float tensor with shape [batch_size,
num_instances, num_keypoints, 2] representing the yx-coordinates of
the keypoints in the output feature map space.
keypoint_scores: A float tensor with shape [batch_size, num_instances,
num_keypoints] representing the keypoint prediction scores.
Raises:
ValueError: if the candidate_ranking_mode is not supported.
"""
keypoint_heatmap = tf.squeeze(tf.nn.sigmoid(prediction_dict[
get_keypoint_name(task_name, KEYPOINT_HEATMAP)][-1]), axis=0)
keypoint_offset = tf.squeeze(prediction_dict[
get_keypoint_name(task_name, KEYPOINT_OFFSET)][-1], axis=0)
keypoint_regression = tf.squeeze(prediction_dict[
get_keypoint_name(task_name, KEYPOINT_REGRESSION)][-1], axis=0)
height, width, num_keypoints = _get_shape(keypoint_heatmap, 3)
# Create the y,x grids: [height, width]
(y_grid, x_grid) = ta_utils.image_shape_to_grids(height, width)
# Prepare the indices to retrieve the information from object centers.
num_instances = _get_shape(object_y_indices, 2)[1]
combined_obj_indices = tf.stack([
tf.reshape(object_y_indices, [-1]),
tf.reshape(object_x_indices, [-1])
], axis=1)
# Select the regression vectors from the object center.
selected_regression_flat = tf.gather_nd(
keypoint_regression, combined_obj_indices)
selected_regression = tf.reshape(
selected_regression_flat, [num_instances, num_keypoints, 2])
(y_reg, x_reg) = tf.unstack(selected_regression, axis=2)
# shape: [num_instances, num_keypoints].
y_regressed = tf.cast(
tf.reshape(object_y_indices, [num_instances, 1]),
dtype=tf.float32) + y_reg
x_regressed = tf.cast(
tf.reshape(object_x_indices, [num_instances, 1]),
dtype=tf.float32) + x_reg
if kp_params.candidate_ranking_mode == 'gaussian_weighted_const':
rescored_heatmap = _gaussian_weighted_map_const_multi(
y_grid, x_grid, keypoint_heatmap, y_regressed, x_regressed,
tf.squeeze(boxes, axis=0), kp_params.gaussian_denom_ratio)
# shape: [height, width, num_keypoints].
keypoint_score_heatmap = tf.math.reduce_max(rescored_heatmap, axis=2)
else:
raise ValueError(
'Unsupported ranking mode in the multipose no topk method: %s' %
kp_params.candidate_ranking_mode)
(keypoint_candidates,
keypoint_scores) = prediction_tensors_to_multi_instance_kpts(
keypoint_heatmap_predictions=rescored_heatmap,
keypoint_heatmap_offsets=keypoint_offset,
keypoint_score_heatmap=keypoint_score_heatmap)
return keypoint_candidates, keypoint_scores
def regressed_keypoints_at_object_centers(regressed_keypoint_predictions,
y_indices, x_indices):
"""Returns the regressed keypoints at specified object centers.
The original keypoint predictions are regressed relative to each feature map
location. The returned keypoints are expressed in absolute coordinates in the
output frame (i.e. the center offsets are added to each individual regressed
set of keypoints).
Args:
regressed_keypoint_predictions: A float tensor of shape
[batch_size, height, width, 2 * num_keypoints] holding regressed
keypoints. The last dimension has keypoint coordinates ordered as follows:
[y0, x0, y1, x1, ..., y{J-1}, x{J-1}] where J is the number of keypoints.
y_indices: A [batch, num_instances] int tensor holding y indices for object
centers. These indices correspond to locations in the output feature map.
x_indices: A [batch, num_instances] int tensor holding x indices for object
centers. These indices correspond to locations in the output feature map.
Returns:
A float tensor of shape [batch_size, num_objects, 2 * num_keypoints] where
regressed keypoints are gathered at the provided locations, and converted
to absolute coordinates in the output coordinate frame.
"""
batch_size, num_instances = _get_shape(y_indices, 2)
# TF Lite does not support tf.gather with batch_dims > 0, so we need to use
# tf_gather_nd instead and here we prepare the indices for that.
combined_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_instances),
tf.reshape(y_indices, [-1]),
tf.reshape(x_indices, [-1])
], axis=1)
relative_regressed_keypoints = tf.gather_nd(regressed_keypoint_predictions,
combined_indices)
relative_regressed_keypoints = tf.reshape(
relative_regressed_keypoints,
[batch_size, num_instances, -1, 2])
relative_regressed_keypoints_y, relative_regressed_keypoints_x = tf.unstack(
relative_regressed_keypoints, axis=3)
y_indices = _to_float32(tf.expand_dims(y_indices, axis=-1))
x_indices = _to_float32(tf.expand_dims(x_indices, axis=-1))
absolute_regressed_keypoints = tf.stack(
[y_indices + relative_regressed_keypoints_y,
x_indices + relative_regressed_keypoints_x],
axis=3)
return tf.reshape(absolute_regressed_keypoints,
[batch_size, num_instances, -1])
def sdr_scaled_ranking_score(
keypoint_scores, distances, bboxes, score_distance_multiplier):
"""Score-to-distance-ratio method to rank keypoint candidates.
This corresponds to the ranking method: 'score_scaled_distance_ratio'. The
keypoint candidates are ranked using the formula:
ranking_score = score / (distance + offset)
where 'score' is the keypoint heatmap scores, 'distance' is the distance
between the heatmap peak location and the regressed joint location,
'offset' is a function of the predicted bounding box:
offset = max(bbox height, bbox width) * score_distance_multiplier
The ranking score is used to find the best keypoint candidate for snapping
regressed joints.
Args:
keypoint_scores: A float tensor of shape
[batch_size, max_candidates, num_keypoints] indicating the scores for
keypoint candidates.
distances: A float tensor of shape
[batch_size, num_instances, max_candidates, num_keypoints] indicating the
distances between the keypoint candidates and the joint regression
locations of each instances.
bboxes: A tensor of shape [batch_size, num_instances, 4] with predicted
bounding boxes for each instance, expressed in the output coordinate
frame. If not provided, boxes will be computed from regressed keypoints.
score_distance_multiplier: A scalar used to multiply the bounding box size
to be the offset in the score-to-distance-ratio formula.
Returns:
A float tensor of shape [batch_size, num_instances, max_candidates,
num_keypoints] representing the ranking scores of each keypoint candidates.
"""
# Get ymin, xmin, ymax, xmax bounding box coordinates.
# Shape: [batch_size, num_instances]
ymin, xmin, ymax, xmax = tf.unstack(bboxes, axis=2)
# Shape: [batch_size, num_instances].
offsets = tf.math.maximum(
ymax - ymin, xmax - xmin) * score_distance_multiplier
# Shape: [batch_size, num_instances, max_candidates, num_keypoints]
ranking_scores = keypoint_scores[:, tf.newaxis, :, :] / (
distances + offsets[:, :, tf.newaxis, tf.newaxis])
return ranking_scores
def gaussian_weighted_score(
keypoint_scores, distances, keypoint_std_dev, bboxes):
"""Gaussian weighted method to rank keypoint candidates.
This corresponds to the ranking method: 'gaussian_weighted'. The
keypoint candidates are ranked using the formula:
score * exp((-distances^2) / (2 * sigma^2))
where 'score' is the keypoint heatmap score, 'distances' is the distance
between the heatmap peak location and the regressed joint location and 'sigma'
is a Gaussian standard deviation used in generating the Gausian heatmap target
multiplied by the 'std_dev_multiplier'.
The ranking score is used to find the best keypoint candidate for snapping
regressed joints.
Args:
keypoint_scores: A float tensor of shape
[batch_size, max_candidates, num_keypoints] indicating the scores for
keypoint candidates.
distances: A float tensor of shape
[batch_size, num_instances, max_candidates, num_keypoints] indicating the
distances between the keypoint candidates and the joint regression
locations of each instances.
keypoint_std_dev: A list of float represent the standard deviation of the
Gaussian kernel used to generate the keypoint heatmap. It is to provide
the flexibility of using different sizes of Gaussian kernel for each
keypoint class.
bboxes: A tensor of shape [batch_size, num_instances, 4] with predicted
bounding boxes for each instance, expressed in the output coordinate
frame. If not provided, boxes will be computed from regressed keypoints.
Returns:
A float tensor of shape [batch_size, num_instances, max_candidates,
num_keypoints] representing the ranking scores of each keypoint candidates.
"""
# Get ymin, xmin, ymax, xmax bounding box coordinates.
# Shape: [batch_size, num_instances]
ymin, xmin, ymax, xmax = tf.unstack(bboxes, axis=2)
# shape: [num_keypoints]
keypoint_std_dev = tf.constant(keypoint_std_dev)
# shape: [batch_size, num_instances]
sigma = cn_assigner._compute_std_dev_from_box_size( # pylint: disable=protected-access
ymax - ymin, xmax - xmin, min_overlap=0.7)
# shape: [batch_size, num_instances, num_keypoints]
sigma = keypoint_std_dev[tf.newaxis, tf.newaxis, :] * sigma[:, :, tf.newaxis]
(_, _, max_candidates, _) = _get_shape(distances, 4)
# shape: [batch_size, num_instances, max_candidates, num_keypoints]
sigma = tf.tile(
sigma[:, :, tf.newaxis, :], multiples=[1, 1, max_candidates, 1])
gaussian_map = tf.exp((-1 * distances * distances) / (2 * sigma * sigma))
return keypoint_scores[:, tf.newaxis, :, :] * gaussian_map
def refine_keypoints(regressed_keypoints,
keypoint_candidates,
keypoint_scores,
num_keypoint_candidates,
bboxes=None,
unmatched_keypoint_score=0.1,
box_scale=1.2,
candidate_search_scale=0.3,
candidate_ranking_mode='min_distance',
score_distance_offset=1e-6,
keypoint_depth_candidates=None,
keypoint_score_threshold=0.1,
score_distance_multiplier=0.1,
keypoint_std_dev=None):
"""Refines regressed keypoints by snapping to the nearest candidate keypoints.
The initial regressed keypoints represent a full set of keypoints regressed
from the centers of the objects. The keypoint candidates are estimated
independently from heatmaps, and are not associated with any object instances.
This function refines the regressed keypoints by "snapping" to the
nearest/highest score/highest score-distance ratio (depending on the
candidate_ranking_mode) candidate of the same keypoint type (e.g. "nose").
If no candidates are nearby, the regressed keypoint remains unchanged.
In order to snap a regressed keypoint to a candidate keypoint, the following
must be satisfied:
- the candidate keypoint must be of the same type as the regressed keypoint
- the candidate keypoint must not lie outside the predicted boxes (or the
boxes which encloses the regressed keypoints for the instance if `bboxes` is
not provided). Note that the box is scaled by
`regressed_box_scale` in height and width, to provide some margin around the
keypoints
- the distance to the closest candidate keypoint cannot exceed
candidate_search_scale * max(height, width), where height and width refer to
the bounding box for the instance.
Note that the same candidate keypoint is allowed to snap to regressed
keypoints in difference instances.
Args:
regressed_keypoints: A float tensor of shape
[batch_size, num_instances, num_keypoints, 2] with the initial regressed
keypoints.
keypoint_candidates: A tensor of shape
[batch_size, max_candidates, num_keypoints, 2] holding the location of
keypoint candidates in [y, x] format (expressed in absolute coordinates in
the output coordinate frame).
keypoint_scores: A float tensor of shape
[batch_size, max_candidates, num_keypoints] indicating the scores for
keypoint candidates.
num_keypoint_candidates: An integer tensor of shape
[batch_size, num_keypoints] indicating the number of valid candidates for
each keypoint type, as there may be padding (dim 1) of
`keypoint_candidates` and `keypoint_scores`.
bboxes: A tensor of shape [batch_size, num_instances, 4] with predicted
bounding boxes for each instance, expressed in the output coordinate
frame. If not provided, boxes will be computed from regressed keypoints.
unmatched_keypoint_score: float, the default score to use for regressed
keypoints that are not successfully snapped to a nearby candidate.
box_scale: float, the multiplier to expand the bounding boxes (either the
provided boxes or those which tightly cover the regressed keypoints) for
an instance. This scale is typically larger than 1.0 when not providing
`bboxes`.
candidate_search_scale: float, the scale parameter that multiplies the
largest dimension of a bounding box. The resulting distance becomes a
search radius for candidates in the vicinity of each regressed keypoint.
candidate_ranking_mode: A string as one of ['min_distance',
'score_distance_ratio', 'score_scaled_distance_ratio',
'gaussian_weighted'] indicating how to select the candidate. If invalid
value is provided, an ValueError will be raised.
score_distance_offset: The distance offset to apply in the denominator when
candidate_ranking_mode is 'score_distance_ratio'. The metric to maximize
in this scenario is score / (distance + score_distance_offset). Larger
values of score_distance_offset make the keypoint score gain more relative
importance.
keypoint_depth_candidates: (optional) A float tensor of shape
[batch_size, max_candidates, num_keypoints] indicating the depths for
keypoint candidates.
keypoint_score_threshold: float, The heatmap score threshold for
a keypoint to become a valid candidate.
score_distance_multiplier: A scalar used to multiply the bounding box size
to be the offset in the score-to-distance-ratio formula.
keypoint_std_dev: A list of float represent the standard deviation of the
Gaussian kernel used to rank the keypoint candidates. It offers the
flexibility of using different sizes of Gaussian kernel for each keypoint
class. Only applicable when the candidate_ranking_mode equals to
'gaussian_weighted'.
Returns:
A tuple with:
refined_keypoints: A float tensor of shape
[batch_size, num_instances, num_keypoints, 2] with the final, refined
keypoints.
refined_scores: A float tensor of shape
[batch_size, num_instances, num_keypoints] with scores associated with all
instances and keypoints in `refined_keypoints`.
Raises:
ValueError: if provided candidate_ranking_mode is not one of
['min_distance', 'score_distance_ratio']
"""
batch_size, num_instances, num_keypoints, _ = (
shape_utils.combined_static_and_dynamic_shape(regressed_keypoints))
max_candidates = keypoint_candidates.shape[1]
# Replace all invalid (i.e. padded) keypoint candidates with NaN.
# This will prevent them from being considered.
range_tiled = tf.tile(
tf.reshape(tf.range(max_candidates), [1, max_candidates, 1]),
[batch_size, 1, num_keypoints])
num_candidates_tiled = tf.tile(tf.expand_dims(num_keypoint_candidates, 1),
[1, max_candidates, 1])
invalid_candidates = range_tiled >= num_candidates_tiled
# Pairwise squared distances between regressed keypoints and candidate
# keypoints (for a single keypoint type).
# Shape [batch_size, num_instances, 1, num_keypoints, 2].
regressed_keypoint_expanded = tf.expand_dims(regressed_keypoints,
axis=2)
# Shape [batch_size, 1, max_candidates, num_keypoints, 2].
keypoint_candidates_expanded = tf.expand_dims(
keypoint_candidates, axis=1)
# Use explicit tensor shape broadcasting (since the tensor dimensions are
# expanded to 5D) to make it tf.lite compatible.
regressed_keypoint_expanded = tf.tile(
regressed_keypoint_expanded, multiples=[1, 1, max_candidates, 1, 1])
keypoint_candidates_expanded = tf.tile(
keypoint_candidates_expanded, multiples=[1, num_instances, 1, 1, 1])
# Replace tf.math.squared_difference by "-" operator and tf.multiply ops since
# tf.lite convert doesn't support squared_difference with undetermined
# dimension.
diff = regressed_keypoint_expanded - keypoint_candidates_expanded
sqrd_distances = tf.math.reduce_sum(tf.multiply(diff, diff), axis=-1)
distances = tf.math.sqrt(sqrd_distances)
# Replace the invalid candidated with large constant (10^5) to make sure the
# following reduce_min/argmin behaves properly.
max_dist = 1e5
distances = tf.where(
tf.tile(
tf.expand_dims(invalid_candidates, axis=1),
multiples=[1, num_instances, 1, 1]),
tf.ones_like(distances) * max_dist,
distances
)
# Determine the candidates that have the minimum distance to the regressed
# keypoints. Shape [batch_size, num_instances, num_keypoints].
min_distances = tf.math.reduce_min(distances, axis=2)
if candidate_ranking_mode == 'min_distance':
nearby_candidate_inds = tf.math.argmin(distances, axis=2)
elif candidate_ranking_mode == 'score_distance_ratio':
# tiled_keypoint_scores:
# Shape [batch_size, num_instances, max_candidates, num_keypoints].
tiled_keypoint_scores = tf.tile(
tf.expand_dims(keypoint_scores, axis=1),
multiples=[1, num_instances, 1, 1],
)
ranking_scores = tiled_keypoint_scores / (distances + score_distance_offset)
nearby_candidate_inds = tf.math.argmax(
ranking_scores, axis=2, output_type=tf.int32
)
elif candidate_ranking_mode == 'score_scaled_distance_ratio':
ranking_scores = sdr_scaled_ranking_score(
keypoint_scores, distances, bboxes, score_distance_multiplier
)
nearby_candidate_inds = tf.math.argmax(
ranking_scores, axis=2, output_type=tf.int32
)
elif candidate_ranking_mode == 'gaussian_weighted':
ranking_scores = gaussian_weighted_score(
keypoint_scores, distances, keypoint_std_dev, bboxes
)
nearby_candidate_inds = tf.math.argmax(
ranking_scores, axis=2, output_type=tf.int32
)
weighted_scores = tf.math.reduce_max(ranking_scores, axis=2)
else:
raise ValueError(
'Not recognized candidate_ranking_mode: %s' % candidate_ranking_mode
)
# Gather the coordinates and scores corresponding to the closest candidates.
# Shape of tensors are [batch_size, num_instances, num_keypoints, 2] and
# [batch_size, num_instances, num_keypoints], respectively.
(nearby_candidate_coords, nearby_candidate_scores,
nearby_candidate_depths) = (
_gather_candidates_at_indices(keypoint_candidates, keypoint_scores,
nearby_candidate_inds,
keypoint_depth_candidates))
# If the ranking mode is 'gaussian_weighted', we use the ranking scores as the
# final keypoint confidence since their values are in between [0, 1].
if candidate_ranking_mode == 'gaussian_weighted':
nearby_candidate_scores = weighted_scores
if bboxes is None:
# Filter out the chosen candidate with score lower than unmatched
# keypoint score.
mask = tf.cast(nearby_candidate_scores <
keypoint_score_threshold, tf.int32)
else:
bboxes_flattened = tf.reshape(bboxes, [-1, 4])
# Scale the bounding boxes.
# Shape [batch_size, num_instances, 4].
boxlist = box_list.BoxList(bboxes_flattened)
boxlist_scaled = box_list_ops.scale_height_width(
boxlist, box_scale, box_scale)
bboxes_scaled = boxlist_scaled.get()
bboxes = tf.reshape(bboxes_scaled, [batch_size, num_instances, 4])
# Get ymin, xmin, ymax, xmax bounding box coordinates, tiled per keypoint.
# Shape [batch_size, num_instances, num_keypoints].
bboxes_tiled = tf.tile(tf.expand_dims(bboxes, 2), [1, 1, num_keypoints, 1])
ymin, xmin, ymax, xmax = tf.unstack(bboxes_tiled, axis=3)
# Produce a mask that indicates whether the original regressed keypoint
# should be used instead of a candidate keypoint.
# Shape [batch_size, num_instances, num_keypoints].
search_radius = (
tf.math.maximum(ymax - ymin, xmax - xmin) * candidate_search_scale)
mask = (tf.cast(nearby_candidate_coords[:, :, :, 0] < ymin, tf.int32) +
tf.cast(nearby_candidate_coords[:, :, :, 0] > ymax, tf.int32) +
tf.cast(nearby_candidate_coords[:, :, :, 1] < xmin, tf.int32) +
tf.cast(nearby_candidate_coords[:, :, :, 1] > xmax, tf.int32) +
# Filter out the chosen candidate with score lower than unmatched
# keypoint score.
tf.cast(nearby_candidate_scores <
keypoint_score_threshold, tf.int32) +
tf.cast(min_distances > search_radius, tf.int32))
mask = mask > 0
# Create refined keypoints where candidate keypoints replace original
# regressed keypoints if they are in the vicinity of the regressed keypoints.
# Shape [batch_size, num_instances, num_keypoints, 2].
refined_keypoints = tf.where(
tf.tile(tf.expand_dims(mask, -1), [1, 1, 1, 2]),
regressed_keypoints,
nearby_candidate_coords)
# Update keypoints scores. In the case where we use the original regressed
# keypoints, we use a default score of `unmatched_keypoint_score`.
# Shape [batch_size, num_instances, num_keypoints].
refined_scores = tf.where(
mask,
unmatched_keypoint_score * tf.ones_like(nearby_candidate_scores),
nearby_candidate_scores)
refined_depths = None
if nearby_candidate_depths is not None:
refined_depths = tf.where(mask, tf.zeros_like(nearby_candidate_depths),
nearby_candidate_depths)
return refined_keypoints, refined_scores, refined_depths
def _pad_to_full_keypoint_dim(keypoint_coords, keypoint_scores, keypoint_inds,
num_total_keypoints):
"""Scatter keypoint elements into tensors with full keypoints dimension.
Args:
keypoint_coords: a [batch_size, num_instances, num_keypoints, 2] float32
tensor.
keypoint_scores: a [batch_size, num_instances, num_keypoints] float32
tensor.
keypoint_inds: a list of integers that indicate the keypoint indices for
this specific keypoint class. These indices are used to scatter into
tensors that have a `num_total_keypoints` dimension.
num_total_keypoints: The total number of keypoints that this model predicts.
Returns:
A tuple with
keypoint_coords_padded: a
[batch_size, num_instances, num_total_keypoints,2] float32 tensor.
keypoint_scores_padded: a [batch_size, num_instances, num_total_keypoints]
float32 tensor.
"""
batch_size, num_instances, _, _ = (
shape_utils.combined_static_and_dynamic_shape(keypoint_coords))
kpt_coords_transposed = tf.transpose(keypoint_coords, [2, 0, 1, 3])
kpt_scores_transposed = tf.transpose(keypoint_scores, [2, 0, 1])
kpt_inds_tensor = tf.expand_dims(keypoint_inds, axis=-1)
kpt_coords_scattered = tf.scatter_nd(
indices=kpt_inds_tensor,
updates=kpt_coords_transposed,
shape=[num_total_keypoints, batch_size, num_instances, 2])
kpt_scores_scattered = tf.scatter_nd(
indices=kpt_inds_tensor,
updates=kpt_scores_transposed,
shape=[num_total_keypoints, batch_size, num_instances])
keypoint_coords_padded = tf.transpose(kpt_coords_scattered, [1, 2, 0, 3])
keypoint_scores_padded = tf.transpose(kpt_scores_scattered, [1, 2, 0])
return keypoint_coords_padded, keypoint_scores_padded
def _pad_to_full_instance_dim(keypoint_coords, keypoint_scores, instance_inds,
max_instances):
"""Scatter keypoint elements into tensors with full instance dimension.
Args:
keypoint_coords: a [batch_size, num_instances, num_keypoints, 2] float32
tensor.
keypoint_scores: a [batch_size, num_instances, num_keypoints] float32
tensor.
instance_inds: a list of integers that indicate the instance indices for
these keypoints. These indices are used to scatter into tensors
that have a `max_instances` dimension.
max_instances: The maximum number of instances detected by the model.
Returns:
A tuple with
keypoint_coords_padded: a [batch_size, max_instances, num_keypoints, 2]
float32 tensor.
keypoint_scores_padded: a [batch_size, max_instances, num_keypoints]
float32 tensor.
"""
batch_size, _, num_keypoints, _ = (
shape_utils.combined_static_and_dynamic_shape(keypoint_coords))
kpt_coords_transposed = tf.transpose(keypoint_coords, [1, 0, 2, 3])
kpt_scores_transposed = tf.transpose(keypoint_scores, [1, 0, 2])
instance_inds = tf.expand_dims(instance_inds, axis=-1)
kpt_coords_scattered = tf.scatter_nd(
indices=instance_inds,
updates=kpt_coords_transposed,
shape=[max_instances, batch_size, num_keypoints, 2])
kpt_scores_scattered = tf.scatter_nd(
indices=instance_inds,
updates=kpt_scores_transposed,
shape=[max_instances, batch_size, num_keypoints])
keypoint_coords_padded = tf.transpose(kpt_coords_scattered, [1, 0, 2, 3])
keypoint_scores_padded = tf.transpose(kpt_scores_scattered, [1, 0, 2])
return keypoint_coords_padded, keypoint_scores_padded
def _gather_candidates_at_indices(keypoint_candidates,
keypoint_scores,
indices,
keypoint_depth_candidates=None):
"""Gathers keypoint candidate coordinates and scores at indices.
Args:
keypoint_candidates: a float tensor of shape [batch_size, max_candidates,
num_keypoints, 2] with candidate coordinates.
keypoint_scores: a float tensor of shape [batch_size, max_candidates,
num_keypoints] with keypoint scores.
indices: an integer tensor of shape [batch_size, num_indices, num_keypoints]
with indices.
keypoint_depth_candidates: (optional) a float tensor of shape [batch_size,
max_candidates, num_keypoints] with keypoint depths.
Returns:
A tuple with
gathered_keypoint_candidates: a float tensor of shape [batch_size,
num_indices, num_keypoints, 2] with gathered coordinates.
gathered_keypoint_scores: a float tensor of shape [batch_size,
num_indices, num_keypoints].
gathered_keypoint_depths: a float tensor of shape [batch_size,
num_indices, num_keypoints]. Return None if the input
keypoint_depth_candidates is None.
"""
batch_size, num_indices, num_keypoints = _get_shape(indices, 3)
# Transpose tensors so that all batch dimensions are up front.
keypoint_candidates_transposed = tf.transpose(keypoint_candidates,
[0, 2, 1, 3])
keypoint_scores_transposed = tf.transpose(keypoint_scores, [0, 2, 1])
nearby_candidate_inds_transposed = tf.transpose(indices, [0, 2, 1])
# TF Lite does not support tf.gather with batch_dims > 0, so we need to use
# tf_gather_nd instead and here we prepare the indices for that.
combined_indices = tf.stack([
_multi_range(
batch_size,
value_repetitions=num_keypoints * num_indices),
_multi_range(
num_keypoints,
value_repetitions=num_indices,
range_repetitions=batch_size),
tf.reshape(tf.cast(nearby_candidate_inds_transposed, tf.int32), [-1])
], axis=1)
nearby_candidate_coords_transposed = tf.gather_nd(
keypoint_candidates_transposed, combined_indices)
nearby_candidate_coords_transposed = tf.reshape(
nearby_candidate_coords_transposed,
[batch_size, num_keypoints, num_indices, -1])
nearby_candidate_scores_transposed = tf.gather_nd(keypoint_scores_transposed,
combined_indices)
nearby_candidate_scores_transposed = tf.reshape(
nearby_candidate_scores_transposed,
[batch_size, num_keypoints, num_indices])
gathered_keypoint_candidates = tf.transpose(
nearby_candidate_coords_transposed, [0, 2, 1, 3])
# The reshape operation above may result in a singleton last dimension, but
# downstream code requires it to always be at least 2-valued.
original_shape = tf.shape(gathered_keypoint_candidates)
new_shape = tf.concat((original_shape[:3],
[tf.maximum(original_shape[3], 2)]), 0)
gathered_keypoint_candidates = tf.reshape(gathered_keypoint_candidates,
new_shape)
gathered_keypoint_scores = tf.transpose(nearby_candidate_scores_transposed,
[0, 2, 1])
gathered_keypoint_depths = None
if keypoint_depth_candidates is not None:
keypoint_depths_transposed = tf.transpose(keypoint_depth_candidates,
[0, 2, 1])
nearby_candidate_depths_transposed = tf.gather_nd(
keypoint_depths_transposed, combined_indices)
nearby_candidate_depths_transposed = tf.reshape(
nearby_candidate_depths_transposed,
[batch_size, num_keypoints, num_indices])
gathered_keypoint_depths = tf.transpose(nearby_candidate_depths_transposed,
[0, 2, 1])
return (gathered_keypoint_candidates, gathered_keypoint_scores,
gathered_keypoint_depths)
def flattened_indices_from_row_col_indices(row_indices, col_indices, num_cols):
"""Get the index in a flattened array given row and column indices."""
return (row_indices * num_cols) + col_indices
def row_col_channel_indices_from_flattened_indices(indices, num_cols,
num_channels):
"""Computes row, column and channel indices from flattened indices.
Args:
indices: An integer tensor of any shape holding the indices in the flattened
space.
num_cols: Number of columns in the image (width).
num_channels: Number of channels in the image.
Returns:
row_indices: The row indices corresponding to each of the input indices.
Same shape as indices.
col_indices: The column indices corresponding to each of the input indices.
Same shape as indices.
channel_indices. The channel indices corresponding to each of the input
indices.
"""
# Be careful with this function when running a model in float16 precision
# (e.g. TF.js with WebGL) because the array indices may not be represented
# accurately if they are too large, resulting in incorrect channel indices.
# See:
# https://en.wikipedia.org/wiki/Half-precision_floating-point_format#Precision_limitations_on_integer_values
#
# Avoid using mod operator to make the ops more easy to be compatible with
# different environments, e.g. WASM.
row_indices = (indices // num_channels) // num_cols
col_indices = (indices // num_channels) - row_indices * num_cols
channel_indices_temp = indices // num_channels
channel_indices = indices - channel_indices_temp * num_channels
return row_indices, col_indices, channel_indices
def row_col_indices_from_flattened_indices(indices, num_cols):
"""Computes row and column indices from flattened indices.
Args:
indices: An integer tensor of any shape holding the indices in the flattened
space.
num_cols: Number of columns in the image (width).
Returns:
row_indices: The row indices corresponding to each of the input indices.
Same shape as indices.
col_indices: The column indices corresponding to each of the input indices.
Same shape as indices.
"""
# Avoid using mod operator to make the ops more easy to be compatible with
# different environments, e.g. WASM.
row_indices = indices // num_cols
col_indices = indices - row_indices * num_cols
return row_indices, col_indices
def get_valid_anchor_weights_in_flattened_image(true_image_shapes, height,
width):
"""Computes valid anchor weights for an image assuming pixels will be flattened.
This function is useful when we only want to penalize valid areas in the
image in the case when padding is used. The function assumes that the loss
function will be applied after flattening the spatial dimensions and returns
anchor weights accordingly.
Args:
true_image_shapes: An integer tensor of shape [batch_size, 3] representing
the true image shape (without padding) for each sample in the batch.
height: height of the prediction from the network.
width: width of the prediction from the network.
Returns:
valid_anchor_weights: a float tensor of shape [batch_size, height * width]
with 1s in locations where the spatial coordinates fall within the height
and width in true_image_shapes.
"""
indices = tf.reshape(tf.range(height * width), [1, -1])
batch_size = tf.shape(true_image_shapes)[0]
batch_indices = tf.ones((batch_size, 1), dtype=tf.int32) * indices
y_coords, x_coords, _ = row_col_channel_indices_from_flattened_indices(
batch_indices, width, 1)
max_y, max_x = true_image_shapes[:, 0], true_image_shapes[:, 1]
max_x = _to_float32(tf.expand_dims(max_x, 1))
max_y = _to_float32(tf.expand_dims(max_y, 1))
x_coords = _to_float32(x_coords)
y_coords = _to_float32(y_coords)
valid_mask = tf.math.logical_and(x_coords < max_x, y_coords < max_y)
return _to_float32(valid_mask)
def convert_strided_predictions_to_normalized_boxes(boxes, stride,
true_image_shapes):
"""Converts predictions in the output space to normalized boxes.
Boxes falling outside the valid image boundary are clipped to be on the
boundary.
Args:
boxes: A tensor of shape [batch_size, num_boxes, 4] holding the raw
coordinates of boxes in the model's output space.
stride: The stride in the output space.
true_image_shapes: A tensor of shape [batch_size, 3] representing the true
shape of the input not considering padding.
Returns:
boxes: A tensor of shape [batch_size, num_boxes, 4] representing the
coordinates of the normalized boxes.
"""
# Note: We use tf ops instead of functions in box_list_ops to make this
# function compatible with dynamic batch size.
boxes = boxes * stride
true_image_shapes = tf.tile(true_image_shapes[:, tf.newaxis, :2], [1, 1, 2])
boxes = boxes / tf.cast(true_image_shapes, tf.float32)
boxes = tf.clip_by_value(boxes, 0.0, 1.0)
return boxes
def convert_strided_predictions_to_normalized_keypoints(
keypoint_coords, keypoint_scores, stride, true_image_shapes,
clip_out_of_frame_keypoints=False):
"""Converts predictions in the output space to normalized keypoints.
If clip_out_of_frame_keypoints=False, keypoint coordinates falling outside
the valid image boundary are normalized but not clipped; If
clip_out_of_frame_keypoints=True, keypoint coordinates falling outside the
valid image boundary are clipped to the closest image boundary and the scores
will be set to 0.0.
Args:
keypoint_coords: A tensor of shape
[batch_size, num_instances, num_keypoints, 2] holding the raw coordinates
of keypoints in the model's output space.
keypoint_scores: A tensor of shape
[batch_size, num_instances, num_keypoints] holding the keypoint scores.
stride: The stride in the output space.
true_image_shapes: A tensor of shape [batch_size, 3] representing the true
shape of the input not considering padding.
clip_out_of_frame_keypoints: A boolean indicating whether keypoints outside
the image boundary should be clipped. If True, keypoint coords will be
clipped to image boundary. If False, keypoints are normalized but not
filtered based on their location.
Returns:
keypoint_coords_normalized: A tensor of shape
[batch_size, num_instances, num_keypoints, 2] representing the coordinates
of the normalized keypoints.
keypoint_scores: A tensor of shape
[batch_size, num_instances, num_keypoints] representing the updated
keypoint scores.
"""
# Flatten keypoints and scores.
batch_size, _, _, _ = (
shape_utils.combined_static_and_dynamic_shape(keypoint_coords))
# Scale and normalize keypoints.
true_heights, true_widths, _ = tf.unstack(true_image_shapes, axis=1)
yscale = float(stride) / tf.cast(true_heights, tf.float32)
xscale = float(stride) / tf.cast(true_widths, tf.float32)
yx_scale = tf.stack([yscale, xscale], axis=1)
keypoint_coords_normalized = keypoint_coords * tf.reshape(
yx_scale, [batch_size, 1, 1, 2])
if clip_out_of_frame_keypoints:
# Determine the keypoints that are in the true image regions.
valid_indices = tf.logical_and(
tf.logical_and(keypoint_coords_normalized[:, :, :, 0] >= 0.0,
keypoint_coords_normalized[:, :, :, 0] <= 1.0),
tf.logical_and(keypoint_coords_normalized[:, :, :, 1] >= 0.0,
keypoint_coords_normalized[:, :, :, 1] <= 1.0))
batch_window = tf.tile(
tf.constant([[0.0, 0.0, 1.0, 1.0]], dtype=tf.float32),
multiples=[batch_size, 1])
def clip_to_window(inputs):
keypoints, window = inputs
return keypoint_ops.clip_to_window(keypoints, window)
keypoint_coords_normalized = shape_utils.static_or_dynamic_map_fn(
clip_to_window, [keypoint_coords_normalized, batch_window],
dtype=tf.float32, back_prop=False)
keypoint_scores = tf.where(valid_indices, keypoint_scores,
tf.zeros_like(keypoint_scores))
return keypoint_coords_normalized, keypoint_scores
def convert_strided_predictions_to_instance_masks(
boxes, classes, masks, true_image_shapes,
densepose_part_heatmap=None, densepose_surface_coords=None, stride=4,
mask_height=256, mask_width=256, score_threshold=0.5,
densepose_class_index=-1):
"""Converts predicted full-image masks into instance masks.
For each predicted detection box:
* Crop and resize the predicted mask (and optionally DensePose coordinates)
based on the detected bounding box coordinates and class prediction. Uses
bilinear resampling.
* Binarize the mask using the provided score threshold.
Args:
boxes: A tensor of shape [batch, max_detections, 4] holding the predicted
boxes, in normalized coordinates (relative to the true image dimensions).
classes: An integer tensor of shape [batch, max_detections] containing the
detected class for each box (0-indexed).
masks: A [batch, output_height, output_width, num_classes] float32
tensor with class probabilities.
true_image_shapes: A tensor of shape [batch, 3] representing the true
shape of the inputs not considering padding.
densepose_part_heatmap: (Optional) A [batch, output_height, output_width,
num_parts] float32 tensor with part scores (i.e. logits).
densepose_surface_coords: (Optional) A [batch, output_height, output_width,
2 * num_parts] float32 tensor with predicted part coordinates (in
vu-format).
stride: The stride in the output space.
mask_height: The desired resized height for instance masks.
mask_width: The desired resized width for instance masks.
score_threshold: The threshold at which to convert predicted mask
into foreground pixels.
densepose_class_index: The class index (0-indexed) corresponding to the
class which has DensePose labels (e.g. person class).
Returns:
A tuple of masks and surface_coords.
instance_masks: A [batch_size, max_detections, mask_height, mask_width]
uint8 tensor with predicted foreground mask for each
instance. If DensePose tensors are provided, then each pixel value in the
mask encodes the 1-indexed part.
surface_coords: A [batch_size, max_detections, mask_height, mask_width, 2]
float32 tensor with (v, u) coordinates. Note that v, u coordinates are
only defined on instance masks, and the coordinates at each location of
the foreground mask correspond to coordinates on a local part coordinate
system (the specific part can be inferred from the `instance_masks`
output. If DensePose feature maps are not passed to this function, this
output will be None.
Raises:
ValueError: If one but not both of `densepose_part_heatmap` and
`densepose_surface_coords` is provided.
"""
batch_size, output_height, output_width, _ = (
shape_utils.combined_static_and_dynamic_shape(masks))
input_height = stride * output_height
input_width = stride * output_width
true_heights, true_widths, _ = tf.unstack(true_image_shapes, axis=1)
# If necessary, create dummy DensePose tensors to simplify the map function.
densepose_present = True
if ((densepose_part_heatmap is not None) ^
(densepose_surface_coords is not None)):
raise ValueError('To use DensePose, both `densepose_part_heatmap` and '
'`densepose_surface_coords` must be provided')
if densepose_part_heatmap is None and densepose_surface_coords is None:
densepose_present = False
densepose_part_heatmap = tf.zeros(
(batch_size, output_height, output_width, 1), dtype=tf.float32)
densepose_surface_coords = tf.zeros(
(batch_size, output_height, output_width, 2), dtype=tf.float32)
crop_and_threshold_fn = functools.partial(
crop_and_threshold_masks, input_height=input_height,
input_width=input_width, mask_height=mask_height, mask_width=mask_width,
score_threshold=score_threshold,
densepose_class_index=densepose_class_index)
instance_masks, surface_coords = shape_utils.static_or_dynamic_map_fn(
crop_and_threshold_fn,
elems=[boxes, classes, masks, densepose_part_heatmap,
densepose_surface_coords, true_heights, true_widths],
dtype=[tf.uint8, tf.float32],
back_prop=False)
surface_coords = surface_coords if densepose_present else None
return instance_masks, surface_coords
def crop_and_threshold_masks(elems, input_height, input_width, mask_height=256,
mask_width=256, score_threshold=0.5,
densepose_class_index=-1):
"""Crops and thresholds masks based on detection boxes.
Args:
elems: A tuple of
boxes - float32 tensor of shape [max_detections, 4]
classes - int32 tensor of shape [max_detections] (0-indexed)
masks - float32 tensor of shape [output_height, output_width, num_classes]
part_heatmap - float32 tensor of shape [output_height, output_width,
num_parts]
surf_coords - float32 tensor of shape [output_height, output_width,
2 * num_parts]
true_height - scalar int tensor
true_width - scalar int tensor
input_height: Input height to network.
input_width: Input width to network.
mask_height: Height for resizing mask crops.
mask_width: Width for resizing mask crops.
score_threshold: The threshold at which to convert predicted mask
into foreground pixels.
densepose_class_index: scalar int tensor with the class index (0-indexed)
for DensePose.
Returns:
A tuple of
all_instances: A [max_detections, mask_height, mask_width] uint8 tensor
with a predicted foreground mask for each instance. Background is encoded
as 0, and foreground is encoded as a positive integer. Specific part
indices are encoded as 1-indexed parts (for classes that have part
information).
surface_coords: A [max_detections, mask_height, mask_width, 2]
float32 tensor with (v, u) coordinates. for each part.
"""
(boxes, classes, masks, part_heatmap, surf_coords, true_height,
true_width) = elems
# Boxes are in normalized coordinates relative to true image shapes. Convert
# coordinates to be normalized relative to input image shapes (since masks
# may still have padding).
boxlist = box_list.BoxList(boxes)
y_scale = true_height / input_height
x_scale = true_width / input_width
boxlist = box_list_ops.scale(boxlist, y_scale, x_scale)
boxes = boxlist.get()
# Convert masks from [output_height, output_width, num_classes] to
# [num_classes, output_height, output_width, 1].
num_classes = tf.shape(masks)[-1]
masks_4d = tf.transpose(masks, perm=[2, 0, 1])[:, :, :, tf.newaxis]
# Tile part and surface coordinate masks for all classes.
part_heatmap_4d = tf.tile(part_heatmap[tf.newaxis, :, :, :],
multiples=[num_classes, 1, 1, 1])
surf_coords_4d = tf.tile(surf_coords[tf.newaxis, :, :, :],
multiples=[num_classes, 1, 1, 1])
feature_maps_concat = tf.concat([masks_4d, part_heatmap_4d, surf_coords_4d],
axis=-1)
# The following tensor has shape
# [max_detections, mask_height, mask_width, 1 + 3 * num_parts].
cropped_masks = tf2.image.crop_and_resize(
feature_maps_concat,
boxes=boxes,
box_indices=classes,
crop_size=[mask_height, mask_width],
method='bilinear')
# Split the cropped masks back into instance masks, part masks, and surface
# coordinates.
num_parts = tf.shape(part_heatmap)[-1]
instance_masks, part_heatmap_cropped, surface_coords_cropped = tf.split(
cropped_masks, [1, num_parts, 2 * num_parts], axis=-1)
# Threshold the instance masks. Resulting tensor has shape
# [max_detections, mask_height, mask_width, 1].
instance_masks_int = tf.cast(
tf.math.greater_equal(instance_masks, score_threshold), dtype=tf.int32)
# Produce a binary mask that is 1.0 only:
# - in the foreground region for an instance
# - in detections corresponding to the DensePose class
det_with_parts = tf.equal(classes, densepose_class_index)
det_with_parts = tf.cast(
tf.reshape(det_with_parts, [-1, 1, 1, 1]), dtype=tf.int32)
instance_masks_with_parts = tf.math.multiply(instance_masks_int,
det_with_parts)
# Similarly, produce a binary mask that holds the foreground masks only for
# instances without parts (i.e. non-DensePose classes).
det_without_parts = 1 - det_with_parts
instance_masks_without_parts = tf.math.multiply(instance_masks_int,
det_without_parts)
# Assemble a tensor that has standard instance segmentation masks for
# non-DensePose classes (with values in [0, 1]), and part segmentation masks
# for DensePose classes (with vaues in [0, 1, ..., num_parts]).
part_mask_int_zero_indexed = tf.math.argmax(
part_heatmap_cropped, axis=-1, output_type=tf.int32)[:, :, :, tf.newaxis]
part_mask_int_one_indexed = part_mask_int_zero_indexed + 1
all_instances = (instance_masks_without_parts +
instance_masks_with_parts * part_mask_int_one_indexed)
# Gather the surface coordinates for the parts.
surface_coords_cropped = tf.reshape(
surface_coords_cropped, [-1, mask_height, mask_width, num_parts, 2])
surface_coords = gather_surface_coords_for_parts(surface_coords_cropped,
part_mask_int_zero_indexed)
surface_coords = (
surface_coords * tf.cast(instance_masks_with_parts, tf.float32))
return [tf.squeeze(all_instances, axis=3), surface_coords]
def gather_surface_coords_for_parts(surface_coords_cropped,
highest_scoring_part):
"""Gathers the (v, u) coordinates for the highest scoring DensePose parts.
Args:
surface_coords_cropped: A [max_detections, height, width, num_parts, 2]
float32 tensor with (v, u) surface coordinates.
highest_scoring_part: A [max_detections, height, width] integer tensor with
the highest scoring part (0-indexed) indices for each location.
Returns:
A [max_detections, height, width, 2] float32 tensor with the (v, u)
coordinates selected from the highest scoring parts.
"""
max_detections, height, width, num_parts, _ = (
shape_utils.combined_static_and_dynamic_shape(surface_coords_cropped))
flattened_surface_coords = tf.reshape(surface_coords_cropped, [-1, 2])
flattened_part_ids = tf.reshape(highest_scoring_part, [-1])
# Produce lookup indices that represent the locations of the highest scoring
# parts in the `flattened_surface_coords` tensor.
flattened_lookup_indices = (
num_parts * tf.range(max_detections * height * width) +
flattened_part_ids)
vu_coords_flattened = tf.gather(flattened_surface_coords,
flattened_lookup_indices, axis=0)
return tf.reshape(vu_coords_flattened, [max_detections, height, width, 2])
def predicted_embeddings_at_object_centers(embedding_predictions,
y_indices, x_indices):
"""Returns the predicted embeddings at specified object centers.
Args:
embedding_predictions: A float tensor of shape [batch_size, height, width,
reid_embed_size] holding predicted embeddings.
y_indices: A [batch, num_instances] int tensor holding y indices for object
centers. These indices correspond to locations in the output feature map.
x_indices: A [batch, num_instances] int tensor holding x indices for object
centers. These indices correspond to locations in the output feature map.
Returns:
A float tensor of shape [batch_size, num_objects, reid_embed_size] where
predicted embeddings are gathered at the provided locations.
"""
batch_size, _, width, _ = _get_shape(embedding_predictions, 4)
flattened_indices = flattened_indices_from_row_col_indices(
y_indices, x_indices, width)
_, num_instances = _get_shape(flattened_indices, 2)
embeddings_flat = _flatten_spatial_dimensions(embedding_predictions)
embeddings = tf.gather(embeddings_flat, flattened_indices, batch_dims=1)
embeddings = tf.reshape(embeddings, [batch_size, num_instances, -1])
return embeddings
def mask_from_true_image_shape(data_shape, true_image_shapes):
"""Get a binary mask based on the true_image_shape.
Args:
data_shape: a possibly static (4,) tensor for the shape of the feature
map.
true_image_shapes: int32 tensor of shape [batch, 3] where each row is of
the form [height, width, channels] indicating the shapes of true
images in the resized images, as resized images can be padded with
zeros.
Returns:
a [batch, data_height, data_width, 1] tensor of 1.0 wherever data_height
is less than height, etc.
"""
mask_h = tf.cast(
tf.range(data_shape[1]) < true_image_shapes[:, tf.newaxis, 0],
tf.float32)
mask_w = tf.cast(
tf.range(data_shape[2]) < true_image_shapes[:, tf.newaxis, 1],
tf.float32)
mask = tf.expand_dims(
mask_h[:, :, tf.newaxis] * mask_w[:, tf.newaxis, :], 3)
return mask
class ObjectDetectionParams(
collections.namedtuple('ObjectDetectionParams', [
'localization_loss', 'scale_loss_weight', 'offset_loss_weight',
'task_loss_weight', 'scale_head_num_filters',
'scale_head_kernel_sizes', 'offset_head_num_filters',
'offset_head_kernel_sizes'
])):
"""Namedtuple to host object detection related parameters.
This is a wrapper class over the fields that are either the hyper-parameters
or the loss functions needed for the object detection task. The class is
immutable after constructed. Please see the __new__ function for detailed
information for each fields.
"""
__slots__ = ()
def __new__(cls,
localization_loss,
scale_loss_weight,
offset_loss_weight,
task_loss_weight=1.0,
scale_head_num_filters=(256),
scale_head_kernel_sizes=(3),
offset_head_num_filters=(256),
offset_head_kernel_sizes=(3)):
"""Constructor with default values for ObjectDetectionParams.
Args:
localization_loss: a object_detection.core.losses.Loss object to compute
the loss for the center offset and height/width predictions in
CenterNet.
scale_loss_weight: float, The weight for localizing box size. Note that
the scale loss is dependent on the input image size, since we penalize
the raw height and width. This constant may need to be adjusted
depending on the input size.
offset_loss_weight: float, The weight for localizing center offsets.
task_loss_weight: float, the weight of the object detection loss.
scale_head_num_filters: filter numbers of the convolutional layers used
by the object detection box scale prediction head.
scale_head_kernel_sizes: kernel size of the convolutional layers used
by the object detection box scale prediction head.
offset_head_num_filters: filter numbers of the convolutional layers used
by the object detection box offset prediction head.
offset_head_kernel_sizes: kernel size of the convolutional layers used
by the object detection box offset prediction head.
Returns:
An initialized ObjectDetectionParams namedtuple.
"""
return super(ObjectDetectionParams,
cls).__new__(cls, localization_loss, scale_loss_weight,
offset_loss_weight, task_loss_weight,
scale_head_num_filters, scale_head_kernel_sizes,
offset_head_num_filters, offset_head_kernel_sizes)
class KeypointEstimationParams(
collections.namedtuple('KeypointEstimationParams', [
'task_name', 'class_id', 'keypoint_indices', 'classification_loss',
'localization_loss', 'keypoint_labels', 'keypoint_std_dev',
'keypoint_heatmap_loss_weight', 'keypoint_offset_loss_weight',
'keypoint_regression_loss_weight', 'keypoint_candidate_score_threshold',
'heatmap_bias_init', 'num_candidates_per_keypoint', 'task_loss_weight',
'peak_max_pool_kernel_size', 'unmatched_keypoint_score', 'box_scale',
'candidate_search_scale', 'candidate_ranking_mode',
'offset_peak_radius', 'per_keypoint_offset', 'predict_depth',
'per_keypoint_depth', 'keypoint_depth_loss_weight',
'score_distance_offset', 'clip_out_of_frame_keypoints',
'rescore_instances', 'heatmap_head_num_filters',
'heatmap_head_kernel_sizes', 'offset_head_num_filters',
'offset_head_kernel_sizes', 'regress_head_num_filters',
'regress_head_kernel_sizes', 'score_distance_multiplier',
'std_dev_multiplier', 'rescoring_threshold', 'gaussian_denom_ratio',
'argmax_postprocessing'
])):
"""Namedtuple to host object detection related parameters.
This is a wrapper class over the fields that are either the hyper-parameters
or the loss functions needed for the keypoint estimation task. The class is
immutable after constructed. Please see the __new__ function for detailed
information for each fields.
"""
__slots__ = ()
def __new__(cls,
task_name,
class_id,
keypoint_indices,
classification_loss,
localization_loss,
keypoint_labels=None,
keypoint_std_dev=None,
keypoint_heatmap_loss_weight=1.0,
keypoint_offset_loss_weight=1.0,
keypoint_regression_loss_weight=1.0,
keypoint_candidate_score_threshold=0.1,
heatmap_bias_init=-2.19,
num_candidates_per_keypoint=100,
task_loss_weight=1.0,
peak_max_pool_kernel_size=3,
unmatched_keypoint_score=0.1,
box_scale=1.2,
candidate_search_scale=0.3,
candidate_ranking_mode='min_distance',
offset_peak_radius=0,
per_keypoint_offset=False,
predict_depth=False,
per_keypoint_depth=False,
keypoint_depth_loss_weight=1.0,
score_distance_offset=1e-6,
clip_out_of_frame_keypoints=False,
rescore_instances=False,
heatmap_head_num_filters=(256),
heatmap_head_kernel_sizes=(3),
offset_head_num_filters=(256),
offset_head_kernel_sizes=(3),
regress_head_num_filters=(256),
regress_head_kernel_sizes=(3),
score_distance_multiplier=0.1,
std_dev_multiplier=1.0,
rescoring_threshold=0.0,
argmax_postprocessing=False,
gaussian_denom_ratio=0.1):
"""Constructor with default values for KeypointEstimationParams.
Args:
task_name: string, the name of the task this namedtuple corresponds to.
Note that it should be an unique identifier of the task.
class_id: int, the ID of the class that contains the target keypoints to
considered in this task. For example, if the task is human pose
estimation, the class id should correspond to the "human" class. Note
that the ID is 0-based, meaning that class 0 corresponds to the first
non-background object class.
keypoint_indices: A list of integers representing the indicies of the
keypoints to be considered in this task. This is used to retrieve the
subset of the keypoints from gt_keypoints that should be considered in
this task.
classification_loss: an object_detection.core.losses.Loss object to
compute the loss for the class predictions in CenterNet.
localization_loss: an object_detection.core.losses.Loss object to compute
the loss for the center offset and height/width predictions in
CenterNet.
keypoint_labels: A list of strings representing the label text of each
keypoint, e.g. "nose", 'left_shoulder". Note that the length of this
list should be equal to keypoint_indices.
keypoint_std_dev: A list of float represent the standard deviation of the
Gaussian kernel used to generate the keypoint heatmap. It is to provide
the flexibility of using different sizes of Gaussian kernel for each
keypoint class.
keypoint_heatmap_loss_weight: float, The weight for the keypoint heatmap.
keypoint_offset_loss_weight: float, The weight for the keypoint offsets
loss.
keypoint_regression_loss_weight: float, The weight for keypoint regression
loss. Note that the loss is dependent on the input image size, since we
penalize the raw height and width. This constant may need to be adjusted
depending on the input size.
keypoint_candidate_score_threshold: float, The heatmap score threshold for
a keypoint to become a valid candidate.
heatmap_bias_init: float, the initial value of bias in the convolutional
kernel of the class prediction head. If set to None, the bias is
initialized with zeros.
num_candidates_per_keypoint: The maximum number of candidates to retrieve
for each keypoint.
task_loss_weight: float, the weight of the keypoint estimation loss.
peak_max_pool_kernel_size: Max pool kernel size to use to pull off peak
score locations in a neighborhood (independently for each keypoint
types).
unmatched_keypoint_score: The default score to use for regressed keypoints
that are not successfully snapped to a nearby candidate.
box_scale: The multiplier to expand the bounding boxes (either the
provided boxes or those which tightly cover the regressed keypoints).
candidate_search_scale: The scale parameter that multiplies the largest
dimension of a bounding box. The resulting distance becomes a search
radius for candidates in the vicinity of each regressed keypoint.
candidate_ranking_mode: One of ['min_distance', 'score_distance_ratio',
'score_scaled_distance_ratio', 'gaussian_weighted'] indicating how to
select the keypoint candidate.
offset_peak_radius: The radius (in the unit of output pixel) around
groundtruth heatmap peak to assign the offset targets. If set 0, then
the offset target will only be assigned to the heatmap peak (same
behavior as the original paper).
per_keypoint_offset: A bool indicates whether to assign offsets for each
keypoint channel separately. If set False, the output offset target has
the shape [batch_size, out_height, out_width, 2] (same behavior as the
original paper). If set True, the output offset target has the shape
[batch_size, out_height, out_width, 2 * num_keypoints] (recommended when
the offset_peak_radius is not zero).
predict_depth: A bool indicates whether to predict the depth of each
keypoints.
per_keypoint_depth: A bool indicates whether the model predicts the depth
of each keypoints in independent channels. Similar to
per_keypoint_offset but for the keypoint depth.
keypoint_depth_loss_weight: The weight of the keypoint depth loss.
score_distance_offset: The distance offset to apply in the denominator
when candidate_ranking_mode is 'score_distance_ratio'. The metric to
maximize in this scenario is score / (distance + score_distance_offset).
Larger values of score_distance_offset make the keypoint score gain more
relative importance.
clip_out_of_frame_keypoints: Whether keypoints outside the image frame
should be clipped back to the image boundary. If True, the keypoints
that are clipped have scores set to 0.0.
rescore_instances: Whether to rescore instances based on a combination of
detection score and keypoint scores.
heatmap_head_num_filters: filter numbers of the convolutional layers used
by the keypoint heatmap prediction head.
heatmap_head_kernel_sizes: kernel size of the convolutional layers used
by the keypoint heatmap prediction head.
offset_head_num_filters: filter numbers of the convolutional layers used
by the keypoint offset prediction head.
offset_head_kernel_sizes: kernel size of the convolutional layers used
by the keypoint offset prediction head.
regress_head_num_filters: filter numbers of the convolutional layers used
by the keypoint regression prediction head.
regress_head_kernel_sizes: kernel size of the convolutional layers used
by the keypoint regression prediction head.
score_distance_multiplier: A scalar used to multiply the bounding box size
to be used as the offset in the score-to-distance-ratio formula.
std_dev_multiplier: A scalar used to multiply the standard deviation to
control the Gaussian kernel which used to weight the candidates.
rescoring_threshold: A scalar used when "rescore_instances" is set to
True. The detection score of an instance is set to be the average over
the scores of the keypoints which their scores higher than the
threshold.
argmax_postprocessing: Whether to use the keypoint postprocessing logic
that replaces the topk op with argmax. Usually used when exporting the
model for predicting keypoints of multiple instances in the browser.
gaussian_denom_ratio: The ratio used to multiply the image size to
determine the denominator of the Gaussian formula. Only applicable when
the candidate_ranking_mode is set to be 'gaussian_weighted_const'.
Returns:
An initialized KeypointEstimationParams namedtuple.
"""
return super(KeypointEstimationParams, cls).__new__(
cls, task_name, class_id, keypoint_indices, classification_loss,
localization_loss, keypoint_labels, keypoint_std_dev,
keypoint_heatmap_loss_weight, keypoint_offset_loss_weight,
keypoint_regression_loss_weight, keypoint_candidate_score_threshold,
heatmap_bias_init, num_candidates_per_keypoint, task_loss_weight,
peak_max_pool_kernel_size, unmatched_keypoint_score, box_scale,
candidate_search_scale, candidate_ranking_mode, offset_peak_radius,
per_keypoint_offset, predict_depth, per_keypoint_depth,
keypoint_depth_loss_weight, score_distance_offset,
clip_out_of_frame_keypoints, rescore_instances,
heatmap_head_num_filters, heatmap_head_kernel_sizes,
offset_head_num_filters, offset_head_kernel_sizes,
regress_head_num_filters, regress_head_kernel_sizes,
score_distance_multiplier, std_dev_multiplier, rescoring_threshold,
argmax_postprocessing, gaussian_denom_ratio)
class ObjectCenterParams(
collections.namedtuple('ObjectCenterParams', [
'classification_loss', 'object_center_loss_weight', 'heatmap_bias_init',
'min_box_overlap_iou', 'max_box_predictions', 'use_labeled_classes',
'keypoint_weights_for_center', 'center_head_num_filters',
'center_head_kernel_sizes', 'peak_max_pool_kernel_size'
])):
"""Namedtuple to store object center prediction related parameters."""
__slots__ = ()
def __new__(cls,
classification_loss,
object_center_loss_weight,
heatmap_bias_init=-2.19,
min_box_overlap_iou=0.7,
max_box_predictions=100,
use_labeled_classes=False,
keypoint_weights_for_center=None,
center_head_num_filters=(256),
center_head_kernel_sizes=(3),
peak_max_pool_kernel_size=3):
"""Constructor with default values for ObjectCenterParams.
Args:
classification_loss: an object_detection.core.losses.Loss object to
compute the loss for the class predictions in CenterNet.
object_center_loss_weight: float, The weight for the object center loss.
heatmap_bias_init: float, the initial value of bias in the convolutional
kernel of the object center prediction head. If set to None, the bias is
initialized with zeros.
min_box_overlap_iou: float, the minimum IOU overlap that predicted boxes
need have with groundtruth boxes to not be penalized. This is used for
computing the class specific center heatmaps.
max_box_predictions: int, the maximum number of boxes to predict.
use_labeled_classes: boolean, compute the loss only labeled classes.
keypoint_weights_for_center: (optional) The keypoint weights used for
calculating the location of object center. If provided, the number of
weights need to be the same as the number of keypoints. The object
center is calculated by the weighted mean of the keypoint locations. If
not provided, the object center is determined by the center of the
bounding box (default behavior).
center_head_num_filters: filter numbers of the convolutional layers used
by the object center prediction head.
center_head_kernel_sizes: kernel size of the convolutional layers used
by the object center prediction head.
peak_max_pool_kernel_size: Max pool kernel size to use to pull off peak
score locations in a neighborhood for the object detection heatmap.
Returns:
An initialized ObjectCenterParams namedtuple.
"""
return super(ObjectCenterParams,
cls).__new__(cls, classification_loss,
object_center_loss_weight, heatmap_bias_init,
min_box_overlap_iou, max_box_predictions,
use_labeled_classes, keypoint_weights_for_center,
center_head_num_filters, center_head_kernel_sizes,
peak_max_pool_kernel_size)
class MaskParams(
collections.namedtuple('MaskParams', [
'classification_loss', 'task_loss_weight', 'mask_height', 'mask_width',
'score_threshold', 'heatmap_bias_init', 'mask_head_num_filters',
'mask_head_kernel_sizes'
])):
"""Namedtuple to store mask prediction related parameters."""
__slots__ = ()
def __new__(cls,
classification_loss,
task_loss_weight=1.0,
mask_height=256,
mask_width=256,
score_threshold=0.5,
heatmap_bias_init=-2.19,
mask_head_num_filters=(256),
mask_head_kernel_sizes=(3)):
"""Constructor with default values for MaskParams.
Args:
classification_loss: an object_detection.core.losses.Loss object to
compute the loss for the semantic segmentation predictions in CenterNet.
task_loss_weight: float, The loss weight for the segmentation task.
mask_height: The height of the resized instance segmentation mask.
mask_width: The width of the resized instance segmentation mask.
score_threshold: The threshold at which to convert predicted mask
probabilities (after passing through sigmoid) into foreground pixels.
heatmap_bias_init: float, the initial value of bias in the convolutional
kernel of the semantic segmentation prediction head. If set to None, the
bias is initialized with zeros.
mask_head_num_filters: filter numbers of the convolutional layers used
by the mask prediction head.
mask_head_kernel_sizes: kernel size of the convolutional layers used
by the mask prediction head.
Returns:
An initialized MaskParams namedtuple.
"""
return super(MaskParams,
cls).__new__(cls, classification_loss,
task_loss_weight, mask_height, mask_width,
score_threshold, heatmap_bias_init,
mask_head_num_filters, mask_head_kernel_sizes)
class DensePoseParams(
collections.namedtuple('DensePoseParams', [
'class_id', 'classification_loss', 'localization_loss',
'part_loss_weight', 'coordinate_loss_weight', 'num_parts',
'task_loss_weight', 'upsample_to_input_res', 'upsample_method',
'heatmap_bias_init'
])):
"""Namedtuple to store DensePose prediction related parameters."""
__slots__ = ()
def __new__(cls,
class_id,
classification_loss,
localization_loss,
part_loss_weight=1.0,
coordinate_loss_weight=1.0,
num_parts=24,
task_loss_weight=1.0,
upsample_to_input_res=True,
upsample_method='bilinear',
heatmap_bias_init=-2.19):
"""Constructor with default values for DensePoseParams.
Args:
class_id: the ID of the class that contains the DensePose groundtruth.
This should typically correspond to the "person" class. Note that the ID
is 0-based, meaning that class 0 corresponds to the first non-background
object class.
classification_loss: an object_detection.core.losses.Loss object to
compute the loss for the body part predictions in CenterNet.
localization_loss: an object_detection.core.losses.Loss object to compute
the loss for the surface coordinate regression in CenterNet.
part_loss_weight: The loss weight to apply to part prediction.
coordinate_loss_weight: The loss weight to apply to surface coordinate
prediction.
num_parts: The number of DensePose parts to predict.
task_loss_weight: float, the loss weight for the DensePose task.
upsample_to_input_res: Whether to upsample the DensePose feature maps to
the input resolution before applying loss. Note that the prediction
outputs are still at the standard CenterNet output stride.
upsample_method: Method for upsampling DensePose feature maps. Options are
either 'bilinear' or 'nearest'). This takes no effect when
`upsample_to_input_res` is False.
heatmap_bias_init: float, the initial value of bias in the convolutional
kernel of the part prediction head. If set to None, the
bias is initialized with zeros.
Returns:
An initialized DensePoseParams namedtuple.
"""
return super(DensePoseParams,
cls).__new__(cls, class_id, classification_loss,
localization_loss, part_loss_weight,
coordinate_loss_weight, num_parts,
task_loss_weight, upsample_to_input_res,
upsample_method, heatmap_bias_init)
class TrackParams(
collections.namedtuple('TrackParams', [
'num_track_ids', 'reid_embed_size', 'num_fc_layers',
'classification_loss', 'task_loss_weight'
])):
"""Namedtuple to store tracking prediction related parameters."""
__slots__ = ()
def __new__(cls,
num_track_ids,
reid_embed_size,
num_fc_layers,
classification_loss,
task_loss_weight=1.0):
"""Constructor with default values for TrackParams.
Args:
num_track_ids: int. The maximum track ID in the dataset. Used for ReID
embedding classification task.
reid_embed_size: int. The embedding size for ReID task.
num_fc_layers: int. The number of (fully-connected, batch-norm, relu)
layers for track ID classification head.
classification_loss: an object_detection.core.losses.Loss object to
compute the loss for the ReID embedding in CenterNet.
task_loss_weight: float, the loss weight for the tracking task.
Returns:
An initialized TrackParams namedtuple.
"""
return super(TrackParams,
cls).__new__(cls, num_track_ids, reid_embed_size,
num_fc_layers, classification_loss,
task_loss_weight)
class TemporalOffsetParams(
collections.namedtuple('TemporalOffsetParams', [
'localization_loss', 'task_loss_weight'
])):
"""Namedtuple to store temporal offset related parameters."""
__slots__ = ()
def __new__(cls,
localization_loss,
task_loss_weight=1.0):
"""Constructor with default values for TrackParams.
Args:
localization_loss: an object_detection.core.losses.Loss object to
compute the loss for the temporal offset in CenterNet.
task_loss_weight: float, the loss weight for the temporal offset
task.
Returns:
An initialized TemporalOffsetParams namedtuple.
"""
return super(TemporalOffsetParams,
cls).__new__(cls, localization_loss, task_loss_weight)
# The following constants are used to generate the keys of the
# (prediction, loss, target assigner,...) dictionaries used in CenterNetMetaArch
# class.
DETECTION_TASK = 'detection_task'
OBJECT_CENTER = 'object_center'
BOX_SCALE = 'box/scale'
BOX_OFFSET = 'box/offset'
KEYPOINT_REGRESSION = 'keypoint/regression'
KEYPOINT_HEATMAP = 'keypoint/heatmap'
KEYPOINT_OFFSET = 'keypoint/offset'
KEYPOINT_DEPTH = 'keypoint/depth'
SEGMENTATION_TASK = 'segmentation_task'
SEGMENTATION_HEATMAP = 'segmentation/heatmap'
DENSEPOSE_TASK = 'densepose_task'
DENSEPOSE_HEATMAP = 'densepose/heatmap'
DENSEPOSE_REGRESSION = 'densepose/regression'
LOSS_KEY_PREFIX = 'Loss'
TRACK_TASK = 'track_task'
TRACK_REID = 'track/reid'
TEMPORALOFFSET_TASK = 'temporal_offset_task'
TEMPORAL_OFFSET = 'track/offset'
def get_keypoint_name(task_name, head_name):
return '%s/%s' % (task_name, head_name)
def get_num_instances_from_weights(groundtruth_weights_list):
"""Computes the number of instances/boxes from the weights in a batch.
Args:
groundtruth_weights_list: A list of float tensors with shape
[max_num_instances] representing whether there is an actual instance in
the image (with non-zero value) or is padded to match the
max_num_instances (with value 0.0). The list represents the batch
dimension.
Returns:
A scalar integer tensor incidating how many instances/boxes are in the
images in the batch. Note that this function is usually used to normalize
the loss so the minimum return value is 1 to avoid weird behavior.
"""
num_instances = tf.reduce_sum(
[tf.math.count_nonzero(w) for w in groundtruth_weights_list])
num_instances = tf.maximum(num_instances, 1)
return num_instances
class CenterNetMetaArch(model.DetectionModel):
"""The CenterNet meta architecture [1].
[1]: https://arxiv.org/abs/1904.07850
"""
def __init__(self,
is_training,
add_summaries,
num_classes,
feature_extractor,
image_resizer_fn,
object_center_params,
object_detection_params=None,
keypoint_params_dict=None,
mask_params=None,
densepose_params=None,
track_params=None,
temporal_offset_params=None,
use_depthwise=False,
compute_heatmap_sparse=False,
non_max_suppression_fn=None,
unit_height_conv=False,
output_prediction_dict=False):
"""Initializes a CenterNet model.
Args:
is_training: Set to True if this model is being built for training.
add_summaries: Whether to add tf summaries in the model.
num_classes: int, The number of classes that the model should predict.
feature_extractor: A CenterNetFeatureExtractor to use to extract features
from an image.
image_resizer_fn: a callable for image resizing. This callable always
takes a rank-3 image tensor (corresponding to a single image) and
returns a rank-3 image tensor, possibly with new spatial dimensions and
a 1-D tensor of shape [3] indicating shape of true image within the
resized image tensor as the resized image tensor could be padded. See
builders/image_resizer_builder.py.
object_center_params: An ObjectCenterParams namedtuple. This object holds
the hyper-parameters for object center prediction. This is required by
either object detection or keypoint estimation tasks.
object_detection_params: An ObjectDetectionParams namedtuple. This object
holds the hyper-parameters necessary for object detection. Please see
the class definition for more details.
keypoint_params_dict: A dictionary that maps from task name to the
corresponding KeypointEstimationParams namedtuple. This object holds the
hyper-parameters necessary for multiple keypoint estimations. Please
see the class definition for more details.
mask_params: A MaskParams namedtuple. This object
holds the hyper-parameters for segmentation. Please see the class
definition for more details.
densepose_params: A DensePoseParams namedtuple. This object holds the
hyper-parameters for DensePose prediction. Please see the class
definition for more details. Note that if this is provided, it is
expected that `mask_params` is also provided.
track_params: A TrackParams namedtuple. This object
holds the hyper-parameters for tracking. Please see the class
definition for more details.
temporal_offset_params: A TemporalOffsetParams namedtuple. This object
holds the hyper-parameters for offset prediction based tracking.
use_depthwise: If true, all task heads will be constructed using
separable_conv. Otherwise, standard convoltuions will be used.
compute_heatmap_sparse: bool, whether or not to use the sparse version of
the Op that computes the center heatmaps. The sparse version scales
better with number of channels in the heatmap, but in some cases is
known to cause an OOM error. See b/170989061.
non_max_suppression_fn: Optional Non Max Suppression function to apply.
unit_height_conv: If True, Conv2Ds in prediction heads have asymmetric
kernels with height=1.
output_prediction_dict: If true, combines all items from the dictionary
returned by predict() function into the output of postprocess().
"""
assert object_detection_params or keypoint_params_dict
# Shorten the name for convenience and better formatting.
self._is_training = is_training
# The Objects as Points paper attaches loss functions to multiple
# (`num_feature_outputs`) feature maps in the the backbone. E.g.
# for the hourglass backbone, `num_feature_outputs` is 2.
self._num_classes = num_classes
self._feature_extractor = feature_extractor
self._num_feature_outputs = feature_extractor.num_feature_outputs
self._stride = self._feature_extractor.out_stride
self._image_resizer_fn = image_resizer_fn
self._center_params = object_center_params
self._od_params = object_detection_params
self._kp_params_dict = keypoint_params_dict
self._mask_params = mask_params
if densepose_params is not None and mask_params is None:
raise ValueError('To run DensePose prediction, `mask_params` must also '
'be supplied.')
self._densepose_params = densepose_params
self._track_params = track_params
self._temporal_offset_params = temporal_offset_params
self._use_depthwise = use_depthwise
self._compute_heatmap_sparse = compute_heatmap_sparse
self._output_prediction_dict = output_prediction_dict
# subclasses may not implement the unit_height_conv arg, so only provide it
# as a kwarg if it is True.
kwargs = {'unit_height_conv': unit_height_conv} if unit_height_conv else {}
# Construct the prediction head nets.
self._prediction_head_dict = self._construct_prediction_heads(
num_classes,
self._num_feature_outputs,
class_prediction_bias_init=self._center_params.heatmap_bias_init,
**kwargs)
# Initialize the target assigners.
self._target_assigner_dict = self._initialize_target_assigners(
stride=self._stride,
min_box_overlap_iou=self._center_params.min_box_overlap_iou)
# Will be used in VOD single_frame_meta_arch for tensor reshape.
self._batched_prediction_tensor_names = []
self._non_max_suppression_fn = non_max_suppression_fn
super(CenterNetMetaArch, self).__init__(num_classes)
def set_trainability_by_layer_traversal(self, trainable):
"""Sets trainability layer by layer.
The commonly-seen `model.trainable = False` method does not traverse
the children layer. For example, if the parent is not trainable, we won't
be able to set individual layers as trainable/non-trainable differentially.
Args:
trainable: (bool) Setting this for the model layer by layer except for
the parent itself.
"""
for layer in self._flatten_layers(include_self=False):
layer.trainable = trainable
@property
def prediction_head_dict(self):
return self._prediction_head_dict
@property
def batched_prediction_tensor_names(self):
if not self._batched_prediction_tensor_names:
raise RuntimeError('Must call predict() method to get batched prediction '
'tensor names.')
return self._batched_prediction_tensor_names
def _make_prediction_net_list(self, num_feature_outputs, num_out_channels,
kernel_sizes=(3), num_filters=(256),
bias_fill=None, name=None,
unit_height_conv=False):
prediction_net_list = []
for i in range(num_feature_outputs):
prediction_net_list.append(
make_prediction_net(
num_out_channels,
kernel_sizes=kernel_sizes,
num_filters=num_filters,
bias_fill=bias_fill,
use_depthwise=self._use_depthwise,
name='{}_{}'.format(name, i) if name else name,
unit_height_conv=unit_height_conv))
return prediction_net_list
def _construct_prediction_heads(self, num_classes, num_feature_outputs,
class_prediction_bias_init,
unit_height_conv=False):
"""Constructs the prediction heads based on the specific parameters.
Args:
num_classes: An integer indicating how many classes in total to predict.
num_feature_outputs: An integer indicating how many feature outputs to use
for calculating the loss. The Objects as Points paper attaches loss
functions to multiple (`num_feature_outputs`) feature maps in the the
backbone. E.g. for the hourglass backbone, `num_feature_outputs` is 2.
class_prediction_bias_init: float, the initial value of bias in the
convolutional kernel of the class prediction head. If set to None, the
bias is initialized with zeros.
unit_height_conv: If True, Conv2Ds have asymmetric kernels with height=1.
Returns:
A dictionary of keras modules generated by calling make_prediction_net
function. It will also create and set a private member of the class when
learning the tracking task.
"""
prediction_heads = {}
prediction_heads[OBJECT_CENTER] = self._make_prediction_net_list(
num_feature_outputs,
num_classes,
kernel_sizes=self._center_params.center_head_kernel_sizes,
num_filters=self._center_params.center_head_num_filters,
bias_fill=class_prediction_bias_init,
name='center',
unit_height_conv=unit_height_conv)
if self._od_params is not None:
prediction_heads[BOX_SCALE] = self._make_prediction_net_list(
num_feature_outputs,
NUM_SIZE_CHANNELS,
kernel_sizes=self._od_params.scale_head_kernel_sizes,
num_filters=self._od_params.scale_head_num_filters,
name='box_scale',
unit_height_conv=unit_height_conv)
prediction_heads[BOX_OFFSET] = self._make_prediction_net_list(
num_feature_outputs,
NUM_OFFSET_CHANNELS,
kernel_sizes=self._od_params.offset_head_kernel_sizes,
num_filters=self._od_params.offset_head_num_filters,
name='box_offset',
unit_height_conv=unit_height_conv)
if self._kp_params_dict is not None:
for task_name, kp_params in self._kp_params_dict.items():
num_keypoints = len(kp_params.keypoint_indices)
prediction_heads[get_keypoint_name(
task_name, KEYPOINT_HEATMAP)] = self._make_prediction_net_list(
num_feature_outputs,
num_keypoints,
kernel_sizes=kp_params.heatmap_head_kernel_sizes,
num_filters=kp_params.heatmap_head_num_filters,
bias_fill=kp_params.heatmap_bias_init,
name='kpt_heatmap',
unit_height_conv=unit_height_conv)
prediction_heads[get_keypoint_name(
task_name, KEYPOINT_REGRESSION)] = self._make_prediction_net_list(
num_feature_outputs,
NUM_OFFSET_CHANNELS * num_keypoints,
kernel_sizes=kp_params.regress_head_kernel_sizes,
num_filters=kp_params.regress_head_num_filters,
name='kpt_regress',
unit_height_conv=unit_height_conv)
if kp_params.per_keypoint_offset:
prediction_heads[get_keypoint_name(
task_name, KEYPOINT_OFFSET)] = self._make_prediction_net_list(
num_feature_outputs,
NUM_OFFSET_CHANNELS * num_keypoints,
kernel_sizes=kp_params.offset_head_kernel_sizes,
num_filters=kp_params.offset_head_num_filters,
name='kpt_offset',
unit_height_conv=unit_height_conv)
else:
prediction_heads[get_keypoint_name(
task_name, KEYPOINT_OFFSET)] = self._make_prediction_net_list(
num_feature_outputs,
NUM_OFFSET_CHANNELS,
kernel_sizes=kp_params.offset_head_kernel_sizes,
num_filters=kp_params.offset_head_num_filters,
name='kpt_offset',
unit_height_conv=unit_height_conv)
if kp_params.predict_depth:
num_depth_channel = (
num_keypoints if kp_params.per_keypoint_depth else 1)
prediction_heads[get_keypoint_name(
task_name, KEYPOINT_DEPTH)] = self._make_prediction_net_list(
num_feature_outputs, num_depth_channel, name='kpt_depth',
unit_height_conv=unit_height_conv)
if self._mask_params is not None:
prediction_heads[SEGMENTATION_HEATMAP] = self._make_prediction_net_list(
num_feature_outputs,
num_classes,
kernel_sizes=self._mask_params.mask_head_kernel_sizes,
num_filters=self._mask_params.mask_head_num_filters,
bias_fill=self._mask_params.heatmap_bias_init,
name='seg_heatmap',
unit_height_conv=unit_height_conv)
if self._densepose_params is not None:
prediction_heads[DENSEPOSE_HEATMAP] = self._make_prediction_net_list(
num_feature_outputs,
self._densepose_params.num_parts,
bias_fill=self._densepose_params.heatmap_bias_init,
name='dense_pose_heatmap',
unit_height_conv=unit_height_conv)
prediction_heads[DENSEPOSE_REGRESSION] = self._make_prediction_net_list(
num_feature_outputs,
2 * self._densepose_params.num_parts,
name='dense_pose_regress',
unit_height_conv=unit_height_conv)
if self._track_params is not None:
prediction_heads[TRACK_REID] = self._make_prediction_net_list(
num_feature_outputs,
self._track_params.reid_embed_size,
name='track_reid',
unit_height_conv=unit_height_conv)
# Creates a classification network to train object embeddings by learning
# a projection from embedding space to object track ID space.
self.track_reid_classification_net = tf.keras.Sequential()
for _ in range(self._track_params.num_fc_layers - 1):
self.track_reid_classification_net.add(
tf.keras.layers.Dense(self._track_params.reid_embed_size))
self.track_reid_classification_net.add(
tf.keras.layers.BatchNormalization())
self.track_reid_classification_net.add(tf.keras.layers.ReLU())
self.track_reid_classification_net.add(
tf.keras.layers.Dense(self._track_params.num_track_ids))
if self._temporal_offset_params is not None:
prediction_heads[TEMPORAL_OFFSET] = self._make_prediction_net_list(
num_feature_outputs, NUM_OFFSET_CHANNELS, name='temporal_offset',
unit_height_conv=unit_height_conv)
return prediction_heads
def _initialize_target_assigners(self, stride, min_box_overlap_iou):
"""Initializes the target assigners and puts them in a dictionary.
Args:
stride: An integer indicating the stride of the image.
min_box_overlap_iou: float, the minimum IOU overlap that predicted boxes
need have with groundtruth boxes to not be penalized. This is used for
computing the class specific center heatmaps.
Returns:
A dictionary of initialized target assigners for each task.
"""
target_assigners = {}
keypoint_weights_for_center = (
self._center_params.keypoint_weights_for_center)
if not keypoint_weights_for_center:
target_assigners[OBJECT_CENTER] = (
cn_assigner.CenterNetCenterHeatmapTargetAssigner(
stride, min_box_overlap_iou, self._compute_heatmap_sparse))
self._center_from_keypoints = False
else:
# Determining the object center location by keypoint location is only
# supported when there is exactly one keypoint prediction task and no
# object detection task is specified.
assert len(self._kp_params_dict) == 1 and self._od_params is None
kp_params = next(iter(self._kp_params_dict.values()))
# The number of keypoint_weights_for_center needs to be the same as the
# number of keypoints.
assert len(keypoint_weights_for_center) == len(kp_params.keypoint_indices)
target_assigners[OBJECT_CENTER] = (
cn_assigner.CenterNetCenterHeatmapTargetAssigner(
stride,
min_box_overlap_iou,
self._compute_heatmap_sparse,
keypoint_class_id=kp_params.class_id,
keypoint_indices=kp_params.keypoint_indices,
keypoint_weights_for_center=keypoint_weights_for_center))
self._center_from_keypoints = True
if self._od_params is not None:
target_assigners[DETECTION_TASK] = (
cn_assigner.CenterNetBoxTargetAssigner(stride))
if self._kp_params_dict is not None:
for task_name, kp_params in self._kp_params_dict.items():
target_assigners[task_name] = (
cn_assigner.CenterNetKeypointTargetAssigner(
stride=stride,
class_id=kp_params.class_id,
keypoint_indices=kp_params.keypoint_indices,
keypoint_std_dev=kp_params.keypoint_std_dev,
peak_radius=kp_params.offset_peak_radius,
per_keypoint_offset=kp_params.per_keypoint_offset,
compute_heatmap_sparse=self._compute_heatmap_sparse,
per_keypoint_depth=kp_params.per_keypoint_depth))
if self._mask_params is not None:
target_assigners[SEGMENTATION_TASK] = (
cn_assigner.CenterNetMaskTargetAssigner(stride, boxes_scale=1.05))
if self._densepose_params is not None:
dp_stride = 1 if self._densepose_params.upsample_to_input_res else stride
target_assigners[DENSEPOSE_TASK] = (
cn_assigner.CenterNetDensePoseTargetAssigner(dp_stride))
if self._track_params is not None:
target_assigners[TRACK_TASK] = (
cn_assigner.CenterNetTrackTargetAssigner(
stride, self._track_params.num_track_ids))
if self._temporal_offset_params is not None:
target_assigners[TEMPORALOFFSET_TASK] = (
cn_assigner.CenterNetTemporalOffsetTargetAssigner(stride))
return target_assigners
def _compute_object_center_loss(self, input_height, input_width,
object_center_predictions, per_pixel_weights,
maximum_normalized_coordinate=1.1):
"""Computes the object center loss.
Args:
input_height: An integer scalar tensor representing input image height.
input_width: An integer scalar tensor representing input image width.
object_center_predictions: A list of float tensors of shape [batch_size,
out_height, out_width, num_classes] representing the object center
feature maps.
per_pixel_weights: A float tensor of shape [batch_size,
out_height * out_width, 1] with 1s in locations where the spatial
coordinates fall within the height and width in true_image_shapes.
maximum_normalized_coordinate: Maximum coordinate value to be considered
as normalized, default to 1.1. This is used to check bounds during
converting normalized coordinates to absolute coordinates.
Returns:
A float scalar tensor representing the object center loss per instance.
"""
gt_classes_list = self.groundtruth_lists(fields.BoxListFields.classes)
gt_weights_list = self.groundtruth_lists(fields.BoxListFields.weights)
if self._center_params.use_labeled_classes:
gt_labeled_classes_list = self.groundtruth_lists(
fields.InputDataFields.groundtruth_labeled_classes)
batch_labeled_classes = tf.stack(gt_labeled_classes_list, axis=0)
batch_labeled_classes_shape = tf.shape(batch_labeled_classes)
batch_labeled_classes = tf.reshape(
batch_labeled_classes,
[batch_labeled_classes_shape[0], 1, batch_labeled_classes_shape[-1]])
per_pixel_weights = per_pixel_weights * batch_labeled_classes
# Convert the groundtruth to targets.
assigner = self._target_assigner_dict[OBJECT_CENTER]
if self._center_from_keypoints:
gt_keypoints_list = self.groundtruth_lists(fields.BoxListFields.keypoints)
heatmap_targets = assigner.assign_center_targets_from_keypoints(
height=input_height,
width=input_width,
gt_classes_list=gt_classes_list,
gt_keypoints_list=gt_keypoints_list,
gt_weights_list=gt_weights_list)
else:
gt_boxes_list = self.groundtruth_lists(fields.BoxListFields.boxes)
heatmap_targets = assigner.assign_center_targets_from_boxes(
height=input_height,
width=input_width,
gt_boxes_list=gt_boxes_list,
gt_classes_list=gt_classes_list,
gt_weights_list=gt_weights_list,
maximum_normalized_coordinate=maximum_normalized_coordinate)
flattened_heatmap_targets = _flatten_spatial_dimensions(heatmap_targets)
num_boxes = _to_float32(get_num_instances_from_weights(gt_weights_list))
loss = 0.0
object_center_loss = self._center_params.classification_loss
# Loop through each feature output head.
for pred in object_center_predictions:
pred = _flatten_spatial_dimensions(pred)
loss += object_center_loss(
pred, flattened_heatmap_targets, weights=per_pixel_weights)
loss_per_instance = tf.reduce_sum(loss) / (
float(len(object_center_predictions)) * num_boxes)
return loss_per_instance
def _compute_object_detection_losses(self, input_height, input_width,
prediction_dict, per_pixel_weights,
maximum_normalized_coordinate=1.1):
"""Computes the weighted object detection losses.
This wrapper function calls the function which computes the losses for
object detection task and applies corresponding weights to the losses.
Args:
input_height: An integer scalar tensor representing input image height.
input_width: An integer scalar tensor representing input image width.
prediction_dict: A dictionary holding predicted tensors output by
"predict" function. See "predict" function for more detailed
description.
per_pixel_weights: A float tensor of shape [batch_size,
out_height * out_width, 1] with 1s in locations where the spatial
coordinates fall within the height and width in true_image_shapes.
maximum_normalized_coordinate: Maximum coordinate value to be considered
as normalized, default to 1.1. This is used to check bounds during
converting normalized coordinates to absolute coordinates.
Returns:
A dictionary of scalar float tensors representing the weighted losses for
object detection task:
BOX_SCALE: the weighted scale (height/width) loss.
BOX_OFFSET: the weighted object offset loss.
"""
od_scale_loss, od_offset_loss = self._compute_box_scale_and_offset_loss(
scale_predictions=prediction_dict[BOX_SCALE],
offset_predictions=prediction_dict[BOX_OFFSET],
input_height=input_height,
input_width=input_width,
maximum_normalized_coordinate=maximum_normalized_coordinate)
loss_dict = {}
loss_dict[BOX_SCALE] = (
self._od_params.scale_loss_weight * od_scale_loss)
loss_dict[BOX_OFFSET] = (
self._od_params.offset_loss_weight * od_offset_loss)
return loss_dict
def _compute_box_scale_and_offset_loss(self, input_height, input_width,
scale_predictions, offset_predictions,
maximum_normalized_coordinate=1.1):
"""Computes the scale loss of the object detection task.
Args:
input_height: An integer scalar tensor representing input image height.
input_width: An integer scalar tensor representing input image width.
scale_predictions: A list of float tensors of shape [batch_size,
out_height, out_width, 2] representing the prediction heads of the model
for object scale (i.e height and width).
offset_predictions: A list of float tensors of shape [batch_size,
out_height, out_width, 2] representing the prediction heads of the model
for object offset.
maximum_normalized_coordinate: Maximum coordinate value to be considered
as normalized, default to 1.1. This is used to check bounds during
converting normalized coordinates to absolute coordinates.
Returns:
A tuple of two losses:
scale_loss: A float scalar tensor representing the object height/width
loss normalized by total number of boxes.
offset_loss: A float scalar tensor representing the object offset loss
normalized by total number of boxes
"""
# TODO(vighneshb) Explore a size invariant version of scale loss.
gt_boxes_list = self.groundtruth_lists(fields.BoxListFields.boxes)
gt_weights_list = self.groundtruth_lists(fields.BoxListFields.weights)
num_boxes = _to_float32(get_num_instances_from_weights(gt_weights_list))
num_predictions = float(len(scale_predictions))
assigner = self._target_assigner_dict[DETECTION_TASK]
(batch_indices, batch_height_width_targets, batch_offset_targets,
batch_weights) = assigner.assign_size_and_offset_targets(
height=input_height,
width=input_width,
gt_boxes_list=gt_boxes_list,
gt_weights_list=gt_weights_list,
maximum_normalized_coordinate=maximum_normalized_coordinate)
batch_weights = tf.expand_dims(batch_weights, -1)
scale_loss = 0
offset_loss = 0
localization_loss_fn = self._od_params.localization_loss
for scale_pred, offset_pred in zip(scale_predictions, offset_predictions):
# Compute the scale loss.
scale_pred = cn_assigner.get_batch_predictions_from_indices(
scale_pred, batch_indices)
scale_loss += localization_loss_fn(
scale_pred, batch_height_width_targets, weights=batch_weights)
# Compute the offset loss.
offset_pred = cn_assigner.get_batch_predictions_from_indices(
offset_pred, batch_indices)
offset_loss += localization_loss_fn(
offset_pred, batch_offset_targets, weights=batch_weights)
scale_loss = tf.reduce_sum(scale_loss) / (
num_predictions * num_boxes)
offset_loss = tf.reduce_sum(offset_loss) / (
num_predictions * num_boxes)
return scale_loss, offset_loss
def _compute_keypoint_estimation_losses(self, task_name, input_height,
input_width, prediction_dict,
per_pixel_weights):
"""Computes the weighted keypoint losses."""
kp_params = self._kp_params_dict[task_name]
heatmap_key = get_keypoint_name(task_name, KEYPOINT_HEATMAP)
offset_key = get_keypoint_name(task_name, KEYPOINT_OFFSET)
regression_key = get_keypoint_name(task_name, KEYPOINT_REGRESSION)
depth_key = get_keypoint_name(task_name, KEYPOINT_DEPTH)
heatmap_loss = self._compute_kp_heatmap_loss(
input_height=input_height,
input_width=input_width,
task_name=task_name,
heatmap_predictions=prediction_dict[heatmap_key],
classification_loss_fn=kp_params.classification_loss,
per_pixel_weights=per_pixel_weights)
offset_loss = self._compute_kp_offset_loss(
input_height=input_height,
input_width=input_width,
task_name=task_name,
offset_predictions=prediction_dict[offset_key],
localization_loss_fn=kp_params.localization_loss)
reg_loss = self._compute_kp_regression_loss(
input_height=input_height,
input_width=input_width,
task_name=task_name,
regression_predictions=prediction_dict[regression_key],
localization_loss_fn=kp_params.localization_loss)
loss_dict = {}
loss_dict[heatmap_key] = (
kp_params.keypoint_heatmap_loss_weight * heatmap_loss)
loss_dict[offset_key] = (
kp_params.keypoint_offset_loss_weight * offset_loss)
loss_dict[regression_key] = (
kp_params.keypoint_regression_loss_weight * reg_loss)
if kp_params.predict_depth:
depth_loss = self._compute_kp_depth_loss(
input_height=input_height,
input_width=input_width,
task_name=task_name,
depth_predictions=prediction_dict[depth_key],
localization_loss_fn=kp_params.localization_loss)
loss_dict[depth_key] = kp_params.keypoint_depth_loss_weight * depth_loss
return loss_dict
def _compute_kp_heatmap_loss(self, input_height, input_width, task_name,
heatmap_predictions, classification_loss_fn,
per_pixel_weights):
"""Computes the heatmap loss of the keypoint estimation task.
Args:
input_height: An integer scalar tensor representing input image height.
input_width: An integer scalar tensor representing input image width.
task_name: A string representing the name of the keypoint task.
heatmap_predictions: A list of float tensors of shape [batch_size,
out_height, out_width, num_keypoints] representing the prediction heads
of the model for keypoint heatmap.
classification_loss_fn: An object_detection.core.losses.Loss object to
compute the loss for the class predictions in CenterNet.
per_pixel_weights: A float tensor of shape [batch_size,
out_height * out_width, 1] with 1s in locations where the spatial
coordinates fall within the height and width in true_image_shapes.
Returns:
loss: A float scalar tensor representing the object keypoint heatmap loss
normalized by number of instances.
"""
gt_keypoints_list = self.groundtruth_lists(fields.BoxListFields.keypoints)
gt_classes_list = self.groundtruth_lists(fields.BoxListFields.classes)
gt_weights_list = self.groundtruth_lists(fields.BoxListFields.weights)
gt_boxes_list = self.groundtruth_lists(fields.BoxListFields.boxes)
assigner = self._target_assigner_dict[task_name]
(keypoint_heatmap, num_instances_per_kp_type,
valid_mask_batch) = assigner.assign_keypoint_heatmap_targets(
height=input_height,
width=input_width,
gt_keypoints_list=gt_keypoints_list,
gt_weights_list=gt_weights_list,
gt_classes_list=gt_classes_list,
gt_boxes_list=gt_boxes_list)
flattened_valid_mask = _flatten_spatial_dimensions(valid_mask_batch)
flattened_heapmap_targets = _flatten_spatial_dimensions(keypoint_heatmap)
# Sum over the number of instances per keypoint types to get the total
# number of keypoints. Note that this is used to normalized the loss and we
# keep the minimum value to be 1 to avoid generating weird loss value when
# no keypoint is in the image batch.
num_instances = tf.maximum(
tf.cast(tf.reduce_sum(num_instances_per_kp_type), dtype=tf.float32),
1.0)
loss = 0.0
# Loop through each feature output head.
for pred in heatmap_predictions:
pred = _flatten_spatial_dimensions(pred)
unweighted_loss = classification_loss_fn(
pred,
flattened_heapmap_targets,
weights=tf.ones_like(per_pixel_weights))
# Apply the weights after the loss function to have full control over it.
loss += unweighted_loss * per_pixel_weights * flattened_valid_mask
loss = tf.reduce_sum(loss) / (
float(len(heatmap_predictions)) * num_instances)
return loss
def _compute_kp_offset_loss(self, input_height, input_width, task_name,
offset_predictions, localization_loss_fn):
"""Computes the offset loss of the keypoint estimation task.
Args:
input_height: An integer scalar tensor representing input image height.
input_width: An integer scalar tensor representing input image width.
task_name: A string representing the name of the keypoint task.
offset_predictions: A list of float tensors of shape [batch_size,
out_height, out_width, 2] representing the prediction heads of the model
for keypoint offset.
localization_loss_fn: An object_detection.core.losses.Loss object to
compute the loss for the keypoint offset predictions in CenterNet.
Returns:
loss: A float scalar tensor representing the keypoint offset loss
normalized by number of total keypoints.
"""
gt_keypoints_list = self.groundtruth_lists(fields.BoxListFields.keypoints)
gt_classes_list = self.groundtruth_lists(fields.BoxListFields.classes)
gt_weights_list = self.groundtruth_lists(fields.BoxListFields.weights)
assigner = self._target_assigner_dict[task_name]
(batch_indices, batch_offsets,
batch_weights) = assigner.assign_keypoints_offset_targets(
height=input_height,
width=input_width,
gt_keypoints_list=gt_keypoints_list,
gt_weights_list=gt_weights_list,
gt_classes_list=gt_classes_list)
# Keypoint offset loss.
loss = 0.0
for prediction in offset_predictions:
batch_size, out_height, out_width, channels = _get_shape(prediction, 4)
if channels > 2:
prediction = tf.reshape(
prediction, shape=[batch_size, out_height, out_width, -1, 2])
prediction = cn_assigner.get_batch_predictions_from_indices(
prediction, batch_indices)
# The dimensions passed are not as per the doc string but the loss
# still computes the correct value.
unweighted_loss = localization_loss_fn(
prediction,
batch_offsets,
weights=tf.expand_dims(tf.ones_like(batch_weights), -1))
# Apply the weights after the loss function to have full control over it.
loss += batch_weights * tf.reduce_sum(unweighted_loss, axis=1)
loss = tf.reduce_sum(loss) / (
float(len(offset_predictions)) *
tf.maximum(tf.reduce_sum(batch_weights), 1.0))
return loss
def _compute_kp_regression_loss(self, input_height, input_width, task_name,
regression_predictions, localization_loss_fn):
"""Computes the keypoint regression loss of the keypoint estimation task.
Args:
input_height: An integer scalar tensor representing input image height.
input_width: An integer scalar tensor representing input image width.
task_name: A string representing the name of the keypoint task.
regression_predictions: A list of float tensors of shape [batch_size,
out_height, out_width, 2 * num_keypoints] representing the prediction
heads of the model for keypoint regression offset.
localization_loss_fn: An object_detection.core.losses.Loss object to
compute the loss for the keypoint regression offset predictions in
CenterNet.
Returns:
loss: A float scalar tensor representing the keypoint regression offset
loss normalized by number of total keypoints.
"""
gt_boxes_list = self.groundtruth_lists(fields.BoxListFields.boxes)
gt_keypoints_list = self.groundtruth_lists(fields.BoxListFields.keypoints)
gt_classes_list = self.groundtruth_lists(fields.BoxListFields.classes)
gt_weights_list = self.groundtruth_lists(fields.BoxListFields.weights)
# keypoint regression offset loss.
assigner = self._target_assigner_dict[task_name]
(batch_indices, batch_regression_offsets,
batch_weights) = assigner.assign_joint_regression_targets(
height=input_height,
width=input_width,
gt_keypoints_list=gt_keypoints_list,
gt_classes_list=gt_classes_list,
gt_weights_list=gt_weights_list,
gt_boxes_list=gt_boxes_list)
loss = 0.0
for prediction in regression_predictions:
batch_size, out_height, out_width, _ = _get_shape(prediction, 4)
reshaped_prediction = tf.reshape(
prediction, shape=[batch_size, out_height, out_width, -1, 2])
reg_prediction = cn_assigner.get_batch_predictions_from_indices(
reshaped_prediction, batch_indices)
unweighted_loss = localization_loss_fn(
reg_prediction,
batch_regression_offsets,
weights=tf.expand_dims(tf.ones_like(batch_weights), -1))
# Apply the weights after the loss function to have full control over it.
loss += batch_weights * tf.reduce_sum(unweighted_loss, axis=1)
loss = tf.reduce_sum(loss) / (
float(len(regression_predictions)) *
tf.maximum(tf.reduce_sum(batch_weights), 1.0))
return loss
def _compute_kp_depth_loss(self, input_height, input_width, task_name,
depth_predictions, localization_loss_fn):
"""Computes the loss of the keypoint depth estimation.
Args:
input_height: An integer scalar tensor representing input image height.
input_width: An integer scalar tensor representing input image width.
task_name: A string representing the name of the keypoint task.
depth_predictions: A list of float tensors of shape [batch_size,
out_height, out_width, 1 (or num_keypoints)] representing the prediction
heads of the model for keypoint depth.
localization_loss_fn: An object_detection.core.losses.Loss object to
compute the loss for the keypoint offset predictions in CenterNet.
Returns:
loss: A float scalar tensor representing the keypoint depth loss
normalized by number of total keypoints.
"""
kp_params = self._kp_params_dict[task_name]
gt_keypoints_list = self.groundtruth_lists(fields.BoxListFields.keypoints)
gt_classes_list = self.groundtruth_lists(fields.BoxListFields.classes)
gt_weights_list = self.groundtruth_lists(fields.BoxListFields.weights)
gt_keypoint_depths_list = self.groundtruth_lists(
fields.BoxListFields.keypoint_depths)
gt_keypoint_depth_weights_list = self.groundtruth_lists(
fields.BoxListFields.keypoint_depth_weights)
assigner = self._target_assigner_dict[task_name]
(batch_indices, batch_depths,
batch_weights) = assigner.assign_keypoints_depth_targets(
height=input_height,
width=input_width,
gt_keypoints_list=gt_keypoints_list,
gt_weights_list=gt_weights_list,
gt_classes_list=gt_classes_list,
gt_keypoint_depths_list=gt_keypoint_depths_list,
gt_keypoint_depth_weights_list=gt_keypoint_depth_weights_list)
# Keypoint offset loss.
loss = 0.0
for prediction in depth_predictions:
if kp_params.per_keypoint_depth:
prediction = tf.expand_dims(prediction, axis=-1)
selected_depths = cn_assigner.get_batch_predictions_from_indices(
prediction, batch_indices)
# The dimensions passed are not as per the doc string but the loss
# still computes the correct value.
unweighted_loss = localization_loss_fn(
selected_depths,
batch_depths,
weights=tf.expand_dims(tf.ones_like(batch_weights), -1))
# Apply the weights after the loss function to have full control over it.
loss += batch_weights * tf.squeeze(unweighted_loss, axis=1)
loss = tf.reduce_sum(loss) / (
float(len(depth_predictions)) *
tf.maximum(tf.reduce_sum(batch_weights), 1.0))
return loss
def _compute_segmentation_losses(self, prediction_dict, per_pixel_weights):
"""Computes all the losses associated with segmentation.
Args:
prediction_dict: The dictionary returned from the predict() method.
per_pixel_weights: A float tensor of shape [batch_size,
out_height * out_width, 1] with 1s in locations where the spatial
coordinates fall within the height and width in true_image_shapes.
Returns:
A dictionary with segmentation losses.
"""
segmentation_heatmap = prediction_dict[SEGMENTATION_HEATMAP]
mask_loss = self._compute_mask_loss(
segmentation_heatmap, per_pixel_weights)
losses = {
SEGMENTATION_HEATMAP: mask_loss
}
return losses
def _compute_mask_loss(self, segmentation_predictions,
per_pixel_weights):
"""Computes the mask loss.
Args:
segmentation_predictions: A list of float32 tensors of shape [batch_size,
out_height, out_width, num_classes].
per_pixel_weights: A float tensor of shape [batch_size,
out_height * out_width, 1] with 1s in locations where the spatial
coordinates fall within the height and width in true_image_shapes.
Returns:
A float scalar tensor representing the mask loss.
"""
gt_boxes_list = self.groundtruth_lists(fields.BoxListFields.boxes)
gt_masks_list = self.groundtruth_lists(fields.BoxListFields.masks)
gt_mask_weights_list = None
if self.groundtruth_has_field(fields.BoxListFields.mask_weights):
gt_mask_weights_list = self.groundtruth_lists(
fields.BoxListFields.mask_weights)
gt_classes_list = self.groundtruth_lists(fields.BoxListFields.classes)
# Convert the groundtruth to targets.
assigner = self._target_assigner_dict[SEGMENTATION_TASK]
heatmap_targets, heatmap_weight = assigner.assign_segmentation_targets(
gt_masks_list=gt_masks_list,
gt_classes_list=gt_classes_list,
gt_boxes_list=gt_boxes_list,
gt_mask_weights_list=gt_mask_weights_list)
flattened_heatmap_targets = _flatten_spatial_dimensions(heatmap_targets)
flattened_heatmap_mask = _flatten_spatial_dimensions(
heatmap_weight[:, :, :, tf.newaxis])
per_pixel_weights *= flattened_heatmap_mask
loss = 0.0
mask_loss_fn = self._mask_params.classification_loss
total_pixels_in_loss = tf.math.maximum(
tf.reduce_sum(per_pixel_weights), 1)
# Loop through each feature output head.
for pred in segmentation_predictions:
pred = _flatten_spatial_dimensions(pred)
loss += mask_loss_fn(
pred, flattened_heatmap_targets, weights=per_pixel_weights)
# TODO(ronnyvotel): Consider other ways to normalize loss.
total_loss = tf.reduce_sum(loss) / (
float(len(segmentation_predictions)) * total_pixels_in_loss)
return total_loss
def _compute_densepose_losses(self, input_height, input_width,
prediction_dict):
"""Computes the weighted DensePose losses.
Args:
input_height: An integer scalar tensor representing input image height.
input_width: An integer scalar tensor representing input image width.
prediction_dict: A dictionary holding predicted tensors output by the
"predict" function. See the "predict" function for more detailed
description.
Returns:
A dictionary of scalar float tensors representing the weighted losses for
the DensePose task:
DENSEPOSE_HEATMAP: the weighted part segmentation loss.
DENSEPOSE_REGRESSION: the weighted part surface coordinate loss.
"""
dp_heatmap_loss, dp_regression_loss = (
self._compute_densepose_part_and_coordinate_losses(
input_height=input_height,
input_width=input_width,
part_predictions=prediction_dict[DENSEPOSE_HEATMAP],
surface_coord_predictions=prediction_dict[DENSEPOSE_REGRESSION]))
loss_dict = {}
loss_dict[DENSEPOSE_HEATMAP] = (
self._densepose_params.part_loss_weight * dp_heatmap_loss)
loss_dict[DENSEPOSE_REGRESSION] = (
self._densepose_params.coordinate_loss_weight * dp_regression_loss)
return loss_dict
def _compute_densepose_part_and_coordinate_losses(
self, input_height, input_width, part_predictions,
surface_coord_predictions):
"""Computes the individual losses for the DensePose task.
Args:
input_height: An integer scalar tensor representing input image height.
input_width: An integer scalar tensor representing input image width.
part_predictions: A list of float tensors of shape [batch_size,
out_height, out_width, num_parts].
surface_coord_predictions: A list of float tensors of shape [batch_size,
out_height, out_width, 2 * num_parts].
Returns:
A tuple with two scalar loss tensors: part_prediction_loss and
surface_coord_loss.
"""
gt_dp_num_points_list = self.groundtruth_lists(
fields.BoxListFields.densepose_num_points)
gt_dp_part_ids_list = self.groundtruth_lists(
fields.BoxListFields.densepose_part_ids)
gt_dp_surface_coords_list = self.groundtruth_lists(
fields.BoxListFields.densepose_surface_coords)
gt_weights_list = self.groundtruth_lists(fields.BoxListFields.weights)
assigner = self._target_assigner_dict[DENSEPOSE_TASK]
batch_indices, batch_part_ids, batch_surface_coords, batch_weights = (
assigner.assign_part_and_coordinate_targets(
height=input_height,
width=input_width,
gt_dp_num_points_list=gt_dp_num_points_list,
gt_dp_part_ids_list=gt_dp_part_ids_list,
gt_dp_surface_coords_list=gt_dp_surface_coords_list,
gt_weights_list=gt_weights_list))
part_prediction_loss = 0
surface_coord_loss = 0
classification_loss_fn = self._densepose_params.classification_loss
localization_loss_fn = self._densepose_params.localization_loss
num_predictions = float(len(part_predictions))
num_valid_points = tf.math.count_nonzero(batch_weights)
num_valid_points = tf.cast(tf.math.maximum(num_valid_points, 1), tf.float32)
for part_pred, surface_coord_pred in zip(part_predictions,
surface_coord_predictions):
# Potentially upsample the feature maps, so that better quality (i.e.
# higher res) groundtruth can be applied.
if self._densepose_params.upsample_to_input_res:
part_pred = tf.keras.layers.UpSampling2D(
self._stride, interpolation=self._densepose_params.upsample_method)(
part_pred)
surface_coord_pred = tf.keras.layers.UpSampling2D(
self._stride, interpolation=self._densepose_params.upsample_method)(
surface_coord_pred)
# Compute the part prediction loss.
part_pred = cn_assigner.get_batch_predictions_from_indices(
part_pred, batch_indices[:, 0:3])
part_prediction_loss += classification_loss_fn(
part_pred[:, tf.newaxis, :],
batch_part_ids[:, tf.newaxis, :],
weights=batch_weights[:, tf.newaxis, tf.newaxis])
# Compute the surface coordinate loss.
batch_size, out_height, out_width, _ = _get_shape(
surface_coord_pred, 4)
surface_coord_pred = tf.reshape(
surface_coord_pred, [batch_size, out_height, out_width, -1, 2])
surface_coord_pred = cn_assigner.get_batch_predictions_from_indices(
surface_coord_pred, batch_indices)
surface_coord_loss += localization_loss_fn(
surface_coord_pred,
batch_surface_coords,
weights=batch_weights[:, tf.newaxis])
part_prediction_loss = tf.reduce_sum(part_prediction_loss) / (
num_predictions * num_valid_points)
surface_coord_loss = tf.reduce_sum(surface_coord_loss) / (
num_predictions * num_valid_points)
return part_prediction_loss, surface_coord_loss
def _compute_track_losses(self, input_height, input_width, prediction_dict):
"""Computes all the losses associated with tracking.
Args:
input_height: An integer scalar tensor representing input image height.
input_width: An integer scalar tensor representing input image width.
prediction_dict: The dictionary returned from the predict() method.
Returns:
A dictionary with tracking losses.
"""
object_reid_predictions = prediction_dict[TRACK_REID]
embedding_loss = self._compute_track_embedding_loss(
input_height=input_height,
input_width=input_width,
object_reid_predictions=object_reid_predictions)
losses = {
TRACK_REID: embedding_loss
}
return losses
def _compute_track_embedding_loss(self, input_height, input_width,
object_reid_predictions):
"""Computes the object ReID loss.
The embedding is trained as a classification task where the target is the
ID of each track among all tracks in the whole dataset.
Args:
input_height: An integer scalar tensor representing input image height.
input_width: An integer scalar tensor representing input image width.
object_reid_predictions: A list of float tensors of shape [batch_size,
out_height, out_width, reid_embed_size] representing the object
embedding feature maps.
Returns:
A float scalar tensor representing the object ReID loss per instance.
"""
gt_track_ids_list = self.groundtruth_lists(fields.BoxListFields.track_ids)
gt_boxes_list = self.groundtruth_lists(fields.BoxListFields.boxes)
gt_weights_list = self.groundtruth_lists(fields.BoxListFields.weights)
num_boxes = _to_float32(get_num_instances_from_weights(gt_weights_list))
# Convert the groundtruth to targets.
assigner = self._target_assigner_dict[TRACK_TASK]
batch_indices, batch_weights, track_targets = assigner.assign_track_targets(
height=input_height,
width=input_width,
gt_track_ids_list=gt_track_ids_list,
gt_boxes_list=gt_boxes_list,
gt_weights_list=gt_weights_list)
batch_weights = tf.expand_dims(batch_weights, -1)
loss = 0.0
object_reid_loss = self._track_params.classification_loss
# Loop through each feature output head.
for pred in object_reid_predictions:
embedding_pred = cn_assigner.get_batch_predictions_from_indices(
pred, batch_indices)
reid_classification = self.track_reid_classification_net(embedding_pred)
loss += object_reid_loss(
reid_classification, track_targets, weights=batch_weights)
loss_per_instance = tf.reduce_sum(loss) / (
float(len(object_reid_predictions)) * num_boxes)
return loss_per_instance
def _compute_temporal_offset_loss(self, input_height,
input_width, prediction_dict):
"""Computes the temporal offset loss for tracking.
Args:
input_height: An integer scalar tensor representing input image height.
input_width: An integer scalar tensor representing input image width.
prediction_dict: The dictionary returned from the predict() method.
Returns:
A dictionary with track/temporal_offset losses.
"""
gt_boxes_list = self.groundtruth_lists(fields.BoxListFields.boxes)
gt_offsets_list = self.groundtruth_lists(
fields.BoxListFields.temporal_offsets)
gt_match_list = self.groundtruth_lists(
fields.BoxListFields.track_match_flags)
gt_weights_list = self.groundtruth_lists(fields.BoxListFields.weights)
num_boxes = tf.cast(
get_num_instances_from_weights(gt_weights_list), tf.float32)
offset_predictions = prediction_dict[TEMPORAL_OFFSET]
num_predictions = float(len(offset_predictions))
assigner = self._target_assigner_dict[TEMPORALOFFSET_TASK]
(batch_indices, batch_offset_targets,
batch_weights) = assigner.assign_temporal_offset_targets(
height=input_height,
width=input_width,
gt_boxes_list=gt_boxes_list,
gt_offsets_list=gt_offsets_list,
gt_match_list=gt_match_list,
gt_weights_list=gt_weights_list)
batch_weights = tf.expand_dims(batch_weights, -1)
offset_loss_fn = self._temporal_offset_params.localization_loss
loss_dict = {}
offset_loss = 0
for offset_pred in offset_predictions:
offset_pred = cn_assigner.get_batch_predictions_from_indices(
offset_pred, batch_indices)
offset_loss += offset_loss_fn(offset_pred[:, None],
batch_offset_targets[:, None],
weights=batch_weights)
offset_loss = tf.reduce_sum(offset_loss) / (num_predictions * num_boxes)
loss_dict[TEMPORAL_OFFSET] = offset_loss
return loss_dict
def _should_clip_keypoints(self):
"""Returns a boolean indicating whether keypoint clipping should occur.
If there is only one keypoint task, clipping is controlled by the field
`clip_out_of_frame_keypoints`. If there are multiple keypoint tasks,
clipping logic is defined based on unanimous agreement of keypoint
parameters. If there is any ambiguity, clip_out_of_frame_keypoints is set
to False (default).
"""
kp_params_iterator = iter(self._kp_params_dict.values())
if len(self._kp_params_dict) == 1:
kp_params = next(kp_params_iterator)
return kp_params.clip_out_of_frame_keypoints
# Multi-task setting.
kp_params = next(kp_params_iterator)
should_clip = kp_params.clip_out_of_frame_keypoints
for kp_params in kp_params_iterator:
if kp_params.clip_out_of_frame_keypoints != should_clip:
return False
return should_clip
def _rescore_instances(self, classes, scores, keypoint_scores):
"""Rescores instances based on detection and keypoint scores.
Args:
classes: A [batch, max_detections] int32 tensor with detection classes.
scores: A [batch, max_detections] float32 tensor with detection scores.
keypoint_scores: A [batch, max_detections, total_num_keypoints] float32
tensor with keypoint scores.
Returns:
A [batch, max_detections] float32 tensor with possibly altered detection
scores.
"""
batch, max_detections, total_num_keypoints = (
shape_utils.combined_static_and_dynamic_shape(keypoint_scores))
classes_tiled = tf.tile(classes[:, :, tf.newaxis],
multiples=[1, 1, total_num_keypoints])
# TODO(yuhuic): Investigate whether this function will create subgraphs in
# tflite that will cause the model to run slower at inference.
for kp_params in self._kp_params_dict.values():
if not kp_params.rescore_instances:
continue
class_id = kp_params.class_id
keypoint_indices = kp_params.keypoint_indices
kpt_mask = tf.reduce_sum(
tf.one_hot(keypoint_indices, depth=total_num_keypoints), axis=0)
kpt_mask_tiled = tf.tile(kpt_mask[tf.newaxis, tf.newaxis, :],
multiples=[batch, max_detections, 1])
class_and_keypoint_mask = tf.math.logical_and(
classes_tiled == class_id,
kpt_mask_tiled == 1.0)
class_and_keypoint_mask_float = tf.cast(class_and_keypoint_mask,
dtype=tf.float32)
visible_keypoints = tf.math.greater(
keypoint_scores, kp_params.rescoring_threshold)
keypoint_scores = tf.where(
visible_keypoints, keypoint_scores, tf.zeros_like(keypoint_scores))
num_visible_keypoints = tf.reduce_sum(
class_and_keypoint_mask_float *
tf.cast(visible_keypoints, tf.float32), axis=-1)
num_visible_keypoints = tf.math.maximum(num_visible_keypoints, 1.0)
scores_for_class = (1./num_visible_keypoints) * (
tf.reduce_sum(class_and_keypoint_mask_float *
scores[:, :, tf.newaxis] *
keypoint_scores, axis=-1))
scores = tf.where(classes == class_id,
scores_for_class,
scores)
return scores
def preprocess(self, inputs):
outputs = shape_utils.resize_images_and_return_shapes(
inputs, self._image_resizer_fn)
resized_inputs, true_image_shapes = outputs
return (self._feature_extractor.preprocess(resized_inputs),
true_image_shapes)
def predict(self, preprocessed_inputs, _):
"""Predicts CenterNet prediction tensors given an input batch.
Feature extractors are free to produce predictions from multiple feature
maps and therefore we return a dictionary mapping strings to lists.
E.g. the hourglass backbone produces two feature maps.
Args:
preprocessed_inputs: a [batch, height, width, channels] float32 tensor
representing a batch of images.
Returns:
prediction_dict: a dictionary holding predicted tensors with
'preprocessed_inputs' - The input image after being resized and
preprocessed by the feature extractor.
'extracted_features' - The output of the feature extractor.
'object_center' - A list of size num_feature_outputs containing
float tensors of size [batch_size, output_height, output_width,
num_classes] representing the predicted object center heatmap logits.
'box/scale' - [optional] A list of size num_feature_outputs holding
float tensors of size [batch_size, output_height, output_width, 2]
representing the predicted box height and width at each output
location. This field exists only when object detection task is
specified.
'box/offset' - [optional] A list of size num_feature_outputs holding
float tensors of size [batch_size, output_height, output_width, 2]
representing the predicted y and x offsets at each output location.
'$TASK_NAME/keypoint_heatmap' - [optional] A list of size
num_feature_outputs holding float tensors of size [batch_size,
output_height, output_width, num_keypoints] representing the predicted
keypoint heatmap logits.
'$TASK_NAME/keypoint_offset' - [optional] A list of size
num_feature_outputs holding float tensors of size [batch_size,
output_height, output_width, 2] representing the predicted keypoint
offsets at each output location.
'$TASK_NAME/keypoint_regression' - [optional] A list of size
num_feature_outputs holding float tensors of size [batch_size,
output_height, output_width, 2 * num_keypoints] representing the
predicted keypoint regression at each output location.
'segmentation/heatmap' - [optional] A list of size num_feature_outputs
holding float tensors of size [batch_size, output_height,
output_width, num_classes] representing the mask logits.
'densepose/heatmap' - [optional] A list of size num_feature_outputs
holding float tensors of size [batch_size, output_height,
output_width, num_parts] representing the mask logits for each part.
'densepose/regression' - [optional] A list of size num_feature_outputs
holding float tensors of size [batch_size, output_height,
output_width, 2 * num_parts] representing the DensePose surface
coordinate predictions.
Note the $TASK_NAME is provided by the KeypointEstimation namedtuple
used to differentiate between different keypoint tasks.
"""
features_list = self._feature_extractor(preprocessed_inputs)
predictions = {}
for head_name, heads in self._prediction_head_dict.items():
predictions[head_name] = [
head(feature) for (feature, head) in zip(features_list, heads)
]
predictions['extracted_features'] = features_list
predictions['preprocessed_inputs'] = preprocessed_inputs
self._batched_prediction_tensor_names = predictions.keys()
return predictions
def loss(
self, prediction_dict, true_image_shapes, scope=None,
maximum_normalized_coordinate=1.1):
"""Computes scalar loss tensors with respect to provided groundtruth.
This function implements the various CenterNet losses.
Args:
prediction_dict: a dictionary holding predicted tensors returned by
"predict" function.
true_image_shapes: int32 tensor of shape [batch, 3] where each row is of
the form [height, width, channels] indicating the shapes of true images
in the resized images, as resized images can be padded with zeros.
scope: Optional scope name.
maximum_normalized_coordinate: Maximum coordinate value to be considered
as normalized, default to 1.1. This is used to check bounds during
converting normalized coordinates to absolute coordinates.
Returns:
A dictionary mapping the keys [
'Loss/object_center',
'Loss/box/scale', (optional)
'Loss/box/offset', (optional)
'Loss/$TASK_NAME/keypoint/heatmap', (optional)
'Loss/$TASK_NAME/keypoint/offset', (optional)
'Loss/$TASK_NAME/keypoint/regression', (optional)
'Loss/segmentation/heatmap', (optional)
'Loss/densepose/heatmap', (optional)
'Loss/densepose/regression', (optional)
'Loss/track/reid'] (optional)
'Loss/track/offset'] (optional)
scalar tensors corresponding to the losses for different tasks. Note the
$TASK_NAME is provided by the KeypointEstimation namedtuple used to
differentiate between different keypoint tasks.
"""
_, input_height, input_width, _ = _get_shape(
prediction_dict['preprocessed_inputs'], 4)
output_height, output_width = (tf.maximum(input_height // self._stride, 1),
tf.maximum(input_width // self._stride, 1))
# TODO(vighneshb) Explore whether using floor here is safe.
output_true_image_shapes = tf.ceil(
tf.cast(true_image_shapes, tf.float32) / self._stride)
valid_anchor_weights = get_valid_anchor_weights_in_flattened_image(
output_true_image_shapes, output_height, output_width)
valid_anchor_weights = tf.expand_dims(valid_anchor_weights, 2)
object_center_loss = self._compute_object_center_loss(
object_center_predictions=prediction_dict[OBJECT_CENTER],
input_height=input_height,
input_width=input_width,
per_pixel_weights=valid_anchor_weights,
maximum_normalized_coordinate=maximum_normalized_coordinate)
losses = {
OBJECT_CENTER:
self._center_params.object_center_loss_weight * object_center_loss
}
if self._od_params is not None:
od_losses = self._compute_object_detection_losses(
input_height=input_height,
input_width=input_width,
prediction_dict=prediction_dict,
per_pixel_weights=valid_anchor_weights,
maximum_normalized_coordinate=maximum_normalized_coordinate)
for key in od_losses:
od_losses[key] = od_losses[key] * self._od_params.task_loss_weight
losses.update(od_losses)
if self._kp_params_dict is not None:
for task_name, params in self._kp_params_dict.items():
kp_losses = self._compute_keypoint_estimation_losses(
task_name=task_name,
input_height=input_height,
input_width=input_width,
prediction_dict=prediction_dict,
per_pixel_weights=valid_anchor_weights)
for key in kp_losses:
kp_losses[key] = kp_losses[key] * params.task_loss_weight
losses.update(kp_losses)
if self._mask_params is not None:
seg_losses = self._compute_segmentation_losses(
prediction_dict=prediction_dict,
per_pixel_weights=valid_anchor_weights)
for key in seg_losses:
seg_losses[key] = seg_losses[key] * self._mask_params.task_loss_weight
losses.update(seg_losses)
if self._densepose_params is not None:
densepose_losses = self._compute_densepose_losses(
input_height=input_height,
input_width=input_width,
prediction_dict=prediction_dict)
for key in densepose_losses:
densepose_losses[key] = (
densepose_losses[key] * self._densepose_params.task_loss_weight)
losses.update(densepose_losses)
if self._track_params is not None:
track_losses = self._compute_track_losses(
input_height=input_height,
input_width=input_width,
prediction_dict=prediction_dict)
for key in track_losses:
track_losses[key] = (
track_losses[key] * self._track_params.task_loss_weight)
losses.update(track_losses)
if self._temporal_offset_params is not None:
offset_losses = self._compute_temporal_offset_loss(
input_height=input_height,
input_width=input_width,
prediction_dict=prediction_dict)
for key in offset_losses:
offset_losses[key] = (
offset_losses[key] * self._temporal_offset_params.task_loss_weight)
losses.update(offset_losses)
# Prepend the LOSS_KEY_PREFIX to the keys in the dictionary such that the
# losses will be grouped together in Tensorboard.
return dict([('%s/%s' % (LOSS_KEY_PREFIX, key), val)
for key, val in losses.items()])
def postprocess(self, prediction_dict, true_image_shapes, **params):
"""Produces boxes given a prediction dict returned by predict().
Although predict returns a list of tensors, only the last tensor in
each list is used for making box predictions.
Args:
prediction_dict: a dictionary holding predicted tensors from "predict"
function.
true_image_shapes: int32 tensor of shape [batch, 3] where each row is of
the form [height, width, channels] indicating the shapes of true images
in the resized images, as resized images can be padded with zeros.
**params: Currently ignored.
Returns:
detections: a dictionary containing the following fields
detection_boxes - A tensor of shape [batch, max_detections, 4]
holding the predicted boxes.
detection_boxes_strided: A tensor of shape [batch_size, num_detections,
4] holding the predicted boxes in absolute coordinates of the
feature extractor's final layer output.
detection_scores: A tensor of shape [batch, max_detections] holding
the predicted score for each box.
detection_multiclass_scores: A tensor of shape [batch, max_detection,
num_classes] holding multiclass score for each box.
detection_classes: An integer tensor of shape [batch, max_detections]
containing the detected class for each box.
num_detections: An integer tensor of shape [batch] containing the
number of detected boxes for each sample in the batch.
detection_keypoints: (Optional) A float tensor of shape [batch,
max_detections, num_keypoints, 2] with normalized keypoints. Any
invalid keypoints have their coordinates and scores set to 0.0.
detection_keypoint_scores: (Optional) A float tensor of shape [batch,
max_detection, num_keypoints] with scores for each keypoint.
detection_masks: (Optional) A uint8 tensor of shape [batch,
max_detections, mask_height, mask_width] with masks for each
detection. Background is specified with 0, and foreground is specified
with positive integers (1 for standard instance segmentation mask, and
1-indexed parts for DensePose task).
detection_surface_coords: (Optional) A float32 tensor of shape [batch,
max_detection, mask_height, mask_width, 2] with DensePose surface
coordinates, in (v, u) format.
detection_embeddings: (Optional) A float tensor of shape [batch,
max_detections, reid_embed_size] containing object embeddings.
"""
object_center_prob = tf.nn.sigmoid(prediction_dict[OBJECT_CENTER][-1])
if true_image_shapes is None:
# If true_image_shapes is not provided, we assume the whole image is valid
# and infer the true_image_shapes from the object_center_prob shape.
batch_size, strided_height, strided_width, _ = _get_shape(
object_center_prob, 4)
true_image_shapes = tf.stack(
[strided_height * self._stride, strided_width * self._stride,
tf.constant(len(self._feature_extractor._channel_means))]) # pylint: disable=protected-access
true_image_shapes = tf.stack([true_image_shapes] * batch_size, axis=0)
else:
# Mask object centers by true_image_shape. [batch, h, w, 1]
object_center_mask = mask_from_true_image_shape(
_get_shape(object_center_prob, 4), true_image_shapes)
object_center_prob *= object_center_mask
# Get x, y and channel indices corresponding to the top indices in the class
# center predictions.
detection_scores, y_indices, x_indices, channel_indices = (
top_k_feature_map_locations(
object_center_prob,
max_pool_kernel_size=self._center_params.peak_max_pool_kernel_size,
k=self._center_params.max_box_predictions))
multiclass_scores = tf.gather_nd(
object_center_prob, tf.stack([y_indices, x_indices], -1), batch_dims=1)
num_detections = tf.reduce_sum(
tf.cast(detection_scores > 0, tf.int32), axis=1)
postprocess_dict = {
fields.DetectionResultFields.detection_scores: detection_scores,
fields.DetectionResultFields.detection_multiclass_scores:
multiclass_scores,
fields.DetectionResultFields.detection_classes: channel_indices,
fields.DetectionResultFields.num_detections: num_detections,
}
if self._output_prediction_dict:
postprocess_dict.update(prediction_dict)
postprocess_dict['true_image_shapes'] = true_image_shapes
boxes_strided = None
if self._od_params:
boxes_strided = (
prediction_tensors_to_boxes(y_indices, x_indices,
prediction_dict[BOX_SCALE][-1],
prediction_dict[BOX_OFFSET][-1]))
boxes = convert_strided_predictions_to_normalized_boxes(
boxes_strided, self._stride, true_image_shapes)
postprocess_dict.update({
fields.DetectionResultFields.detection_boxes: boxes,
'detection_boxes_strided': boxes_strided,
})
if self._kp_params_dict:
# If the model is trained to predict only one class of object and its
# keypoint, we fall back to a simpler postprocessing function which uses
# the ops that are supported by tf.lite on GPU.
clip_keypoints = self._should_clip_keypoints()
if len(self._kp_params_dict) == 1 and self._num_classes == 1:
task_name, kp_params = next(iter(self._kp_params_dict.items()))
keypoint_depths = None
if kp_params.argmax_postprocessing:
keypoints, keypoint_scores = (
prediction_to_keypoints_argmax(
prediction_dict,
object_y_indices=y_indices,
object_x_indices=x_indices,
boxes=boxes_strided,
task_name=task_name,
kp_params=kp_params))
else:
(keypoints, keypoint_scores,
keypoint_depths) = self._postprocess_keypoints_single_class(
prediction_dict, channel_indices, y_indices, x_indices,
boxes_strided, num_detections)
keypoints, keypoint_scores = (
convert_strided_predictions_to_normalized_keypoints(
keypoints, keypoint_scores, self._stride, true_image_shapes,
clip_out_of_frame_keypoints=clip_keypoints))
if keypoint_depths is not None:
postprocess_dict.update({
fields.DetectionResultFields.detection_keypoint_depths:
keypoint_depths
})
else:
# Multi-class keypoint estimation task does not support depth
# estimation.
assert all([
not kp_dict.predict_depth
for kp_dict in self._kp_params_dict.values()
])
keypoints, keypoint_scores = self._postprocess_keypoints_multi_class(
prediction_dict, channel_indices, y_indices, x_indices,
boxes_strided, num_detections)
keypoints, keypoint_scores = (
convert_strided_predictions_to_normalized_keypoints(
keypoints, keypoint_scores, self._stride, true_image_shapes,
clip_out_of_frame_keypoints=clip_keypoints))
postprocess_dict.update({
fields.DetectionResultFields.detection_keypoints: keypoints,
fields.DetectionResultFields.detection_keypoint_scores:
keypoint_scores
})
if self._od_params is None:
# Still output the box prediction by enclosing the keypoints for
# evaluation purpose.
boxes = keypoint_ops.keypoints_to_enclosing_bounding_boxes(
keypoints, keypoints_axis=2)
postprocess_dict.update({
fields.DetectionResultFields.detection_boxes: boxes,
})
if self._mask_params:
masks = tf.nn.sigmoid(prediction_dict[SEGMENTATION_HEATMAP][-1])
densepose_part_heatmap, densepose_surface_coords = None, None
densepose_class_index = 0
if self._densepose_params:
densepose_part_heatmap = prediction_dict[DENSEPOSE_HEATMAP][-1]
densepose_surface_coords = prediction_dict[DENSEPOSE_REGRESSION][-1]
densepose_class_index = self._densepose_params.class_id
instance_masks, surface_coords = (
convert_strided_predictions_to_instance_masks(
boxes, channel_indices, masks, true_image_shapes,
densepose_part_heatmap, densepose_surface_coords,
stride=self._stride, mask_height=self._mask_params.mask_height,
mask_width=self._mask_params.mask_width,
score_threshold=self._mask_params.score_threshold,
densepose_class_index=densepose_class_index))
postprocess_dict[
fields.DetectionResultFields.detection_masks] = instance_masks
if self._densepose_params:
postprocess_dict[
fields.DetectionResultFields.detection_surface_coords] = (
surface_coords)
if self._track_params:
embeddings = self._postprocess_embeddings(prediction_dict,
y_indices, x_indices)
postprocess_dict.update({
fields.DetectionResultFields.detection_embeddings: embeddings
})
if self._temporal_offset_params:
offsets = prediction_tensors_to_temporal_offsets(
y_indices, x_indices,
prediction_dict[TEMPORAL_OFFSET][-1])
postprocess_dict[fields.DetectionResultFields.detection_offsets] = offsets
if self._non_max_suppression_fn:
boxes = tf.expand_dims(
postprocess_dict.pop(fields.DetectionResultFields.detection_boxes),
axis=-2)
multiclass_scores = postprocess_dict[
fields.DetectionResultFields.detection_multiclass_scores]
num_classes = tf.shape(multiclass_scores)[2]
class_mask = tf.cast(
tf.one_hot(
postprocess_dict[fields.DetectionResultFields.detection_classes],
depth=num_classes), tf.bool)
# Surpress the scores of those unselected classes to be zeros. Otherwise,
# the downstream NMS ops might be confused and introduce issues.
multiclass_scores = tf.where(
class_mask, multiclass_scores, tf.zeros_like(multiclass_scores))
num_valid_boxes = postprocess_dict.pop(
fields.DetectionResultFields.num_detections)
# Remove scores and classes as NMS will compute these form multiclass
# scores.
postprocess_dict.pop(fields.DetectionResultFields.detection_scores)
postprocess_dict.pop(fields.DetectionResultFields.detection_classes)
(nmsed_boxes, nmsed_scores, nmsed_classes, _, nmsed_additional_fields,
num_detections) = self._non_max_suppression_fn(
boxes,
multiclass_scores,
additional_fields=postprocess_dict,
num_valid_boxes=num_valid_boxes)
postprocess_dict = nmsed_additional_fields
postprocess_dict[
fields.DetectionResultFields.detection_boxes] = nmsed_boxes
postprocess_dict[
fields.DetectionResultFields.detection_scores] = nmsed_scores
postprocess_dict[
fields.DetectionResultFields.detection_classes] = nmsed_classes
postprocess_dict[
fields.DetectionResultFields.num_detections] = num_detections
postprocess_dict.update(nmsed_additional_fields)
# Perform the rescoring once the NMS is applied to make sure the rescored
# scores won't be washed out by the NMS function.
if self._kp_params_dict:
channel_indices = postprocess_dict[
fields.DetectionResultFields.detection_classes]
detection_scores = postprocess_dict[
fields.DetectionResultFields.detection_scores]
keypoint_scores = postprocess_dict[
fields.DetectionResultFields.detection_keypoint_scores]
# Update instance scores based on keypoints.
scores = self._rescore_instances(
channel_indices, detection_scores, keypoint_scores)
postprocess_dict.update({
fields.DetectionResultFields.detection_scores: scores,
})
return postprocess_dict
def postprocess_single_instance_keypoints(
self,
prediction_dict,
true_image_shapes):
"""Postprocess for predicting single instance keypoints.
This postprocess function is a special case of predicting the keypoint of
a single instance in the image (original CenterNet postprocess supports
multi-instance prediction). Due to the simplification assumption, this
postprocessing function achieves much faster inference time.
Here is a short list of the modifications made in this function:
1) Assume the model predicts only single class keypoint.
2) Assume there is only one instance in the image. If multiple instances
appear in the image, the model tends to predict the one that is closer
to the image center (the other ones are considered as background and
are rejected by the model).
3) Avoid using top_k ops in the postprocessing logics since it is slower
than using argmax.
4) The predictions other than the keypoints are ignored, e.g. boxes.
5) The input batch size is assumed to be 1.
Args:
prediction_dict: a dictionary holding predicted tensors from "predict"
function.
true_image_shapes: int32 tensor of shape [batch, 3] where each row is of
the form [height, width, channels] indicating the shapes of true images
in the resized images, as resized images can be padded with zeros.
Returns:
detections: a dictionary containing the following fields
detection_keypoints: A float tensor of shape
[1, 1, num_keypoints, 2] with normalized keypoints. Any invalid
keypoints have their coordinates and scores set to 0.0.
detection_keypoint_scores: A float tensor of shape
[1, 1, num_keypoints] with scores for each keypoint.
"""
# The number of keypoint task is expected to be 1.
assert len(self._kp_params_dict) == 1
task_name, kp_params = next(iter(self._kp_params_dict.items()))
keypoint_heatmap = tf.nn.sigmoid(prediction_dict[get_keypoint_name(
task_name, KEYPOINT_HEATMAP)][-1])
keypoint_offset = prediction_dict[get_keypoint_name(task_name,
KEYPOINT_OFFSET)][-1]
keypoint_regression = prediction_dict[get_keypoint_name(
task_name, KEYPOINT_REGRESSION)][-1]
object_heatmap = tf.nn.sigmoid(prediction_dict[OBJECT_CENTER][-1])
keypoint_depths = None
if kp_params.predict_depth:
keypoint_depths = prediction_dict[get_keypoint_name(
task_name, KEYPOINT_DEPTH)][-1]
keypoints, keypoint_scores, keypoint_depths = (
prediction_to_single_instance_keypoints(
object_heatmap=object_heatmap,
keypoint_heatmap=keypoint_heatmap,
keypoint_offset=keypoint_offset,
keypoint_regression=keypoint_regression,
kp_params=kp_params,
keypoint_depths=keypoint_depths))
keypoints, keypoint_scores = (
convert_strided_predictions_to_normalized_keypoints(
keypoints,
keypoint_scores,
self._stride,
true_image_shapes,
clip_out_of_frame_keypoints=False))
postprocess_dict = {
fields.DetectionResultFields.detection_keypoints: keypoints,
fields.DetectionResultFields.detection_keypoint_scores: keypoint_scores
}
if kp_params.predict_depth:
postprocess_dict.update({
fields.DetectionResultFields.detection_keypoint_depths:
keypoint_depths
})
return postprocess_dict
def _postprocess_embeddings(self, prediction_dict, y_indices, x_indices):
"""Performs postprocessing on embedding predictions.
Args:
prediction_dict: a dictionary holding predicted tensors, returned from the
predict() method. This dictionary should contain embedding prediction
feature maps for tracking task.
y_indices: A [batch_size, max_detections] int tensor with y indices for
all object centers.
x_indices: A [batch_size, max_detections] int tensor with x indices for
all object centers.
Returns:
embeddings: A [batch_size, max_detection, reid_embed_size] float32
tensor with L2 normalized embeddings extracted from detection box
centers.
"""
embedding_predictions = prediction_dict[TRACK_REID][-1]
embeddings = predicted_embeddings_at_object_centers(
embedding_predictions, y_indices, x_indices)
embeddings, _ = tf.linalg.normalize(embeddings, axis=-1)
return embeddings
def _scatter_keypoints_to_batch(self, num_ind, kpt_coords_for_example,
kpt_scores_for_example,
instance_inds_for_example, max_detections,
total_num_keypoints):
"""Helper function to convert scattered keypoints into batch."""
def left_fn(kpt_coords_for_example, kpt_scores_for_example,
instance_inds_for_example):
# Scatter into tensor where instances align with original detection
# instances. New shape of keypoint coordinates and scores are
# [1, max_detections, num_total_keypoints, 2] and
# [1, max_detections, num_total_keypoints], respectively.
return _pad_to_full_instance_dim(
kpt_coords_for_example, kpt_scores_for_example,
instance_inds_for_example,
self._center_params.max_box_predictions)
def right_fn():
kpt_coords_for_example_all_det = tf.zeros(
[1, max_detections, total_num_keypoints, 2], dtype=tf.float32)
kpt_scores_for_example_all_det = tf.zeros(
[1, max_detections, total_num_keypoints], dtype=tf.float32)
return (kpt_coords_for_example_all_det,
kpt_scores_for_example_all_det)
left_fn = functools.partial(left_fn, kpt_coords_for_example,
kpt_scores_for_example,
instance_inds_for_example)
# Use dimension values instead of tf.size for tf.lite compatibility.
return tf.cond(num_ind[0] > 0, left_fn, right_fn)
def _postprocess_keypoints_multi_class(self, prediction_dict, classes,
y_indices, x_indices, boxes,
num_detections):
"""Performs postprocessing on keypoint predictions.
This is the most general keypoint postprocessing function which supports
multiple keypoint tasks (e.g. human and dog keypoints) and multiple object
detection classes. Note that it is the most expensive postprocessing logics
and is currently not tf.lite/tf.js compatible. See
_postprocess_keypoints_single_class if you plan to export the model in more
portable format.
Args:
prediction_dict: a dictionary holding predicted tensors, returned from the
predict() method. This dictionary should contain keypoint prediction
feature maps for each keypoint task.
classes: A [batch_size, max_detections] int tensor with class indices for
all detected objects.
y_indices: A [batch_size, max_detections] int tensor with y indices for
all object centers.
x_indices: A [batch_size, max_detections] int tensor with x indices for
all object centers.
boxes: A [batch_size, max_detections, 4] float32 tensor with bounding
boxes in (un-normalized) output space.
num_detections: A [batch_size] int tensor with the number of valid
detections for each image.
Returns:
A tuple of
keypoints: a [batch_size, max_detection, num_total_keypoints, 2] float32
tensor with keypoints in the output (strided) coordinate frame.
keypoint_scores: a [batch_size, max_detections, num_total_keypoints]
float32 tensor with keypoint scores.
"""
total_num_keypoints = sum(len(kp_dict.keypoint_indices) for kp_dict
in self._kp_params_dict.values())
batch_size, max_detections = _get_shape(classes, 2)
kpt_coords_for_example_list = []
kpt_scores_for_example_list = []
for ex_ind in range(batch_size):
# The tensors that host the keypoint coordinates and scores for all
# instances and all keypoints. They will be updated by scatter_nd_add for
# each keypoint tasks.
kpt_coords_for_example_all_det = tf.zeros(
[max_detections, total_num_keypoints, 2])
kpt_scores_for_example_all_det = tf.zeros(
[max_detections, total_num_keypoints])
for task_name, kp_params in self._kp_params_dict.items():
keypoint_heatmap = prediction_dict[
get_keypoint_name(task_name, KEYPOINT_HEATMAP)][-1]
keypoint_offsets = prediction_dict[
get_keypoint_name(task_name, KEYPOINT_OFFSET)][-1]
keypoint_regression = prediction_dict[
get_keypoint_name(task_name, KEYPOINT_REGRESSION)][-1]
instance_inds = self._get_instance_indices(
classes, num_detections, ex_ind, kp_params.class_id)
# Gather the feature map locations corresponding to the object class.
y_indices_for_kpt_class = tf.gather(y_indices, instance_inds, axis=1)
x_indices_for_kpt_class = tf.gather(x_indices, instance_inds, axis=1)
if boxes is None:
boxes_for_kpt_class = None
else:
boxes_for_kpt_class = tf.gather(boxes, instance_inds, axis=1)
# Postprocess keypoints and scores for class and single image. Shapes
# are [1, num_instances_i, num_keypoints_i, 2] and
# [1, num_instances_i, num_keypoints_i], respectively. Note that
# num_instances_i and num_keypoints_i refers to the number of
# instances and keypoints for class i, respectively.
(kpt_coords_for_class, kpt_scores_for_class, _) = (
self._postprocess_keypoints_for_class_and_image(
keypoint_heatmap,
keypoint_offsets,
keypoint_regression,
classes,
y_indices_for_kpt_class,
x_indices_for_kpt_class,
boxes_for_kpt_class,
ex_ind,
kp_params,
))
# Prepare the indices for scatter_nd. The resulting combined_inds has
# the shape of [num_instances_i * num_keypoints_i, 2], where the first
# column corresponds to the instance IDs and the second column
# corresponds to the keypoint IDs.
kpt_inds = tf.constant(kp_params.keypoint_indices, dtype=tf.int32)
kpt_inds = tf.expand_dims(kpt_inds, axis=0)
instance_inds_expand = tf.expand_dims(instance_inds, axis=-1)
kpt_inds_expand = kpt_inds * tf.ones_like(instance_inds_expand)
instance_inds_expand = instance_inds_expand * tf.ones_like(kpt_inds)
combined_inds = tf.stack(
[instance_inds_expand, kpt_inds_expand], axis=2)
combined_inds = tf.reshape(combined_inds, [-1, 2])
# Reshape the keypoint coordinates/scores to [num_instances_i *
# num_keypoints_i, 2]/[num_instances_i * num_keypoints_i] to be used
# by scatter_nd_add.
kpt_coords_for_class = tf.reshape(kpt_coords_for_class, [-1, 2])
kpt_scores_for_class = tf.reshape(kpt_scores_for_class, [-1])
kpt_coords_for_example_all_det = tf.tensor_scatter_nd_add(
kpt_coords_for_example_all_det,
combined_inds, kpt_coords_for_class)
kpt_scores_for_example_all_det = tf.tensor_scatter_nd_add(
kpt_scores_for_example_all_det,
combined_inds, kpt_scores_for_class)
kpt_coords_for_example_list.append(
tf.expand_dims(kpt_coords_for_example_all_det, axis=0))
kpt_scores_for_example_list.append(
tf.expand_dims(kpt_scores_for_example_all_det, axis=0))
# Concatenate all keypoints and scores from all examples in the batch.
# Shapes are [batch_size, max_detections, num_total_keypoints, 2] and
# [batch_size, max_detections, num_total_keypoints], respectively.
keypoints = tf.concat(kpt_coords_for_example_list, axis=0)
keypoint_scores = tf.concat(kpt_scores_for_example_list, axis=0)
return keypoints, keypoint_scores
def _postprocess_keypoints_single_class(self, prediction_dict, classes,
y_indices, x_indices, boxes,
num_detections):
"""Performs postprocessing on keypoint predictions (single class only).
This function handles the special case of keypoint task that the model
predicts only one class of the bounding box/keypoint (e.g. person). By the
assumption, the function uses only tf.lite supported ops and should run
faster.
Args:
prediction_dict: a dictionary holding predicted tensors, returned from the
predict() method. This dictionary should contain keypoint prediction
feature maps for each keypoint task.
classes: A [batch_size, max_detections] int tensor with class indices for
all detected objects.
y_indices: A [batch_size, max_detections] int tensor with y indices for
all object centers.
x_indices: A [batch_size, max_detections] int tensor with x indices for
all object centers.
boxes: A [batch_size, max_detections, 4] float32 tensor with bounding
boxes in (un-normalized) output space.
num_detections: A [batch_size] int tensor with the number of valid
detections for each image.
Returns:
A tuple of
keypoints: a [batch_size, max_detection, num_total_keypoints, 2] float32
tensor with keypoints in the output (strided) coordinate frame.
keypoint_scores: a [batch_size, max_detections, num_total_keypoints]
float32 tensor with keypoint scores.
"""
# This function only works when there is only one keypoint task and the
# number of classes equal to one. For more general use cases, please use
# _postprocess_keypoints instead.
assert len(self._kp_params_dict) == 1 and self._num_classes == 1
task_name, kp_params = next(iter(self._kp_params_dict.items()))
keypoint_heatmap = prediction_dict[
get_keypoint_name(task_name, KEYPOINT_HEATMAP)][-1]
keypoint_offsets = prediction_dict[
get_keypoint_name(task_name, KEYPOINT_OFFSET)][-1]
keypoint_regression = prediction_dict[
get_keypoint_name(task_name, KEYPOINT_REGRESSION)][-1]
keypoint_depth_predictions = None
if kp_params.predict_depth:
keypoint_depth_predictions = prediction_dict[get_keypoint_name(
task_name, KEYPOINT_DEPTH)][-1]
batch_size, _ = _get_shape(classes, 2)
kpt_coords_for_example_list = []
kpt_scores_for_example_list = []
kpt_depths_for_example_list = []
for ex_ind in range(batch_size):
# Postprocess keypoints and scores for class and single image. Shapes
# are [1, max_detections, num_keypoints, 2] and
# [1, max_detections, num_keypoints], respectively.
(kpt_coords_for_class, kpt_scores_for_class, kpt_depths_for_class) = (
self._postprocess_keypoints_for_class_and_image(
keypoint_heatmap,
keypoint_offsets,
keypoint_regression,
classes,
y_indices,
x_indices,
boxes,
ex_ind,
kp_params,
keypoint_depth_predictions=keypoint_depth_predictions))
kpt_coords_for_example_list.append(kpt_coords_for_class)
kpt_scores_for_example_list.append(kpt_scores_for_class)
kpt_depths_for_example_list.append(kpt_depths_for_class)
# Concatenate all keypoints and scores from all examples in the batch.
# Shapes are [batch_size, max_detections, num_keypoints, 2] and
# [batch_size, max_detections, num_keypoints], respectively.
keypoints = tf.concat(kpt_coords_for_example_list, axis=0)
keypoint_scores = tf.concat(kpt_scores_for_example_list, axis=0)
keypoint_depths = None
if kp_params.predict_depth:
keypoint_depths = tf.concat(kpt_depths_for_example_list, axis=0)
return keypoints, keypoint_scores, keypoint_depths
def _get_instance_indices(self, classes, num_detections, batch_index,
class_id):
"""Gets the instance indices that match the target class ID.
Args:
classes: A [batch_size, max_detections] int tensor with class indices for
all detected objects.
num_detections: A [batch_size] int tensor with the number of valid
detections for each image.
batch_index: An integer specifying the index for an example in the batch.
class_id: Class id
Returns:
instance_inds: A [num_instances] int32 tensor where each element indicates
the instance location within the `classes` tensor. This is useful to
associate the refined keypoints with the original detections (i.e.
boxes)
"""
classes = classes[batch_index:batch_index+1, ...]
_, max_detections = shape_utils.combined_static_and_dynamic_shape(
classes)
# Get the detection indices corresponding to the target class.
# Call tf.math.equal with matched tensor shape to make it tf.lite
# compatible.
valid_detections_with_kpt_class = tf.math.logical_and(
tf.range(max_detections) < num_detections[batch_index],
tf.math.equal(classes[0], tf.fill(classes[0].shape, class_id)))
instance_inds = tf.where(valid_detections_with_kpt_class)[:, 0]
# Cast the indices tensor to int32 for tf.lite compatibility.
return tf.cast(instance_inds, tf.int32)
def _postprocess_keypoints_for_class_and_image(
self,
keypoint_heatmap,
keypoint_offsets,
keypoint_regression,
classes,
y_indices,
x_indices,
boxes,
batch_index,
kp_params,
keypoint_depth_predictions=None):
"""Postprocess keypoints for a single image and class.
Args:
keypoint_heatmap: A [batch_size, height, width, num_keypoints] float32
tensor with keypoint heatmaps.
keypoint_offsets: A [batch_size, height, width, 2] float32 tensor with
local offsets to keypoint centers.
keypoint_regression: A [batch_size, height, width, 2 * num_keypoints]
float32 tensor with regressed offsets to all keypoints.
classes: A [batch_size, max_detections] int tensor with class indices for
all detected objects.
y_indices: A [batch_size, max_detections] int tensor with y indices for
all object centers.
x_indices: A [batch_size, max_detections] int tensor with x indices for
all object centers.
boxes: A [batch_size, max_detections, 4] float32 tensor with detected
boxes in the output (strided) frame.
batch_index: An integer specifying the index for an example in the batch.
kp_params: A `KeypointEstimationParams` object with parameters for a
single keypoint class.
keypoint_depth_predictions: (optional) A [batch_size, height, width, 1]
float32 tensor representing the keypoint depth prediction.
Returns:
A tuple of
refined_keypoints: A [1, num_instances, num_keypoints, 2] float32 tensor
with refined keypoints for a single class in a single image, expressed
in the output (strided) coordinate frame. Note that `num_instances` is a
dynamic dimension, and corresponds to the number of valid detections
for the specific class.
refined_scores: A [1, num_instances, num_keypoints] float32 tensor with
keypoint scores.
refined_depths: A [1, num_instances, num_keypoints] float32 tensor with
keypoint depths. Return None if the input keypoint_depth_predictions is
None.
"""
num_keypoints = len(kp_params.keypoint_indices)
keypoint_heatmap = tf.nn.sigmoid(
keypoint_heatmap[batch_index:batch_index+1, ...])
keypoint_offsets = keypoint_offsets[batch_index:batch_index+1, ...]
keypoint_regression = keypoint_regression[batch_index:batch_index+1, ...]
keypoint_depths = None
if keypoint_depth_predictions is not None:
keypoint_depths = keypoint_depth_predictions[batch_index:batch_index + 1,
...]
y_indices = y_indices[batch_index:batch_index+1, ...]
x_indices = x_indices[batch_index:batch_index+1, ...]
if boxes is None:
boxes_slice = None
else:
boxes_slice = boxes[batch_index:batch_index+1, ...]
# Gather the regressed keypoints. Final tensor has shape
# [1, num_instances, num_keypoints, 2].
regressed_keypoints_for_objects = regressed_keypoints_at_object_centers(
keypoint_regression, y_indices, x_indices)
regressed_keypoints_for_objects = tf.reshape(
regressed_keypoints_for_objects, [1, -1, num_keypoints, 2])
# Get the candidate keypoints and scores.
# The shape of keypoint_candidates and keypoint_scores is:
# [1, num_candidates_per_keypoint, num_keypoints, 2] and
# [1, num_candidates_per_keypoint, num_keypoints], respectively.
(keypoint_candidates, keypoint_scores, num_keypoint_candidates,
keypoint_depth_candidates) = (
prediction_tensors_to_keypoint_candidates(
keypoint_heatmap,
keypoint_offsets,
keypoint_score_threshold=(
kp_params.keypoint_candidate_score_threshold),
max_pool_kernel_size=kp_params.peak_max_pool_kernel_size,
max_candidates=kp_params.num_candidates_per_keypoint,
keypoint_depths=keypoint_depths))
kpts_std_dev_postprocess = [
s * kp_params.std_dev_multiplier for s in kp_params.keypoint_std_dev
]
# Get the refined keypoints and scores, of shape
# [1, num_instances, num_keypoints, 2] and
# [1, num_instances, num_keypoints], respectively.
(refined_keypoints, refined_scores, refined_depths) = refine_keypoints(
regressed_keypoints_for_objects,
keypoint_candidates,
keypoint_scores,
num_keypoint_candidates,
bboxes=boxes_slice,
unmatched_keypoint_score=kp_params.unmatched_keypoint_score,
box_scale=kp_params.box_scale,
candidate_search_scale=kp_params.candidate_search_scale,
candidate_ranking_mode=kp_params.candidate_ranking_mode,
score_distance_offset=kp_params.score_distance_offset,
keypoint_depth_candidates=keypoint_depth_candidates,
keypoint_score_threshold=(kp_params.keypoint_candidate_score_threshold),
score_distance_multiplier=kp_params.score_distance_multiplier,
keypoint_std_dev=kpts_std_dev_postprocess)
return refined_keypoints, refined_scores, refined_depths
def regularization_losses(self):
return []
def restore_map(self,
fine_tune_checkpoint_type='detection',
load_all_detection_checkpoint_vars=False):
raise RuntimeError('CenterNetMetaArch not supported under TF1.x.')
def restore_from_objects(self, fine_tune_checkpoint_type='detection'):
"""Returns a map of Trackable objects to load from a foreign checkpoint.
Returns a dictionary of Tensorflow 2 Trackable objects (e.g. tf.Module
or Checkpoint). This enables the model to initialize based on weights from
another task. For example, the feature extractor variables from a
classification model can be used to bootstrap training of an object
detector. When loading from an object detection model, the checkpoint model
should have the same parameters as this detection model with exception of
the num_classes parameter.
Note that this function is intended to be used to restore Keras-based
models when running Tensorflow 2, whereas restore_map (not implemented
in CenterNet) is intended to be used to restore Slim-based models when
running Tensorflow 1.x.
TODO(jonathanhuang): Make this function consistent with other
meta-architectures.
Args:
fine_tune_checkpoint_type: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.
Valid values: `detection`, `classification`, `fine_tune`.
Default 'detection'.
'detection': used when loading models pre-trained on other detection
tasks. With this checkpoint type the weights of the feature extractor
are expected under the attribute 'feature_extractor'.
'classification': used when loading models pre-trained on an image
classification task. Note that only the encoder section of the network
is loaded and not the upsampling layers. With this checkpoint type,
the weights of only the encoder section are expected under the
attribute 'feature_extractor'.
'fine_tune': used when loading the entire CenterNet feature extractor
pre-trained on other tasks. The checkpoints saved during CenterNet
model training can be directly loaded using this type. With this
checkpoint type, the weights of the feature extractor are expected
under the attribute 'model._feature_extractor'.
For more details, see the tensorflow section on Loading mechanics.
https://www.tensorflow.org/guide/checkpoint#loading_mechanics
Returns:
A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
"""
if fine_tune_checkpoint_type == 'detection':
feature_extractor_model = tf.train.Checkpoint(
_feature_extractor=self._feature_extractor)
return {'model': feature_extractor_model}
elif fine_tune_checkpoint_type == 'classification':
return {
'feature_extractor':
self._feature_extractor.classification_backbone
}
elif fine_tune_checkpoint_type == 'full':
return {'model': self}
elif fine_tune_checkpoint_type == 'fine_tune':
raise ValueError(('"fine_tune" is no longer supported for CenterNet. '
'Please set fine_tune_checkpoint_type to "detection"'
' which has the same functionality. If you are using'
' the ExtremeNet checkpoint, download the new version'
' from the model zoo.'))
else:
raise ValueError('Unknown fine tune checkpoint type {}'.format(
fine_tune_checkpoint_type))
def updates(self):
if tf_version.is_tf2():
raise RuntimeError('This model is intended to be used with model_lib_v2 '
'which does not support updates()')
else:
update_ops = []
slim_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
# Copy the slim ops to avoid modifying the collection
if slim_update_ops:
update_ops.extend(slim_update_ops)
return update_ops