tensorflow/models

View on GitHub
official/projects/mosaic/mosaic_tasks.py

Summary

Maintainability
A
1 hr
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Task definition for image semantic segmentation with MOSAIC models."""

from absl import logging
import tensorflow as tf, tf_keras

from official.core import task_factory
from official.projects.mosaic.configs import mosaic_config
from official.projects.mosaic.modeling import mosaic_model
from official.vision.tasks import semantic_segmentation as seg_tasks


@task_factory.register_task_cls(mosaic_config.MosaicSemanticSegmentationTask)
class MosaicSemanticSegmentationTask(seg_tasks.SemanticSegmentationTask):
  """A task for semantic segmentation using MOSAIC model."""

  # Note: the `build_model` is overrided to add an additional `train` flag
  # for the purpose of indicating the model is built for performing `training`
  # or `eval`. This is to make sure the model is initialized with proper
  # `input_shape` if the model will be trained and evaluated in different
  # `input_shape`. For example, the model is trained with cropping but
  # evaluated with original shape.
  def build_model(self, training: bool = True) -> tf_keras.Model:
    """Builds MOSAIC segmentation model."""
    input_specs = tf_keras.layers.InputSpec(
        shape=[None] + self.task_config.model.input_size)

    l2_weight_decay = self.task_config.losses.l2_weight_decay
    # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
    # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
    # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
    l2_regularizer = (tf_keras.regularizers.l2(
        l2_weight_decay / 2.0) if l2_weight_decay else None)

    model = mosaic_model.build_mosaic_segmentation_model(
        input_specs=input_specs,
        model_config=self.task_config.model,
        l2_regularizer=l2_regularizer)

    # Note: Create a dummy input and call model instance to initialize.
    # This ensures all the layers are built; otherwise some layers may be
    # missing from the model and cannot be associated with variables from
    # a loaded checkpoint. The input size is determined by whether the model
    # is built for performing training or eval.
    if training:
      input_size = self.task_config.train_data.output_size
      crop_size = self.task_config.train_data.crop_size
      if crop_size:
        input_size = crop_size
    else:
      input_size = self.task_config.validation_data.output_size

    if len(self.task_config.model.input_size) == 3:
      input_channel = self.task_config.model.input_size[-1]
    else:
      input_channel = 3

    dummy_input = tf.ones(shape=[1] + input_size + [input_channel])
    model(dummy_input)

    return model

  def initialize(self, model: tf_keras.Model):
    """Loads pretrained checkpoint."""
    if not self.task_config.init_checkpoint:
      return

    ckpt_dir_or_file = self.task_config.init_checkpoint
    if tf.io.gfile.isdir(ckpt_dir_or_file):
      ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)

    # Restoring checkpoint.
    if 'all' in self.task_config.init_checkpoint_modules:
      ckpt = tf.train.Checkpoint(**model.checkpoint_items)
      status = ckpt.read(ckpt_dir_or_file)
      status.expect_partial().assert_existing_objects_matched()
    else:
      ckpt_items = {}
      if 'backbone' in self.task_config.init_checkpoint_modules:
        ckpt_items.update(backbone=model.backbone)
      if 'neck' in self.task_config.init_checkpoint_modules:
        ckpt_items.update(neck=model.neck)

      ckpt = tf.train.Checkpoint(**ckpt_items)
      status = ckpt.read(ckpt_dir_or_file)
      status.expect_partial().assert_existing_objects_matched()

    logging.info('Finished loading pretrained checkpoint from %s',
                 ckpt_dir_or_file)