tensorflow/models

View on GitHub
official/projects/volumetric_models/modeling/decoders/unet_3d_decoder.py

Summary

Maintainability
A
2 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains definitions of 3D UNet Model decoder part.

[1] Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf
Ronneberger. 3D U-Net: Learning Dense Volumetric Segmentation from Sparse
Annotation. arXiv:1606.06650.
"""

from typing import Any, Dict, Mapping, Optional, Sequence

import tensorflow as tf, tf_keras

from official.modeling import hyperparams
from official.projects.volumetric_models.modeling import nn_blocks_3d
from official.projects.volumetric_models.modeling.decoders import factory

layers = tf_keras.layers


@tf_keras.utils.register_keras_serializable(package='Vision')
class UNet3DDecoder(tf_keras.Model):
  """Class to build 3D UNet decoder."""

  def __init__(self,
               model_id: int,
               input_specs: Mapping[str, tf.TensorShape],
               pool_size: Sequence[int] = (2, 2, 2),
               kernel_size: Sequence[int] = (3, 3, 3),
               kernel_regularizer: tf_keras.regularizers.Regularizer = None,
               activation: str = 'relu',
               norm_momentum: float = 0.99,
               norm_epsilon: float = 0.001,
               use_sync_bn: bool = False,
               use_batch_normalization: bool = False,
               use_deconvolution: bool = False,  # pytype: disable=annotation-type-mismatch  # typed-keras
               **kwargs):
    """3D UNet decoder initialization function.

    Args:
      model_id: The depth of UNet3D backbone model. The greater the depth, the
        more max pooling layers will be added to the model. Lowering the depth
        may reduce the amount of memory required for training.
      input_specs: The input specifications. A dictionary consists of
        {level: TensorShape} from a backbone.
      pool_size: The pooling size for the max pooling operations.
      kernel_size: The kernel size for 3D convolution.
      kernel_regularizer: A tf_keras.regularizers.Regularizer object for Conv2D.
        Default to None.
      activation: The name of the activation function.
      norm_momentum: The normalization momentum for the moving average.
      norm_epsilon: A float added to variance to avoid dividing by zero.
      use_sync_bn: If True, use synchronized batch normalization.
      use_batch_normalization: If set to True, use batch normalization after
        convolution and before activation. Default to False.
      use_deconvolution: If set to True, the model will use transpose
        convolution (deconvolution) instead of up-sampling. This increases the
        amount memory required during training. Default to False.
      **kwargs: Keyword arguments to be passed.
    """
    self._config_dict = {
        'model_id': model_id,
        'input_specs': input_specs,
        'pool_size': pool_size,
        'kernel_size': kernel_size,
        'kernel_regularizer': kernel_regularizer,
        'activation': activation,
        'norm_momentum': norm_momentum,
        'norm_epsilon': norm_epsilon,
        'use_sync_bn': use_sync_bn,
        'use_batch_normalization': use_batch_normalization,
        'use_deconvolution': use_deconvolution
    }
    if use_sync_bn:
      self._norm = layers.experimental.SyncBatchNormalization
    else:
      self._norm = layers.BatchNormalization
    self._use_batch_normalization = use_batch_normalization

    if tf_keras.backend.image_data_format() == 'channels_last':
      channel_dim = -1
    else:
      channel_dim = 1

    # Build 3D UNet.
    inputs = self._build_input_pyramid(input_specs, model_id)  # pytype: disable=wrong-arg-types  # dynamic-method-lookup

    # Add levels with up-convolution or up-sampling.
    x = inputs[str(model_id)]
    for layer_depth in range(model_id - 1, 0, -1):
      # Apply deconvolution or upsampling.
      if use_deconvolution:
        x = layers.Conv3DTranspose(
            filters=x.get_shape().as_list()[channel_dim],
            kernel_size=pool_size,
            strides=(2, 2, 2))(
                x)
      else:
        x = layers.UpSampling3D(size=pool_size)(x)

      # Concatenate upsampled features with input features from one layer up.
      x = tf.concat([x, tf.cast(inputs[str(layer_depth)], dtype=x.dtype)],
                    axis=channel_dim)
      filter_num = inputs[str(layer_depth)].get_shape().as_list()[channel_dim]
      x = nn_blocks_3d.BasicBlock3DVolume(
          filters=[filter_num, filter_num],
          strides=(1, 1, 1),
          kernel_size=kernel_size,
          kernel_regularizer=kernel_regularizer,
          activation=activation,
          use_sync_bn=use_sync_bn,
          norm_momentum=norm_momentum,
          norm_epsilon=norm_epsilon,
          use_batch_normalization=use_batch_normalization)(
              x)

    feats = {'1': x}
    self._output_specs = {l: feats[l].get_shape() for l in feats}

    super(UNet3DDecoder, self).__init__(inputs=inputs, outputs=feats, **kwargs)

  def _build_input_pyramid(self, input_specs: Dict[str, tf.TensorShape],
                           depth: int) -> Dict[str, tf.Tensor]:
    """Builds input pyramid features."""
    assert isinstance(input_specs, dict)
    if len(input_specs.keys()) > depth:
      raise ValueError(
          'Backbone depth should be equal to 3D UNet decoder\'s depth.')

    inputs = {}
    for level, spec in input_specs.items():
      inputs[level] = tf_keras.Input(shape=spec[1:])
    return inputs

  def get_config(self) -> Mapping[str, Any]:
    return self._config_dict

  @classmethod
  def from_config(cls, config: Mapping[str, Any], custom_objects=None):
    return cls(**config)

  @property
  def output_specs(self) -> Mapping[str, tf.TensorShape]:
    """A dict of {level: TensorShape} pairs for the model output."""
    return self._output_specs


@factory.register_decoder_builder('unet_3d_decoder')
def build_unet_3d_decoder(
    input_specs: Mapping[str, tf.TensorShape],
    model_config: hyperparams.Config,
    l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None
) -> tf_keras.Model:
  """Builds UNet3D decoder from a config.

  Args:
    input_specs: A `dict` of input specifications. A dictionary consists of
      {level: TensorShape} from a backbone.
    model_config: A OneOfConfig. Model config.
    l2_regularizer: A `tf_keras.regularizers.Regularizer` instance. Default to
      None.

  Returns:
    A `tf_keras.Model` instance of the UNet3D decoder.
  """
  decoder_type = model_config.decoder.type
  decoder_cfg = model_config.decoder.get()
  assert decoder_type == 'unet_3d_decoder', (f'Inconsistent decoder type '
                                             f'{decoder_type}')
  norm_activation_config = model_config.norm_activation
  return UNet3DDecoder(
      model_id=decoder_cfg.model_id,
      input_specs=input_specs,
      pool_size=decoder_cfg.pool_size,
      kernel_regularizer=l2_regularizer,
      activation=norm_activation_config.activation,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      use_sync_bn=norm_activation_config.use_sync_bn,
      use_batch_normalization=decoder_cfg.use_batch_normalization,
      use_deconvolution=decoder_cfg.use_deconvolution)