tensorflow/models

View on GitHub
official/projects/panoptic/modeling/panoptic_deeplab_model.py

Summary

Maintainability
A
50 mins
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Build Panoptic Deeplab model."""
from typing import Any, Mapping, Optional, Union

import tensorflow as tf, tf_keras
from official.projects.panoptic.modeling.layers import panoptic_deeplab_merge


@tf_keras.utils.register_keras_serializable(package='Vision')
class PanopticDeeplabModel(tf_keras.Model):
  """Panoptic Deeplab model."""

  def __init__(
      self,
      backbone: tf_keras.Model,
      semantic_decoder: tf_keras.Model,
      semantic_head: tf_keras.layers.Layer,
      instance_head: tf_keras.layers.Layer,
      instance_decoder: Optional[tf_keras.Model] = None,
      post_processor: Optional[panoptic_deeplab_merge.PostProcessor] = None,
      **kwargs):
    """Panoptic deeplab model initializer.

    Args:
      backbone: a backbone network.
      semantic_decoder: a decoder network. E.g. FPN.
      semantic_head: segmentation head.
      instance_head: instance center head.
      instance_decoder: Optional decoder network for instance predictions.
      post_processor: Optional post processor layer.
      **kwargs: keyword arguments to be passed.
    """
    super(PanopticDeeplabModel, self).__init__(**kwargs)

    self._config_dict = {
        'backbone': backbone,
        'semantic_decoder': semantic_decoder,
        'instance_decoder': instance_decoder,
        'semantic_head': semantic_head,
        'instance_head': instance_head,
        'post_processor': post_processor
    }
    self.backbone = backbone
    self.semantic_decoder = semantic_decoder
    self.instance_decoder = instance_decoder
    self.semantic_head = semantic_head
    self.instance_head = instance_head
    self.post_processor = post_processor

  def call(  # pytype: disable=signature-mismatch  # overriding-parameter-count-checks
      self, inputs: tf.Tensor,
      image_info: tf.Tensor,
      training: bool = None):
    if training is None:
      training = tf_keras.backend.learning_phase()

    backbone_features = self.backbone(inputs, training=training)

    semantic_features = self.semantic_decoder(
        backbone_features, training=training)

    if self.instance_decoder is None:
      instance_features = semantic_features
    else:
      instance_features = self.instance_decoder(
          backbone_features, training=training)

    segmentation_outputs = self.semantic_head(
        (backbone_features, semantic_features),
        training=training)
    instance_outputs = self.instance_head(
        (backbone_features, instance_features),
        training=training)

    outputs = {
        'segmentation_outputs': segmentation_outputs,
        'instance_centers_heatmap':
            instance_outputs['instance_centers_heatmap'],
        'instance_centers_offset':
            instance_outputs['instance_centers_offset'],
    }
    if training:
      return outputs

    if self.post_processor is not None:
      panoptic_masks = self.post_processor(outputs, image_info)
      outputs.update(panoptic_masks)
    return outputs

  @property
  def checkpoint_items(
      self) -> Mapping[str, Union[tf_keras.Model, tf_keras.layers.Layer]]:
    """Returns a dictionary of items to be additionally checkpointed."""
    items = dict(
        backbone=self.backbone,
        semantic_decoder=self.semantic_decoder,
        semantic_head=self.semantic_head,
        instance_head=self.instance_head)
    if self.instance_decoder is not None:
      items.update(instance_decoder=self.instance_decoder)

    return items

  def get_config(self) -> Mapping[str, Any]:
    return self._config_dict

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)