tensorflow/models

View on GitHub
official/nlp/serving/export_savedmodel_util.py

Summary

Maintainability
A
50 mins
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Common library to export a SavedModel from the export module."""
from typing import Dict, List, Optional, Union, Any

import tensorflow as tf, tf_keras

from official.core import export_base

get_timestamped_export_dir = export_base.get_timestamped_export_dir


def export(export_module: export_base.ExportModule,
           function_keys: Union[List[str], Dict[str, str]],
           export_savedmodel_dir: str,
           checkpoint_path: Optional[str] = None,
           timestamped: bool = True,
           module_key: Optional[str] = None,
           checkpoint_kwargs: Optional[Dict[str, Any]] = None) -> str:
  """Exports to SavedModel format.

  Args:
    export_module: a ExportModule with the keras Model and serving tf.functions.
    function_keys: a list of string keys to retrieve pre-defined serving
      signatures. The signaute keys will be set with defaults. If a dictionary
      is provided, the values will be used as signature keys.
    export_savedmodel_dir: Output saved model directory.
    checkpoint_path: Object-based checkpoint path or directory.
    timestamped: Whether to export the savedmodel to a timestamped directory.
    module_key: Optional string to identify a checkpoint object to load for the
      model in the export module.
    checkpoint_kwargs: Optional dict used as keyword args to create the
      checkpoint object. Not used if module_key is present.

  Returns:
    The savedmodel directory path.
  """
  save_options = tf.saved_model.SaveOptions(function_aliases={
      'tpu_candidate': export_module.serve,
  })
  if module_key:
    kwargs = {module_key: export_module.model}
    checkpoint = tf.train.Checkpoint(**kwargs)
  elif checkpoint_kwargs:
    checkpoint = tf.train.Checkpoint(**checkpoint_kwargs)
  else:
    checkpoint = None
  return export_base.export(
      export_module,
      function_keys,
      export_savedmodel_dir,
      checkpoint_path,
      timestamped,
      save_options,
      checkpoint=checkpoint)