tensorflow/models

View on GitHub
official/projects/qat/vision/modeling/factory.py

Summary

Maintainability
A
2 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Factory methods to build models."""
# Import libraries

import tensorflow as tf, tf_keras

import tensorflow_model_optimization as tfmot
from official.projects.qat.vision.configs import common
from official.projects.qat.vision.modeling import segmentation_model as qat_segmentation_model
from official.projects.qat.vision.modeling.heads import dense_prediction_heads as dense_prediction_heads_qat
from official.projects.qat.vision.modeling.layers import nn_layers as qat_nn_layers
from official.projects.qat.vision.n_bit import schemes as n_bit_schemes
from official.projects.qat.vision.quantization import configs as qat_configs
from official.projects.qat.vision.quantization import helper
from official.projects.qat.vision.quantization import schemes
from official.vision import configs
from official.vision.modeling import classification_model
from official.vision.modeling import retinanet_model
from official.vision.modeling.decoders import aspp
from official.vision.modeling.decoders import fpn
from official.vision.modeling.heads import dense_prediction_heads
from official.vision.modeling.heads import segmentation_heads
from official.vision.modeling.layers import nn_layers


def build_qat_classification_model(
    model: tf_keras.Model,
    quantization: common.Quantization,
    input_specs: tf_keras.layers.InputSpec,
    model_config: configs.image_classification.ImageClassificationModel,
    l2_regularizer: tf_keras.regularizers.Regularizer = None
) -> tf_keras.Model:  # pytype: disable=annotation-type-mismatch  # typed-keras
  """Apply model optimization techniques.

  Args:
    model: The model applying model optimization techniques.
    quantization: The Quantization config.
    input_specs: `tf_keras.layers.InputSpec` specs of the input tensor.
    model_config: The model config.
    l2_regularizer: tf_keras.regularizers.Regularizer object. Default to None.

  Returns:
    model: The model that applied optimization techniques.
  """
  original_checkpoint = quantization.pretrained_original_checkpoint
  if original_checkpoint:
    ckpt = tf.train.Checkpoint(
        model=model,
        **model.checkpoint_items)
    status = ckpt.read(original_checkpoint)
    status.expect_partial().assert_existing_objects_matched()

  scope_dict = {
      'L2': tf_keras.regularizers.l2,
  }
  with tfmot.quantization.keras.quantize_scope(scope_dict):
    annotated_backbone = tfmot.quantization.keras.quantize_annotate_model(
        model.backbone)
    if quantization.change_num_bits:
      backbone = tfmot.quantization.keras.quantize_apply(
          annotated_backbone,
          scheme=n_bit_schemes.DefaultNBitQuantizeScheme(
              num_bits_weight=quantization.num_bits_weight,
              num_bits_activation=quantization.num_bits_activation))
    else:
      backbone = tfmot.quantization.keras.quantize_apply(
          annotated_backbone,
          scheme=schemes.Default8BitQuantizeScheme())

  norm_activation_config = model_config.norm_activation
  backbone_optimized_model = classification_model.ClassificationModel(
      backbone=backbone,
      num_classes=model_config.num_classes,
      input_specs=input_specs,
      dropout_rate=model_config.dropout_rate,
      kernel_regularizer=l2_regularizer,
      add_head_batch_norm=model_config.add_head_batch_norm,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon)
  for from_layer, to_layer in zip(
      model.layers, backbone_optimized_model.layers):
    if from_layer != model.backbone:
      to_layer.set_weights(from_layer.get_weights())

  with tfmot.quantization.keras.quantize_scope(scope_dict):
    def apply_quantization_to_dense(layer):
      if isinstance(layer, (tf_keras.layers.Dense,
                            tf_keras.layers.Dropout,
                            tf_keras.layers.GlobalAveragePooling2D)):
        return tfmot.quantization.keras.quantize_annotate_layer(layer)
      return layer

    backbone_optimized_model.use_legacy_config = True
    annotated_model = tf_keras.models.clone_model(
        backbone_optimized_model,
        clone_function=apply_quantization_to_dense,
    )

    annotated_model.use_legacy_config = True
    if quantization.change_num_bits:
      optimized_model = tfmot.quantization.keras.quantize_apply(
          annotated_model,
          scheme=n_bit_schemes.DefaultNBitQuantizeScheme(
              num_bits_weight=quantization.num_bits_weight,
              num_bits_activation=quantization.num_bits_activation))

    else:
      optimized_model = tfmot.quantization.keras.quantize_apply(
          annotated_model)

  return optimized_model


def _clone_function_for_fpn(layer):
  if isinstance(layer, (
      tf_keras.layers.BatchNormalization,
      tf_keras.layers.experimental.SyncBatchNormalization)):
    return tfmot.quantization.keras.quantize_annotate_layer(
        qat_nn_layers.BatchNormalizationWrapper(layer),
        qat_configs.Default8BitOutputQuantizeConfig())
  if isinstance(layer, tf_keras.layers.UpSampling2D):
    return layer
  return tfmot.quantization.keras.quantize_annotate_layer(layer)


def build_qat_retinanet(
    model: tf_keras.Model, quantization: common.Quantization,
    model_config: configs.retinanet.RetinaNet) -> tf_keras.Model:
  """Applies quantization aware training for RetinaNet model.

  Args:
    model: The model applying quantization aware training.
    quantization: The Quantization config.
    model_config: The model config.

  Returns:
    The model that applied optimization techniques.
  """

  original_checkpoint = quantization.pretrained_original_checkpoint
  if original_checkpoint is not None:
    ckpt = tf.train.Checkpoint(
        model=model,
        **model.checkpoint_items)
    status = ckpt.read(original_checkpoint)
    status.expect_partial().assert_existing_objects_matched()

  scope_dict = {
      'L2': tf_keras.regularizers.l2,
      'BatchNormalizationWrapper': qat_nn_layers.BatchNormalizationWrapper,
  }
  with tfmot.quantization.keras.quantize_scope(scope_dict):
    annotated_backbone = tfmot.quantization.keras.quantize_annotate_model(
        model.backbone)
    optimized_backbone = tfmot.quantization.keras.quantize_apply(
        annotated_backbone,
        scheme=schemes.Default8BitQuantizeScheme())
    decoder = model.decoder
    if quantization.quantize_detection_decoder:
      if not isinstance(decoder, fpn.FPN):
        raise ValueError('Currently only supports FPN.')

      decoder = tf_keras.models.clone_model(
          decoder,
          clone_function=_clone_function_for_fpn,
      )
      decoder = tfmot.quantization.keras.quantize_apply(decoder)
      decoder = tfmot.quantization.keras.remove_input_range(decoder)

    head = model.head
    if quantization.quantize_detection_head:
      if not isinstance(head, dense_prediction_heads.RetinaNetHead):
        raise ValueError('Currently only supports RetinaNetHead.')
      head = (
          dense_prediction_heads_qat.RetinaNetHeadQuantized.from_config(
              head.get_config()))

  optimized_model = retinanet_model.RetinaNetModel(
      backbone=optimized_backbone,
      decoder=decoder,
      head=head,
      detection_generator=model.detection_generator,
      anchor_boxes=model.anchor_boxes,
      min_level=model_config.min_level,
      max_level=model_config.max_level,
      num_scales=model_config.anchor.num_scales,
      aspect_ratios=model_config.anchor.aspect_ratios,
      anchor_size=model_config.anchor.anchor_size)

  if quantization.quantize_detection_head:
    # Call the model with dummy input to build the head part.
    dummpy_input = tf.zeros([1] + model_config.input_size)
    height, width, _ = model_config.input_size
    image_shape = [[height, width]]
    optimized_model.call(dummpy_input, image_shape=image_shape, training=False)
    helper.copy_original_weights(model.head, optimized_model.head)
  return optimized_model


def build_qat_segmentation_model(
    model: tf_keras.Model, quantization: common.Quantization,
    input_specs: tf_keras.layers.InputSpec) -> tf_keras.Model:
  """Applies quantization aware training for segmentation model.

  Args:
    model: The model applying quantization aware training.
    quantization: The Quantization config.
    input_specs: The shape specifications of input tensor.

  Returns:
    The model that applied optimization techniques.
  """

  original_checkpoint = quantization.pretrained_original_checkpoint
  if original_checkpoint is not None:
    ckpt = tf.train.Checkpoint(model=model, **model.checkpoint_items)
    status = ckpt.read(original_checkpoint)
    status.expect_partial().assert_existing_objects_matched()

  # Build quantization compatible model.
  model = qat_segmentation_model.SegmentationModelQuantized(
      model.backbone, model.decoder, model.head, input_specs)

  scope_dict = {
      'L2': tf_keras.regularizers.l2,
  }

  model.use_legacy_config = True  # Ensures old Keras serialization format
  # Apply QAT to backbone (a tf_keras.Model) first.
  with tfmot.quantization.keras.quantize_scope(scope_dict):
    annotated_backbone = tfmot.quantization.keras.quantize_annotate_model(
        model.backbone)
    optimized_backbone = tfmot.quantization.keras.quantize_apply(
        annotated_backbone, scheme=schemes.Default8BitQuantizeScheme())
  backbone_optimized_model = qat_segmentation_model.SegmentationModelQuantized(
      optimized_backbone, model.decoder, model.head, input_specs)

  # Copy over all remaining layers.
  for from_layer, to_layer in zip(model.layers,
                                  backbone_optimized_model.layers):
    if from_layer != model.backbone:
      to_layer.set_weights(from_layer.get_weights())

  with tfmot.quantization.keras.quantize_scope(scope_dict):

    def apply_quantization_to_layers(layer):
      if isinstance(layer, (segmentation_heads.SegmentationHead,
                            nn_layers.SpatialPyramidPooling, aspp.ASPP)):
        return tfmot.quantization.keras.quantize_annotate_layer(layer)
      return layer

    backbone_optimized_model.use_legacy_config = True
    annotated_model = tf_keras.models.clone_model(
        backbone_optimized_model,
        clone_function=apply_quantization_to_layers,
    )
    annotated_model.use_legacy_config = True
    optimized_model = tfmot.quantization.keras.quantize_apply(
        annotated_model, scheme=schemes.Default8BitQuantizeScheme())

  return optimized_model