tensorflow/models

View on GitHub
research/slim/nets/post_training_quantization.py

Summary

Maintainability
A
3 hrs
Test Coverage
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Export quantized tflite model from a trained checkpoint."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
from absl import app
from absl import flags
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds
from nets import nets_factory
from preprocessing import preprocessing_factory

flags.DEFINE_string("model_name", None,
                    "The name of the architecture to quantize.")
flags.DEFINE_string("checkpoint_path", None, "Path to the training checkpoint.")
flags.DEFINE_string("dataset_name", "imagenet2012",
                    "Name of the dataset to use for quantization calibration.")
flags.DEFINE_string("dataset_dir", None, "Dataset location.")
flags.DEFINE_string(
    "dataset_split", "train",
    "The dataset split (train, validation etc.) to use for calibration.")
flags.DEFINE_string("output_tflite", None, "Path to output tflite file.")
flags.DEFINE_boolean(
    "use_model_specific_preprocessing", False,
    "When true, uses the preprocessing corresponding to the model as specified "
    "in preprocessing factory.")
flags.DEFINE_boolean("enable_ema", True,
                     "Load exponential moving average version of variables.")
flags.DEFINE_integer(
    "num_steps", 1000,
    "Number of post-training quantization calibration steps to run.")
flags.DEFINE_integer("image_size", 224, "Size of the input image.")
flags.DEFINE_integer("num_classes", 1001,
                     "Number of output classes for the model.")

FLAGS = flags.FLAGS

# Mean and standard deviation used for normalizing the image tensor.
_MEAN_RGB = 127.5
_STD_RGB = 127.5


def _preprocess_for_quantization(image_data, image_size, crop_padding=32):
  """Crops to center of image with padding then scales, normalizes image_size.

  Args:
    image_data: A 3D Tensor representing the RGB image data. Image can be of
      arbitrary height and width.
    image_size: image height/width dimension.
    crop_padding: the padding size to use when centering the crop.

  Returns:
    A decoded and cropped image Tensor. Image is normalized to [-1,1].

  """

  shape = tf.shape(image_data)
  image_height = shape[0]
  image_width = shape[1]

  padded_center_crop_size = tf.cast(
      (image_size * 1.0 / (image_size + crop_padding)) *
      tf.cast(tf.minimum(image_height, image_width), tf.float32), tf.int32)

  offset_height = ((image_height - padded_center_crop_size) + 1) // 2
  offset_width = ((image_width - padded_center_crop_size) + 1) // 2

  image = tf.image.crop_to_bounding_box(
      image_data,
      offset_height=offset_height,
      offset_width=offset_width,
      target_height=padded_center_crop_size,
      target_width=padded_center_crop_size)

  image = tf.image.resize([image], [image_size, image_size],
                          method=tf.image.ResizeMethod.BICUBIC)[0]
  image = tf.cast(image, tf.float32)
  image -= tf.constant(_MEAN_RGB)
  image /= tf.constant(_STD_RGB)
  return image


def restore_model(sess, checkpoint_path, enable_ema=True):
  """Restore variables from the checkpoint into the provided session.

  Args:
    sess: A tensorflow session where the checkpoint will be loaded.
    checkpoint_path: Path to the trained checkpoint.
    enable_ema: (optional) Whether to load the exponential moving average (ema)
      version of the tensorflow variables. Defaults to True.
  """
  if enable_ema:
    ema = tf.train.ExponentialMovingAverage(decay=0.0)
    ema_vars = tf.trainable_variables() + tf.get_collection("moving_vars")
    for v in tf.global_variables():
      if "moving_mean" in v.name or "moving_variance" in v.name:
        ema_vars.append(v)
    ema_vars = list(set(ema_vars))
    var_dict = ema.variables_to_restore(ema_vars)
  else:
    var_dict = None

  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver(var_dict, max_to_keep=1)
  saver.restore(sess, checkpoint_path)


def _representative_dataset_gen():
  """Gets a python generator of numpy arrays for the given dataset."""
  image_size = FLAGS.image_size
  dataset = tfds.builder(FLAGS.dataset_name, data_dir=FLAGS.dataset_dir)
  dataset.download_and_prepare()
  data = dataset.as_dataset()[FLAGS.dataset_split]
  iterator = tf.data.make_one_shot_iterator(data)
  if FLAGS.use_model_specific_preprocessing:
    preprocess_fn = functools.partial(
        preprocessing_factory.get_preprocessing(name=FLAGS.model_name),
        output_height=image_size,
        output_width=image_size)
  else:
    preprocess_fn = functools.partial(
        _preprocess_for_quantization, image_size=image_size)
  features = iterator.get_next()
  image = features["image"]
  image = preprocess_fn(image)
  image = tf.reshape(image, [1, image_size, image_size, 3])
  for _ in range(FLAGS.num_steps):
    yield [image.eval()]


def main(_):
  with tf.Graph().as_default(), tf.Session() as sess:
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name, num_classes=FLAGS.num_classes, is_training=False)
    image_size = FLAGS.image_size
    images = tf.placeholder(
        tf.float32, shape=(1, image_size, image_size, 3), name="images")

    logits, _ = network_fn(images)

    output_tensor = tf.nn.softmax(logits)
    restore_model(sess, FLAGS.checkpoint_path, enable_ema=FLAGS.enable_ema)

    converter = tf.lite.TFLiteConverter.from_session(sess, [images],
                                                     [output_tensor])

    converter.representative_dataset = tf.lite.RepresentativeDataset(
        _representative_dataset_gen)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.inference_input_type = tf.int8
    converter.inference_output_type = tf.int8
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]

    tflite_buffer = converter.convert()
    with tf.gfile.GFile(FLAGS.output_tflite, "wb") as output_tflite:
      output_tflite.write(tflite_buffer)
  print("tflite model written to %s" % FLAGS.output_tflite)


if __name__ == "__main__":
  flags.mark_flag_as_required("model_name")
  flags.mark_flag_as_required("checkpoint_path")
  flags.mark_flag_as_required("dataset_dir")
  flags.mark_flag_as_required("output_tflite")
  app.run(main)