tensorflow/models

View on GitHub
official/projects/deepmac_maskrcnn/modeling/maskrcnn_model.py

Summary

Maintainability
B
6 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Mask R-CNN model."""

from typing import List, Mapping, Optional, Union

# Import libraries

from absl import logging
import tensorflow as tf, tf_keras

from official.vision.modeling import maskrcnn_model
from official.vision.ops import box_ops


def resize_as(source, size):

  source = tf.transpose(source, (0, 2, 3, 1))
  source = tf.image.resize(source, (size, size))
  return tf.transpose(source, (0, 3, 1, 2))


class DeepMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
  """The Mask R-CNN model."""

  def __init__(self,
               backbone: tf_keras.Model,
               decoder: tf_keras.Model,
               rpn_head: tf_keras.layers.Layer,
               detection_head: Union[tf_keras.layers.Layer,
                                     List[tf_keras.layers.Layer]],
               roi_generator: tf_keras.layers.Layer,
               roi_sampler: Union[tf_keras.layers.Layer,
                                  List[tf_keras.layers.Layer]],
               roi_aligner: tf_keras.layers.Layer,
               detection_generator: tf_keras.layers.Layer,
               mask_head: Optional[tf_keras.layers.Layer] = None,
               mask_sampler: Optional[tf_keras.layers.Layer] = None,
               mask_roi_aligner: Optional[tf_keras.layers.Layer] = None,
               class_agnostic_bbox_pred: bool = False,
               cascade_class_ensemble: bool = False,
               min_level: Optional[int] = None,
               max_level: Optional[int] = None,
               num_scales: Optional[int] = None,
               aspect_ratios: Optional[List[float]] = None,
               anchor_size: Optional[float] = None,
               outer_boxes_scale: float = 1.0,
               use_gt_boxes_for_masks=False,
               **kwargs):
    """Initializes the Mask R-CNN model.

    Args:
      backbone: `tf_keras.Model`, the backbone network.
      decoder: `tf_keras.Model`, the decoder network.
      rpn_head: the RPN head.
      detection_head: the detection head or a list of heads.
      roi_generator: the ROI generator.
      roi_sampler: a single ROI sampler or a list of ROI samplers for cascade
        detection heads.
      roi_aligner: the ROI aligner.
      detection_generator: the detection generator.
      mask_head: the mask head.
      mask_sampler: the mask sampler.
      mask_roi_aligner: the ROI alginer for mask prediction.
      class_agnostic_bbox_pred: if True, perform class agnostic bounding box
        prediction. Needs to be `True` for Cascade RCNN models.
      cascade_class_ensemble: if True, ensemble classification scores over all
        detection heads.
      min_level: Minimum level in output feature maps.
      max_level: Maximum level in output feature maps.
      num_scales: A number representing intermediate scales added on each level.
        For instances, num_scales=2 adds one additional intermediate anchor
        scales [2^0, 2^0.5] on each level.
      aspect_ratios: A list representing the aspect raito anchors added on each
        level. The number indicates the ratio of width to height. For instances,
        aspect_ratios=[1.0, 2.0, 0.5] adds three anchors on each scale level.
      anchor_size: A number representing the scale of size of the base anchor to
        the feature stride 2^level.
      outer_boxes_scale: a float to scale up the bounding boxes to generate
        more inclusive masks. The scale is expected to be >=1.0.
      use_gt_boxes_for_masks: bool, if set, crop using groundtruth boxes instead
        of proposals for training mask head
      **kwargs: keyword arguments to be passed.
    """
    super().__init__(
        backbone=backbone,
        decoder=decoder,
        rpn_head=rpn_head,
        detection_head=detection_head,
        roi_generator=roi_generator,
        roi_sampler=roi_sampler,
        roi_aligner=roi_aligner,
        detection_generator=detection_generator,
        mask_head=mask_head,
        mask_sampler=mask_sampler,
        mask_roi_aligner=mask_roi_aligner,
        class_agnostic_bbox_pred=class_agnostic_bbox_pred,
        cascade_class_ensemble=cascade_class_ensemble,
        min_level=min_level,
        max_level=max_level,
        num_scales=num_scales,
        aspect_ratios=aspect_ratios,
        anchor_size=anchor_size,
        outer_boxes_scale=outer_boxes_scale,
        **kwargs)

    self._config_dict['use_gt_boxes_for_masks'] = use_gt_boxes_for_masks

  def call(self,
           images: tf.Tensor,
           image_shape: tf.Tensor,
           anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
           gt_boxes: Optional[tf.Tensor] = None,
           gt_classes: Optional[tf.Tensor] = None,
           gt_masks: Optional[tf.Tensor] = None,
           gt_outer_boxes: Optional[tf.Tensor] = None,
           training: Optional[bool] = None) -> Mapping[str, tf.Tensor]:
    call_box_outputs_kwargs = {
        'images': images,
        'image_shape': image_shape,
        'anchor_boxes': anchor_boxes,
        'gt_boxes': gt_boxes,
        'gt_classes': gt_classes,
        'training': training
    }
    if self.outer_boxes_scale > 1.0:
      call_box_outputs_kwargs['gt_outer_boxes'] = gt_outer_boxes
    model_outputs, intermediate_outputs = self._call_box_outputs(
        **call_box_outputs_kwargs)
    if not self._include_mask:
      return model_outputs

    if self.outer_boxes_scale == 1.0:
      current_rois = intermediate_outputs['current_rois']
      matched_gt_boxes = intermediate_outputs['matched_gt_boxes']
      mask_head_gt_boxes = gt_boxes
    else:
      current_rois = box_ops.compute_outer_boxes(
          intermediate_outputs['current_rois'],
          tf.expand_dims(image_shape, axis=1), self.outer_boxes_scale)
      matched_gt_boxes = intermediate_outputs['matched_gt_outer_boxes']
      mask_head_gt_boxes = gt_outer_boxes

    model_mask_outputs = self._call_mask_outputs(
        model_box_outputs=model_outputs,
        features=model_outputs['decoder_features'],
        current_rois=current_rois,
        matched_gt_indices=intermediate_outputs['matched_gt_indices'],
        matched_gt_boxes=matched_gt_boxes,
        matched_gt_classes=intermediate_outputs['matched_gt_classes'],
        gt_masks=gt_masks,
        gt_classes=gt_classes,
        gt_boxes=mask_head_gt_boxes,
        training=training)
    model_outputs.update(model_mask_outputs)
    return model_outputs

  def call_images_and_boxes(self, images, boxes):
    """Predict masks given an image and bounding boxes."""

    _, decoder_features = self._get_backbone_and_decoder_features(images)
    boxes_shape = tf.shape(boxes)
    batch_size, num_boxes = boxes_shape[0], boxes_shape[1]
    classes = tf.zeros((batch_size, num_boxes), dtype=tf.int32)

    _, mask_probs = self._features_to_mask_outputs(
        decoder_features, boxes, classes)
    return {
        'detection_masks': mask_probs
    }

  def _call_mask_outputs(
      self,
      model_box_outputs: Mapping[str, tf.Tensor],
      features: tf.Tensor,
      current_rois: tf.Tensor,
      matched_gt_indices: tf.Tensor,
      matched_gt_boxes: tf.Tensor,
      matched_gt_classes: tf.Tensor,
      gt_masks: tf.Tensor,
      gt_classes: tf.Tensor,
      gt_boxes: tf.Tensor,
      training: Optional[bool] = None) -> Mapping[str, tf.Tensor]:

    model_outputs = dict(model_box_outputs)
    if training:
      if self._config_dict['use_gt_boxes_for_masks']:
        mask_size = (
            self.mask_roi_aligner._config_dict['crop_size'] *  # pylint:disable=protected-access
            self.mask_head._config_dict['upsample_factor']  # pylint:disable=protected-access
        )
        gt_masks = resize_as(source=gt_masks, size=mask_size)

        logging.info('Using GT class and mask targets.')
        model_outputs.update({
            'mask_class_targets': gt_classes,
            'mask_targets': gt_masks,
        })
      else:
        rois, roi_classes, roi_masks = self.mask_sampler(
            current_rois, matched_gt_boxes, matched_gt_classes,
            matched_gt_indices, gt_masks)
        roi_masks = tf.stop_gradient(roi_masks)
        model_outputs.update({
            'mask_class_targets': roi_classes,
            'mask_targets': roi_masks,
        })

    else:
      if self.outer_boxes_scale == 1.0:
        rois = model_outputs['detection_boxes']
      else:
        rois = model_outputs['detection_outer_boxes']
      roi_classes = model_outputs['detection_classes']

    # Mask RoI align.
    if training and self._config_dict['use_gt_boxes_for_masks']:
      logging.info('Using GT mask roi features.')
      roi_aligner_boxes = gt_boxes
      mask_head_classes = gt_classes

    else:
      roi_aligner_boxes = rois
      mask_head_classes = roi_classes

    mask_logits, mask_probs = self._features_to_mask_outputs(
        features, roi_aligner_boxes, mask_head_classes)

    if training:
      model_outputs.update({
          'mask_outputs': mask_logits,
      })
    else:
      model_outputs.update({
          'detection_masks': mask_probs,
      })
    return model_outputs