research/object_detection/predictors/mask_rcnn_keras_box_predictor.py
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mask R-CNN Box Predictor."""
from object_detection.core import box_predictor
BOX_ENCODINGS = box_predictor.BOX_ENCODINGS
CLASS_PREDICTIONS_WITH_BACKGROUND = (
box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND)
MASK_PREDICTIONS = box_predictor.MASK_PREDICTIONS
class MaskRCNNKerasBoxPredictor(box_predictor.KerasBoxPredictor):
"""Mask R-CNN Box Predictor.
See Mask R-CNN: He, K., Gkioxari, G., Dollar, P., & Girshick, R. (2017).
Mask R-CNN. arXiv preprint arXiv:1703.06870.
This is used for the second stage of the Mask R-CNN detector where proposals
cropped from an image are arranged along the batch dimension of the input
image_features tensor. Notice that locations are *not* shared across classes,
thus for each anchor, a separate prediction is made for each class.
In addition to predicting boxes and classes, optionally this class allows
predicting masks and/or keypoints inside detection boxes.
Currently this box predictor makes per-class predictions; that is, each
anchor makes a separate box prediction for each class.
"""
def __init__(self,
is_training,
num_classes,
freeze_batchnorm,
box_prediction_head,
class_prediction_head,
third_stage_heads,
name=None):
"""Constructor.
Args:
is_training: Indicates whether the BoxPredictor is in training mode.
num_classes: number of classes. Note that num_classes *does not*
include the background category, so if groundtruth labels take values
in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the
assigned classification targets can range from {0,... K}).
freeze_batchnorm: Whether to freeze batch norm parameters during
training or not. When training with a small batch size (e.g. 1), it is
desirable to freeze batch norm update and use pretrained batch norm
params.
box_prediction_head: The head that predicts the boxes in second stage.
class_prediction_head: The head that predicts the classes in second stage.
third_stage_heads: A dictionary mapping head names to mask rcnn head
classes.
name: A string name scope to assign to the model. If `None`, Keras
will auto-generate one from the class name.
"""
super(MaskRCNNKerasBoxPredictor, self).__init__(
is_training, num_classes, freeze_batchnorm=freeze_batchnorm,
inplace_batchnorm_update=False, name=name)
self._box_prediction_head = box_prediction_head
self._class_prediction_head = class_prediction_head
self._third_stage_heads = third_stage_heads
@property
def num_classes(self):
return self._num_classes
def get_second_stage_prediction_heads(self):
return BOX_ENCODINGS, CLASS_PREDICTIONS_WITH_BACKGROUND
def get_third_stage_prediction_heads(self):
return sorted(self._third_stage_heads.keys())
def _predict(self,
image_features,
prediction_stage=2,
**kwargs):
"""Optionally computes encoded object locations, confidences, and masks.
Predicts the heads belonging to the given prediction stage.
Args:
image_features: A list of float tensors of shape
[batch_size, height_i, width_i, channels_i] containing roi pooled
features for each image. The length of the list should be 1 otherwise
a ValueError will be raised.
prediction_stage: Prediction stage. Acceptable values are 2 and 3.
**kwargs: Unused Keyword args
Returns:
A dictionary containing the predicted tensors that are listed in
self._prediction_heads. A subset of the following keys will exist in the
dictionary:
BOX_ENCODINGS: A float tensor of shape
[batch_size, 1, num_classes, code_size] representing the
location of the objects.
CLASS_PREDICTIONS_WITH_BACKGROUND: A float tensor of shape
[batch_size, 1, num_classes + 1] representing the class
predictions for the proposals.
MASK_PREDICTIONS: A float tensor of shape
[batch_size, 1, num_classes, image_height, image_width]
Raises:
ValueError: If num_predictions_per_location is not 1 or if
len(image_features) is not 1.
ValueError: if prediction_stage is not 2 or 3.
"""
if len(image_features) != 1:
raise ValueError('length of `image_features` must be 1. Found {}'.format(
len(image_features)))
image_feature = image_features[0]
predictions_dict = {}
if prediction_stage == 2:
predictions_dict[BOX_ENCODINGS] = self._box_prediction_head(image_feature)
predictions_dict[CLASS_PREDICTIONS_WITH_BACKGROUND] = (
self._class_prediction_head(image_feature))
elif prediction_stage == 3:
for prediction_head in self.get_third_stage_prediction_heads():
head_object = self._third_stage_heads[prediction_head]
predictions_dict[prediction_head] = head_object(image_feature)
else:
raise ValueError('prediction_stage should be either 2 or 3.')
return predictions_dict