research/object_detection/meta_architectures/context_rcnn_lib_tf2.py
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library functions for Context R-CNN."""
import tensorflow as tf
from object_detection.core import freezable_batch_norm
# The negative value used in padding the invalid weights.
_NEGATIVE_PADDING_VALUE = -100000
class ContextProjection(tf.keras.layers.Layer):
"""Custom layer to do batch normalization and projection."""
def __init__(self, projection_dimension, **kwargs):
self.batch_norm = freezable_batch_norm.FreezableBatchNorm(
epsilon=0.001,
center=True,
scale=True,
momentum=0.97,
trainable=True)
self.projection = tf.keras.layers.Dense(units=projection_dimension,
use_bias=True)
self.projection_dimension = projection_dimension
super(ContextProjection, self).__init__(**kwargs)
def build(self, input_shape):
self.projection.build(input_shape)
self.batch_norm.build(input_shape[:1] + [self.projection_dimension])
def call(self, input_features, is_training=False):
return tf.nn.relu6(self.batch_norm(self.projection(input_features),
is_training))
class AttentionBlock(tf.keras.layers.Layer):
"""Custom layer to perform all attention."""
def __init__(self, bottleneck_dimension, attention_temperature,
output_dimension=None, is_training=False,
name='AttentionBlock', max_num_proposals=100,
**kwargs):
"""Constructs an attention block.
Args:
bottleneck_dimension: A int32 Tensor representing the bottleneck dimension
for intermediate projections.
attention_temperature: A float Tensor. It controls the temperature of the
softmax for weights calculation. The formula for calculation as follows:
weights = exp(weights / temperature) / sum(exp(weights / temperature))
output_dimension: A int32 Tensor representing the last dimension of the
output feature.
is_training: A boolean Tensor (affecting batch normalization).
name: A string describing what to name the variables in this block.
max_num_proposals: The number of box proposals for each image
**kwargs: Additional keyword arguments.
"""
self._key_proj = ContextProjection(bottleneck_dimension)
self._val_proj = ContextProjection(bottleneck_dimension)
self._query_proj = ContextProjection(bottleneck_dimension)
self._feature_proj = None
self._attention_temperature = attention_temperature
self._bottleneck_dimension = bottleneck_dimension
self._is_training = is_training
self._output_dimension = output_dimension
self._max_num_proposals = max_num_proposals
if self._output_dimension:
self._feature_proj = ContextProjection(self._output_dimension)
super(AttentionBlock, self).__init__(name=name, **kwargs)
def build(self, input_shapes):
"""Finishes building the attention block.
Args:
input_shapes: the shape of the primary input box features.
"""
if not self._feature_proj:
self._output_dimension = input_shapes[-1]
self._feature_proj = ContextProjection(self._output_dimension)
def call(self, box_features, context_features, valid_context_size,
num_proposals):
"""Handles a call by performing attention.
Args:
box_features: A float Tensor of shape [batch_size * input_size, height,
width, num_input_features].
context_features: A float Tensor of shape [batch_size, context_size,
num_context_features].
valid_context_size: A int32 Tensor of shape [batch_size].
num_proposals: A [batch_size] int32 Tensor specifying the number of valid
proposals per image in the batch.
Returns:
A float Tensor with shape [batch_size, input_size, num_input_features]
containing output features after attention with context features.
"""
_, context_size, _ = context_features.shape
keys_values_valid_mask = compute_valid_mask(
valid_context_size, context_size)
total_proposals, height, width, channels = box_features.shape
batch_size = total_proposals // self._max_num_proposals
box_features = tf.reshape(
box_features,
[batch_size,
self._max_num_proposals,
height,
width,
channels])
# Average pools over height and width dimension so that the shape of
# box_features becomes [batch_size, max_num_proposals, channels].
box_features = tf.reduce_mean(box_features, [2, 3])
queries_valid_mask = compute_valid_mask(num_proposals,
box_features.shape[1])
queries = project_features(
box_features, self._bottleneck_dimension, self._is_training,
self._query_proj, normalize=True)
keys = project_features(
context_features, self._bottleneck_dimension, self._is_training,
self._key_proj, normalize=True)
values = project_features(
context_features, self._bottleneck_dimension, self._is_training,
self._val_proj, normalize=True)
# masking out any keys which are padding
keys *= tf.cast(keys_values_valid_mask[..., tf.newaxis], keys.dtype)
queries *= tf.cast(queries_valid_mask[..., tf.newaxis], queries.dtype)
weights = tf.matmul(queries, keys, transpose_b=True)
weights, values = filter_weight_value(weights, values,
keys_values_valid_mask)
weights = tf.nn.softmax(weights / self._attention_temperature)
features = tf.matmul(weights, values)
output_features = project_features(
features, self._output_dimension, self._is_training,
self._feature_proj, normalize=False)
output_features = output_features[:, :, tf.newaxis, tf.newaxis, :]
return output_features
def filter_weight_value(weights, values, valid_mask):
"""Filters weights and values based on valid_mask.
_NEGATIVE_PADDING_VALUE will be added to invalid elements in the weights to
avoid their contribution in softmax. 0 will be set for the invalid elements in
the values.
Args:
weights: A float Tensor of shape [batch_size, input_size, context_size].
values: A float Tensor of shape [batch_size, context_size,
projected_dimension].
valid_mask: A boolean Tensor of shape [batch_size, context_size]. True means
valid and False means invalid.
Returns:
weights: A float Tensor of shape [batch_size, input_size, context_size].
values: A float Tensor of shape [batch_size, context_size,
projected_dimension].
Raises:
ValueError: If shape of doesn't match.
"""
w_batch_size, _, w_context_size = weights.shape
v_batch_size, v_context_size, _ = values.shape
m_batch_size, m_context_size = valid_mask.shape
if w_batch_size != v_batch_size or v_batch_size != m_batch_size:
raise ValueError('Please make sure the first dimension of the input'
' tensors are the same.')
if w_context_size != v_context_size:
raise ValueError('Please make sure the third dimension of weights matches'
' the second dimension of values.')
if w_context_size != m_context_size:
raise ValueError('Please make sure the third dimension of the weights'
' matches the second dimension of the valid_mask.')
valid_mask = valid_mask[..., tf.newaxis]
# Force the invalid weights to be very negative so it won't contribute to
# the softmax.
weights += tf.transpose(
tf.cast(tf.math.logical_not(valid_mask), weights.dtype) *
_NEGATIVE_PADDING_VALUE,
perm=[0, 2, 1])
# Force the invalid values to be 0.
values *= tf.cast(valid_mask, values.dtype)
return weights, values
def project_features(features, bottleneck_dimension, is_training,
layer, normalize=True):
"""Projects features to another feature space.
Args:
features: A float Tensor of shape [batch_size, features_size,
num_features].
bottleneck_dimension: A int32 Tensor.
is_training: A boolean Tensor (affecting batch normalization).
layer: Contains a custom layer specific to the particular operation
being performed (key, value, query, features)
normalize: A boolean Tensor. If true, the output features will be l2
normalized on the last dimension.
Returns:
A float Tensor of shape [batch, features_size, projection_dimension].
"""
shape_arr = features.shape
batch_size, _, num_features = shape_arr
features = tf.reshape(features, [-1, num_features])
projected_features = layer(features, is_training)
projected_features = tf.reshape(projected_features,
[batch_size, -1, bottleneck_dimension])
if normalize:
projected_features = tf.keras.backend.l2_normalize(projected_features,
axis=-1)
return projected_features
def compute_valid_mask(num_valid_elements, num_elements):
"""Computes mask of valid entries within padded context feature.
Args:
num_valid_elements: A int32 Tensor of shape [batch_size].
num_elements: An int32 Tensor.
Returns:
A boolean Tensor of the shape [batch_size, num_elements]. True means
valid and False means invalid.
"""
batch_size = num_valid_elements.shape[0]
element_idxs = tf.range(num_elements, dtype=tf.int32)
batch_element_idxs = tf.tile(element_idxs[tf.newaxis, ...], [batch_size, 1])
num_valid_elements = num_valid_elements[..., tf.newaxis]
valid_mask = tf.less(batch_element_idxs, num_valid_elements)
return valid_mask