tensorflow/models

View on GitHub
research/attention_ocr/python/model_export.py

Summary

Maintainability
A
2 hrs
Test Coverage
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Converts existing checkpoint into a SavedModel.

Usage example:
python model_export.py \
  --logtostderr --checkpoint=model.ckpt-399731 \
  --export_dir=/tmp/attention_ocr_export
"""
import os

import tensorflow as tf
from tensorflow import app
from tensorflow.contrib import slim
from tensorflow.compat.v1 import flags

import common_flags
import model_export_lib

FLAGS = flags.FLAGS
common_flags.define()

flags.DEFINE_string('export_dir', None, 'Directory to export model files to.')
flags.DEFINE_integer(
    'image_width', None,
    'Image width used during training (or crop width if used)'
    ' If not set, the dataset default is used instead.')
flags.DEFINE_integer(
    'image_height', None,
    'Image height used during training(or crop height if used)'
    ' If not set, the dataset default is used instead.')
flags.DEFINE_string('work_dir', '/tmp',
                    'A directory to store temporary files.')
flags.DEFINE_integer('version_number', 1, 'Version number of the model')
flags.DEFINE_bool(
    'export_for_serving', True,
    'Whether the exported model accepts serialized tf.Example '
    'protos as input')


def get_checkpoint_path():
  """Returns a path to a checkpoint based on specified commandline flags.

  In order to specify a full path to a checkpoint use --checkpoint flag.
  Alternatively, if --train_log_dir was specified it will return a path to the
  most recent checkpoint.

  Raises:
    ValueError: in case it can't find a checkpoint.

  Returns:
    A string.
  """
  if FLAGS.checkpoint:
    return FLAGS.checkpoint
  else:
    model_save_path = tf.train.latest_checkpoint(FLAGS.train_log_dir)
    if not model_save_path:
      raise ValueError('Can\'t find a checkpoint in: %s' % FLAGS.train_log_dir)
    return model_save_path


def export_model(export_dir,
                 export_for_serving,
                 batch_size=None,
                 crop_image_width=None,
                 crop_image_height=None):
  """Exports a model to the named directory.

  Note that --datatset_name and --checkpoint are required and parsed by the
  underlying module common_flags.

  Args:
    export_dir: The output dir where model is exported to.
    export_for_serving: If True, expects a serialized image as input and attach
      image normalization as part of exported graph.
    batch_size: For non-serving export, the input batch_size needs to be
      specified.
    crop_image_width: Width of the input image. Uses the dataset default if
      None.
    crop_image_height: Height of the input image. Uses the dataset default if
      None.

  Returns:
    Returns the model signature_def.
  """
  # Dataset object used only to get all parameters for the model.
  dataset = common_flags.create_dataset(split_name='test')
  model = common_flags.create_model(
      dataset.num_char_classes,
      dataset.max_sequence_length,
      dataset.num_of_views,
      dataset.null_code,
      charset=dataset.charset)
  dataset_image_height, dataset_image_width, image_depth = dataset.image_shape

  # Add check for charmap file
  if not os.path.exists(dataset.charset_file):
    raise ValueError('No charset defined at {}: export will fail'.format(
        dataset.charset))

  # Default to dataset dimensions, otherwise use provided dimensions.
  image_width = crop_image_width or dataset_image_width
  image_height = crop_image_height or dataset_image_height

  if export_for_serving:
    images_orig = tf.compat.v1.placeholder(
        tf.string, shape=[batch_size], name='tf_example')
    images_orig_float = model_export_lib.generate_tfexample_image(
        images_orig,
        image_height,
        image_width,
        image_depth,
        name='float_images')
  else:
    images_shape = (batch_size, image_height, image_width, image_depth)
    images_orig = tf.compat.v1.placeholder(
        tf.uint8, shape=images_shape, name='original_image')
    images_orig_float = tf.image.convert_image_dtype(
        images_orig, dtype=tf.float32, name='float_images')

  endpoints = model.create_base(images_orig_float, labels_one_hot=None)

  sess = tf.compat.v1.Session()
  saver = tf.compat.v1.train.Saver(
      slim.get_variables_to_restore(), sharded=True)
  saver.restore(sess, get_checkpoint_path())
  tf.compat.v1.logging.info('Model restored successfully.')

  # Create model signature.
  if export_for_serving:
    input_tensors = {
        tf.saved_model.CLASSIFY_INPUTS: images_orig
    }
  else:
    input_tensors = {'images': images_orig}
  signature_inputs = model_export_lib.build_tensor_info(input_tensors)
  # NOTE: Tensors 'image_float' and 'chars_logit' are used by the inference
  # or to compute saliency maps.
  output_tensors = {
      'images_float': images_orig_float,
      'predictions': endpoints.predicted_chars,
      'scores': endpoints.predicted_scores,
      'chars_logit': endpoints.chars_logit,
      'predicted_length': endpoints.predicted_length,
      'predicted_text': endpoints.predicted_text,
      'predicted_conf': endpoints.predicted_conf,
      'normalized_seq_conf': endpoints.normalized_seq_conf
  }
  for i, t in enumerate(
      model_export_lib.attention_ocr_attention_masks(
          dataset.max_sequence_length)):
    output_tensors['attention_mask_%d' % i] = t
  signature_outputs = model_export_lib.build_tensor_info(output_tensors)
  signature_def = tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
      signature_inputs, signature_outputs,
      tf.saved_model.CLASSIFY_METHOD_NAME)
  # Save model.
  builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir)
  builder.add_meta_graph_and_variables(
      sess, [tf.saved_model.SERVING],
      signature_def_map={
          tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
              signature_def
      },
      main_op=tf.compat.v1.tables_initializer(),
      strip_default_attrs=True)
  builder.save()
  tf.compat.v1.logging.info('Model has been exported to %s' % export_dir)

  return signature_def


def main(unused_argv):
  if os.path.exists(FLAGS.export_dir):
    raise ValueError('export_dir already exists: exporting will fail')

  export_model(FLAGS.export_dir, FLAGS.export_for_serving, FLAGS.batch_size,
               FLAGS.image_width, FLAGS.image_height)


if __name__ == '__main__':
  flags.mark_flag_as_required('dataset_name')
  flags.mark_flag_as_required('export_dir')
  app.run(main)