research/slim/nets/post_training_quantization.py
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Export quantized tflite model from a trained checkpoint."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl import app
from absl import flags
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds
from nets import nets_factory
from preprocessing import preprocessing_factory
flags.DEFINE_string("model_name", None,
"The name of the architecture to quantize.")
flags.DEFINE_string("checkpoint_path", None, "Path to the training checkpoint.")
flags.DEFINE_string("dataset_name", "imagenet2012",
"Name of the dataset to use for quantization calibration.")
flags.DEFINE_string("dataset_dir", None, "Dataset location.")
flags.DEFINE_string(
"dataset_split", "train",
"The dataset split (train, validation etc.) to use for calibration.")
flags.DEFINE_string("output_tflite", None, "Path to output tflite file.")
flags.DEFINE_boolean(
"use_model_specific_preprocessing", False,
"When true, uses the preprocessing corresponding to the model as specified "
"in preprocessing factory.")
flags.DEFINE_boolean("enable_ema", True,
"Load exponential moving average version of variables.")
flags.DEFINE_integer(
"num_steps", 1000,
"Number of post-training quantization calibration steps to run.")
flags.DEFINE_integer("image_size", 224, "Size of the input image.")
flags.DEFINE_integer("num_classes", 1001,
"Number of output classes for the model.")
FLAGS = flags.FLAGS
# Mean and standard deviation used for normalizing the image tensor.
_MEAN_RGB = 127.5
_STD_RGB = 127.5
def _preprocess_for_quantization(image_data, image_size, crop_padding=32):
"""Crops to center of image with padding then scales, normalizes image_size.
Args:
image_data: A 3D Tensor representing the RGB image data. Image can be of
arbitrary height and width.
image_size: image height/width dimension.
crop_padding: the padding size to use when centering the crop.
Returns:
A decoded and cropped image Tensor. Image is normalized to [-1,1].
"""
shape = tf.shape(image_data)
image_height = shape[0]
image_width = shape[1]
padded_center_crop_size = tf.cast(
(image_size * 1.0 / (image_size + crop_padding)) *
tf.cast(tf.minimum(image_height, image_width), tf.float32), tf.int32)
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
image = tf.image.crop_to_bounding_box(
image_data,
offset_height=offset_height,
offset_width=offset_width,
target_height=padded_center_crop_size,
target_width=padded_center_crop_size)
image = tf.image.resize([image], [image_size, image_size],
method=tf.image.ResizeMethod.BICUBIC)[0]
image = tf.cast(image, tf.float32)
image -= tf.constant(_MEAN_RGB)
image /= tf.constant(_STD_RGB)
return image
def restore_model(sess, checkpoint_path, enable_ema=True):
"""Restore variables from the checkpoint into the provided session.
Args:
sess: A tensorflow session where the checkpoint will be loaded.
checkpoint_path: Path to the trained checkpoint.
enable_ema: (optional) Whether to load the exponential moving average (ema)
version of the tensorflow variables. Defaults to True.
"""
if enable_ema:
ema = tf.train.ExponentialMovingAverage(decay=0.0)
ema_vars = tf.trainable_variables() + tf.get_collection("moving_vars")
for v in tf.global_variables():
if "moving_mean" in v.name or "moving_variance" in v.name:
ema_vars.append(v)
ema_vars = list(set(ema_vars))
var_dict = ema.variables_to_restore(ema_vars)
else:
var_dict = None
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(var_dict, max_to_keep=1)
saver.restore(sess, checkpoint_path)
def _representative_dataset_gen():
"""Gets a python generator of numpy arrays for the given dataset."""
image_size = FLAGS.image_size
dataset = tfds.builder(FLAGS.dataset_name, data_dir=FLAGS.dataset_dir)
dataset.download_and_prepare()
data = dataset.as_dataset()[FLAGS.dataset_split]
iterator = tf.data.make_one_shot_iterator(data)
if FLAGS.use_model_specific_preprocessing:
preprocess_fn = functools.partial(
preprocessing_factory.get_preprocessing(name=FLAGS.model_name),
output_height=image_size,
output_width=image_size)
else:
preprocess_fn = functools.partial(
_preprocess_for_quantization, image_size=image_size)
features = iterator.get_next()
image = features["image"]
image = preprocess_fn(image)
image = tf.reshape(image, [1, image_size, image_size, 3])
for _ in range(FLAGS.num_steps):
yield [image.eval()]
def main(_):
with tf.Graph().as_default(), tf.Session() as sess:
network_fn = nets_factory.get_network_fn(
FLAGS.model_name, num_classes=FLAGS.num_classes, is_training=False)
image_size = FLAGS.image_size
images = tf.placeholder(
tf.float32, shape=(1, image_size, image_size, 3), name="images")
logits, _ = network_fn(images)
output_tensor = tf.nn.softmax(logits)
restore_model(sess, FLAGS.checkpoint_path, enable_ema=FLAGS.enable_ema)
converter = tf.lite.TFLiteConverter.from_session(sess, [images],
[output_tensor])
converter.representative_dataset = tf.lite.RepresentativeDataset(
_representative_dataset_gen)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
tflite_buffer = converter.convert()
with tf.gfile.GFile(FLAGS.output_tflite, "wb") as output_tflite:
output_tflite.write(tflite_buffer)
print("tflite model written to %s" % FLAGS.output_tflite)
if __name__ == "__main__":
flags.mark_flag_as_required("model_name")
flags.mark_flag_as_required("checkpoint_path")
flags.mark_flag_as_required("dataset_dir")
flags.mark_flag_as_required("output_tflite")
app.run(main)