tensorflow/models

View on GitHub
official/vision/serving/export_tflite_lib.py

Summary

Maintainability
B
5 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library to facilitate TFLite model conversion."""
import functools
from typing import Iterator, List, Optional

from absl import logging
import tensorflow as tf, tf_keras

from official.core import base_task
from official.core import config_definitions as cfg
from official.vision import configs
from official.vision import tasks


def create_representative_dataset(
    params: cfg.ExperimentConfig,
    task: Optional[base_task.Task] = None) -> tf.data.Dataset:
  """Creates a tf.data.Dataset to load images for representative dataset.

  Args:
    params: An ExperimentConfig.
    task: An optional task instance. If it is None, task will be built according
      to the task type in params.

  Returns:
    A tf.data.Dataset instance.

  Raises:
    ValueError: If task is not supported.
  """
  if task is None:
    if isinstance(params.task,
                  configs.image_classification.ImageClassificationTask):

      task = tasks.image_classification.ImageClassificationTask(params.task)
    elif isinstance(params.task, configs.retinanet.RetinaNetTask):
      task = tasks.retinanet.RetinaNetTask(params.task)
    elif isinstance(params.task, configs.maskrcnn.MaskRCNNTask):
      task = tasks.maskrcnn.MaskRCNNTask(params.task)
    elif isinstance(params.task,
                    configs.semantic_segmentation.SemanticSegmentationTask):
      task = tasks.semantic_segmentation.SemanticSegmentationTask(params.task)
    else:
      raise ValueError('Task {} not supported.'.format(type(params.task)))
  # Ensure batch size is 1 for TFLite model.
  params.task.train_data.global_batch_size = 1
  params.task.train_data.dtype = 'float32'
  logging.info('Task config: %s', params.task.as_dict())
  return task.build_inputs(params=params.task.train_data)


def representative_dataset(
    params: cfg.ExperimentConfig,
    task: Optional[base_task.Task] = None,
    calibration_steps: int = 2000) -> Iterator[List[tf.Tensor]]:
  """"Creates representative dataset for input calibration.

  Args:
    params: An ExperimentConfig.
    task: An optional task instance. If it is None, task will be built according
      to the task type in params.
    calibration_steps: The steps to do calibration.

  Yields:
    An input image tensor.
  """
  dataset = create_representative_dataset(params=params, task=task)
  for image, _ in dataset.take(calibration_steps):
    # Skip images that do not have 3 channels.
    if image.shape[-1] != 3:
      continue
    yield [image]


def convert_tflite_model(
    saved_model_dir: Optional[str] = None,
    concrete_function: Optional[tf.types.experimental.ConcreteFunction] = None,
    model: Optional[tf.Module] = None,
    quant_type: Optional[str] = None,
    params: Optional[cfg.ExperimentConfig] = None,
    task: Optional[base_task.Task] = None,
    calibration_steps: Optional[int] = 2000,
    denylisted_ops: Optional[List[str]] = None,
) -> 'bytes':
  """Converts and returns a TFLite model.

  Args:
    saved_model_dir: The directory to the SavedModel.
    concrete_function: An optional concrete function to be exported.
    model: An optional tf_keras.Model instance. If both `saved_model_dir` and
      `concrete_function` are not available, convert this model to TFLite.
    quant_type: The post training quantization (PTQ) method. It can be one of
      `default` (dynamic range), `fp16` (float16), `int8` (integer wih float
      fallback), `int8_full` (integer only) and None (no quantization).
    params: An optional ExperimentConfig to load and preprocess input images to
      do calibration for integer quantization.
    task: An optional task instance. If it is None, task will be built according
      to the task type in params.
    calibration_steps: The steps to do calibration.
    denylisted_ops: A list of strings containing ops that are excluded from
      integer quantization.

  Returns:
    A converted TFLite model with optional PTQ.

  Raises:
    ValueError: If `representative_dataset_path` is not present if integer
      quantization is requested, or `saved_model_dir`, `concrete_function` or
      `model` are not provided.
  """
  if saved_model_dir:
    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
  elif concrete_function is not None:
    converter = tf.lite.TFLiteConverter.from_concrete_functions(
        [concrete_function]
    )
  elif model is not None:
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
  else:
    raise ValueError(
        '`saved_model_dir`, `model` or `concrete_function` must be specified.'
    )

  if quant_type:
    if quant_type.startswith('int8'):
      converter.optimizations = [tf.lite.Optimize.DEFAULT]
      converter.representative_dataset = functools.partial(
          representative_dataset,
          params=params,
          task=task,
          calibration_steps=calibration_steps)
      if quant_type.startswith('int8_full'):
        converter.target_spec.supported_ops = [
            tf.lite.OpsSet.TFLITE_BUILTINS_INT8
        ]
      if quant_type == 'int8_full':
        converter.inference_input_type = tf.uint8
        converter.inference_output_type = tf.uint8
      if quant_type == 'int8_full_int8_io':
        converter.inference_input_type = tf.int8
        converter.inference_output_type = tf.int8

      if denylisted_ops:
        debug_options = tf.lite.experimental.QuantizationDebugOptions(
            denylisted_ops=denylisted_ops)
        debugger = tf.lite.experimental.QuantizationDebugger(
            converter=converter,
            debug_dataset=functools.partial(
                representative_dataset,
                params=params,
                calibration_steps=calibration_steps),
            debug_options=debug_options)
        debugger.run()
        return debugger.get_nondebug_quantized_model()

    elif quant_type == 'uint8':
      converter.optimizations = [tf.lite.Optimize.DEFAULT]
      converter.default_ranges_stats = (-10, 10)
      converter.inference_type = tf.uint8
      converter.quantized_input_stats = {'input_placeholder': (0., 1.)}
    elif quant_type == 'fp16':
      converter.optimizations = [tf.lite.Optimize.DEFAULT]
      converter.target_spec.supported_types = [tf.float16]
    elif quant_type in ('default', 'qat_fp32_io'):
      converter.optimizations = [tf.lite.Optimize.DEFAULT]
    elif quant_type == 'qat':
      converter.optimizations = [tf.lite.Optimize.DEFAULT]
      converter.inference_input_type = tf.uint8  # or tf.int8
      converter.inference_output_type = tf.uint8  # or tf.int8
    else:
      raise ValueError(f'quantization type {quant_type} is not supported.')

  return converter.convert()