tensorflow/tensorflow

View on GitHub
tensorflow/python/framework/meta_graph.py

Summary

Maintainability
F
4 days
Test Coverage
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""MetaGraph and related functions."""
import copy
from packaging import version as packaging_version  # pylint: disable=g-bad-import-order
import os.path
import re
import sys

from google.protobuf.any_pb2 import Any
from google.protobuf import text_format

from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import op_def_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.client import pywrap_tf_session as c_api
from tensorflow.python.eager import context
from tensorflow.python.framework import byte_swap_tensor as bst
from tensorflow.python.framework import error_interpolation
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import importer
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor
from tensorflow.python.framework import versions
from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat


# Prefix to be added to unbound input names so they are easily identifiable.
_UNBOUND_INPUT_PREFIX = "$unbound_inputs_"

# List of collections that didn't register proto functions, as a result in
# a previously exported meta_graph the items are of a different data type.
_COMPAT_COLLECTION_LIST = [ops.GraphKeys.LOCAL_VARIABLES,
                           ops.GraphKeys.MODEL_VARIABLES,
                           ops.GraphKeys.METRIC_VARIABLES]


def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False):
  """Create a `NodeDef` proto with export_scope stripped.

  Args:
    from_node_def: A `node_def_pb2.NodeDef` protocol buffer.
    export_scope: A `string` representing the name scope to remove.
    unbound_inputs: An array of unbound input names if they exist.
    clear_devices: Boolean which controls whether to clear device information
      from node_def. Default false.

  Returns:
    A `node_def_pb2.NodeDef` protocol buffer.
  """
  node_def = copy.deepcopy(from_node_def)
  for i, v in enumerate(node_def.input):
    if (export_scope and
        not node_def.input[i].lstrip("^").startswith(export_scope)):
      # Adds "$unbound_inputs_" prefix to the unbound name so they are easily
      # identifiable.
      node_def.input[i] = re.sub(r"([\^]|^)(.*)",
                                 r"\1" + _UNBOUND_INPUT_PREFIX + r"\2",
                                 compat.as_str(v))
      unbound_inputs.append(node_def.input[i])
    else:
      node_def.input[i] = ops.strip_name_scope(v, export_scope)
  node_def.name = compat.as_bytes(
      ops.strip_name_scope(from_node_def.name, export_scope))
  for k, v in from_node_def.attr.items():
    if k == "_class":
      new_s = [compat.as_bytes(
          ops.strip_name_scope(s, export_scope)) for s in v.list.s
               if not export_scope or
               compat.as_str(s).split("@")[1].startswith(export_scope)]
      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(
          list=attr_value_pb2.AttrValue.ListValue(s=new_s)))
    elif node_def.op in ("Enter", "RefEnter") and k == "frame_name":
      if not export_scope or compat.as_str(v.s).startswith(export_scope):
        new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope))
      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s))
    else:
      node_def.attr[k].CopyFrom(v)

  if clear_devices:
    node_def.device = ""

  return node_def


def _read_file(filename):
  """Reads a file containing `GraphDef` and returns the protocol buffer.

  Args:
    filename: `graph_def` filename including the path.

  Returns:
    A `GraphDef` protocol buffer.

  Raises:
    IOError: If the file doesn't exist, or cannot be successfully parsed.
  """
  graph_def = graph_pb2.GraphDef()
  if not file_io.file_exists(filename):
    raise IOError(f"File {filename} does not exist.")
  # First try to read it as a binary file.
  with file_io.FileIO(filename, "rb") as f:
    file_content = f.read()
  try:
    graph_def.ParseFromString(file_content)
    return graph_def
  except Exception:  # pylint: disable=broad-except
    pass

  # Next try to read it as a text file.
  try:
    text_format.Merge(file_content, graph_def)
  except text_format.ParseError as e:
    raise IOError(f"Cannot parse file {filename}: {str(e)}.")

  return graph_def


def ops_used_by_graph_def(graph_def):
  """Collect the list of ops used by a graph.

  Does not validate that the ops are all registered.

  Args:
    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.

  Returns:
    A list of strings, each naming an op used by the graph.
  """
  # Map function names to definitions
  name_to_function = {}
  for fun in graph_def.library.function:
    name_to_function[fun.signature.name] = fun

  # Collect the list of op names.  Since functions can reference functions, we
  # need a recursive traversal.
  used_ops = set()  # Includes both primitive ops and functions
  functions_to_process = []  # A subset of used_ops

  def mark_op_as_used(op):
    if op not in used_ops and op in name_to_function:
      functions_to_process.append(name_to_function[op])
    used_ops.add(op)

  def process_node(node):
    mark_op_as_used(node.op)
    if node.op in ["PartitionedCall", "StatefulPartitionedCall"]:
      mark_op_as_used(node.attr["f"].func.name)

  for node in graph_def.node:
    process_node(node)
  while functions_to_process:
    fun = functions_to_process.pop()
    for node in fun.node_def:
      process_node(node)

  return [op for op in used_ops if op not in name_to_function]


def stripped_op_list_for_graph(graph_def):
  """Collect the stripped OpDefs for ops used by a graph.

  This function computes the `stripped_op_list` field of `MetaGraphDef` and
  similar protos.  The result can be communicated from the producer to the
  consumer, which can then use the C++ function
  `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.

  Args:
    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.

  Returns:
    An `OpList` of ops used by the graph.
  """
  # This is similar to StrippedOpListForGraph in C++, but unlike its
  # C++ counterpart, this version does not require all ops to be registered.
  # This is done to support Prelu fusion in tfjs.
  used_ops = ops_used_by_graph_def(graph_def)
  op_defs = []
  for op in sorted(used_ops):
    op_def = op_def_registry.get(op)
    if op_def is not None:
      op_defs.append(op_def)
  return op_def_pb2.OpList(op=op_defs)


def _get_kind_name(item):
  """Returns the kind name in CollectionDef.

  Args:
    item: A data item.

  Returns:
    The string representation of the kind in CollectionDef.
  """
  if isinstance(item, (str, bytes)):
    kind = "bytes_list"
  elif isinstance(item, int):
    kind = "int64_list"
  elif isinstance(item, float):
    kind = "float_list"
  elif isinstance(item, Any):
    kind = "any_list"
  else:
    kind = "node_list"
  return kind


SAVE_AND_RESTORE_OPS = ["SaveV2",
                        "Save", "SaveSlice",
                        "LegacySave", "LegacySaveSlice",
                        "RestoreV2",
                        "Restore", "RestoreSlice",
                        "LegacyRestore", "LegacyRestoreSlice"]


def _get_scope(node_name):
  """Extract the scope name from a node name.

  The scope name is everything before the final slash,
  not including any ^ prefix denoting a control dependency.

  Args:
    node_name: the full name of an Op or a Tensor in the graph.
  Returns:
    The deepest named scope containing the node.
  Raises:
    ValueError: if tensor_name is None or empty
  """
  if not node_name:
    raise ValueError(
        f"Node name cannot be empty or None. Received: {node_name}.")

  # Control dependency inputs start with ^.
  if node_name.startswith("^"):
    node_name = node_name[1:]
  if "/" in node_name:
    scope, _ = node_name.rsplit("/", 1)
    return scope

  return ""


def _find_extraneous_saver_nodes(graph_def, saver_def):
  """Identifies any nodes in the graph_def related to unused Savers.

  This approach assumes that each Saver is cleanly isolated in its own name
  scope, so we need only identify the scopes associated with extraneous Savers
  and return all the nodes in those scopes.

  Args:
    graph_def: a GraphDef proto to evaluate.
    saver_def: a SaverDef proto referencing Save/Restore ops to be retained.
  Returns:
    An iterable of node names that may be safely omitted.
  """
  # TODO(soergel): confirm that the assumption of scope isolation is valid.
  # If not, we need to walk up the graph from any restore_all nodes, and walk
  # down the graph from any Save/Restore nodes.  I drafted that approach too,
  # but it seems unnecessarily complex given the name scope solution.

  # load the graph DAG in minimal form, without initializing a full Graph object
  nodes = {
      node_def.name: (
          set(tensor.get_op_name(x) for x in node_def.input), node_def.op)
      for node_def in graph_def.node
  }

  retain_scope_save = None
  retain_scope_restore = None
  # It's possible to have no saver if the graph has no Variables
  if saver_def is not None:
    save_op_name = tensor.get_op_name(saver_def.save_tensor_name)
    restore_op_name = tensor.get_op_name(saver_def.restore_op_name)

    # The save and restore scopes should always be the same, but if they differ
    # for some reason, we retain them both to be safe.
    retain_scope_restore = _get_scope(restore_op_name) + "/"
    retain_scope_save = _get_scope(save_op_name) + "/"

  all_saver_node_names = set(
      name for name, (_, op) in nodes.items() if op in SAVE_AND_RESTORE_OPS)

  all_saver_scopes = (
      set(_get_scope(x) for x in all_saver_node_names) - all_saver_node_names)
  all_saver_scopes = set(x + "/" for x in all_saver_scopes)

  extraneous_scopes = all_saver_scopes - set([retain_scope_save,
                                              retain_scope_restore])

  extraneous_node_names = set()
  for name, _ in nodes.items():
    for extraneous_scope in extraneous_scopes:
      if name.startswith(extraneous_scope):
        extraneous_node_names.add(name)
        break

  return extraneous_node_names


def _should_include_node(node_or_node_name, export_scope, exclude_nodes):
  """Returns `True` if a node should be included.

  Args:
    node_or_node_name: A node or `string` node name.
    export_scope: `string`. Name scope under which to extract the subgraph. The
      scope name will be stripped from the node definitions for easy import
      later into new name scopes.
    exclude_nodes: An iterable of nodes or `string` node names to omit from the
      export, or None.  Note no sanity-checking is done, so this list must be
      carefully constructed to avoid producing an invalid graph.

  Returns:
    `True` if the node should be included.
  """
  if not isinstance(node_or_node_name, str):
    try:
      node_name = node_or_node_name.name
    except AttributeError:
      # Keep the object that we don't know how to process.
      return True
  else:
    node_name = node_or_node_name

  if exclude_nodes and (node_or_node_name in exclude_nodes
                        or node_name in exclude_nodes):
    return False

  return (node_name.startswith(_UNBOUND_INPUT_PREFIX) or
          (not export_scope or node_name.startswith(export_scope)))


def add_collection_def(meta_graph_def, key, graph=None,
                       export_scope=None, exclude_nodes=None,
                       override_contents=None):
  """Adds a collection to MetaGraphDef protocol buffer.

  Args:
    meta_graph_def: MetaGraphDef protocol buffer.
    key: One of the GraphKeys or user-defined string.
    graph: The `Graph` from which to get collections.
    export_scope: Optional `string`. Name scope to remove.
    exclude_nodes: An iterable of nodes or `string` node names to omit from the
      collection, or None.
    override_contents: An iterable of values to place in the collection,
      ignoring the current values (if set).
  """
  if graph and not isinstance(graph, ops.Graph):
    raise TypeError(
        f"graph must be of type Graph. Received type: {type(graph)}.")

  if not isinstance(key, str) and not isinstance(key, bytes):
    logging.warning("Only collections with string type keys will be "
                    "serialized. This key has %s", type(key))
    return

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  if override_contents:
    collection_list = override_contents
  else:
    collection_list = graph.get_collection(key)

  # Remove nodes that should not be exported from the collection list.
  collection_list = [x for x in collection_list if
                     _should_include_node(x, export_scope, exclude_nodes)]
  if not collection_list:
    return

  try:
    col_def = meta_graph_def.collection_def[key]
    to_proto = ops.get_to_proto_function(key)
    proto_type = ops.get_collection_proto_type(key)
    if to_proto:
      kind = "bytes_list"
      for x in collection_list:
        # Additional type check to make sure the returned proto is indeed
        # what we expect.
        proto = to_proto(x, export_scope=export_scope)
        if proto:
          assert isinstance(proto, proto_type)
          getattr(col_def, kind).value.append(proto.SerializeToString())
    else:
      kind = _get_kind_name(collection_list[0])
      if kind == "node_list":
        for x in collection_list:
          if not export_scope or x.name.startswith(export_scope):
            getattr(col_def, kind).value.append(
                ops.strip_name_scope(x.name, export_scope))
      elif kind == "bytes_list":
        # NOTE(opensource): This force conversion is to work around the fact
        # that Python3 distinguishes between bytes and strings.
        getattr(col_def, kind).value.extend(
            [compat.as_bytes(x) for x in collection_list])
      else:
        getattr(col_def, kind).value.extend([x for x in collection_list])
  except Exception as e:  # pylint: disable=broad-except
    logging.warning("Issue encountered when serializing %s.\n"
                    "Type is unsupported, or the types of the items don't "
                    "match field type in CollectionDef. Note this is a warning "
                    "and probably safe to ignore.\n%s", key, str(e))
    if key in meta_graph_def.collection_def:
      del meta_graph_def.collection_def[key]
    return


def _is_default_attr_value(op_def, attr_name, attr_value):
  """Checks if given attribute matches the default value in the op def."""
  for attr_def in op_def.attr:
    if attr_def.name == attr_name:
      if not attr_def.HasField("default_value"):
        return False
      # c_api.EqualAttrValueWrapper returns an empty string
      # if both arguments represent an equivalent AttrValue instance.
      return not c_api.EqualAttrValueWrapper(
          attr_value.SerializeToString(),
          attr_def.default_value.SerializeToString())
  return False


def strip_graph_default_valued_attrs(meta_graph_def):
  """Strips default valued attributes for node defs in given MetaGraphDef.

  This method also sets `meta_info_def.stripped_default_attrs` in the given
  `MetaGraphDef` proto to True.

  Args:
    meta_graph_def: `MetaGraphDef` protocol buffer

  Returns:
    None.
  """
  # Map function op names to their function definitions.
  op_name_to_function = {}
  for function_def in meta_graph_def.graph_def.library.function:
    op_name_to_function[function_def.signature.name] = function_def

  def _strip_node_default_valued_attrs(node_def):
    """Removes default valued attributes from a single node def."""
    if node_def.op in op_name_to_function:
      return

    op_def = op_def_registry.get(node_def.op)
    if op_def is None:
      return

    attrs_to_strip = set()
    for attr_name, attr_value in node_def.attr.items():
      if _is_default_attr_value(op_def, attr_name, attr_value):
        attrs_to_strip.add(attr_name)

    for attr in attrs_to_strip:
      del node_def.attr[attr]

  # Process all NodeDef instances in graph_def.
  for node_def in meta_graph_def.graph_def.node:
    _strip_node_default_valued_attrs(node_def)

  # Process all NodeDef instances in graph_def.library.function.
  for function_def in meta_graph_def.graph_def.library.function:
    for function_node_def in function_def.node_def:
      _strip_node_default_valued_attrs(function_node_def)

  # Tell consumers of this graph that default valued attrs have been stripped.
  meta_graph_def.meta_info_def.stripped_default_attrs = True


def create_meta_graph_def(meta_info_def=None,
                          graph_def=None,
                          saver_def=None,
                          collection_list=None,
                          graph=None,
                          export_scope=None,
                          exclude_nodes=None,
                          clear_extraneous_savers=False,
                          strip_default_attrs=False):
  # pylint: disable=line-too-long
  """Construct and returns a `MetaGraphDef` protocol buffer.

  Args:
    meta_info_def: `MetaInfoDef` protocol buffer.
    graph_def: `GraphDef` protocol buffer.
    saver_def: `SaverDef` protocol buffer.
    collection_list: List of string keys to collect.
    graph: The `Graph` to create `MetaGraphDef` out of.
    export_scope: Optional `string`. Name scope to remove.
    exclude_nodes: An iterable of nodes or `string` node names to omit from all
      collection, or None.
    clear_extraneous_savers: Remove any preexisting SaverDefs from the SAVERS
        collection.  Note this method does not alter the graph, so any
        extraneous Save/Restore ops should have been removed already, as needed.
    strip_default_attrs: Boolean. If `True`, default-valued attributes will be
        removed from the NodeDefs. For a detailed guide, see
        [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).

  Returns:
    MetaGraphDef protocol buffer.

  Raises:
    TypeError: If the arguments are not of the correct proto buffer type.
  """
  # pylint: enable=line-too-long
  # Type check.
  if graph and not isinstance(graph, ops.Graph):
    raise TypeError(
        f"graph must be of type Graph. Received type: {type(graph)}.")
  if meta_info_def and not isinstance(meta_info_def,
                                      meta_graph_pb2.MetaGraphDef.MetaInfoDef):
    raise TypeError(
        "meta_info_def must be of type MetaInfoDef. "
        f"Received type: {type(meta_info_def)}.")
  if graph_def and not isinstance(graph_def, graph_pb2.GraphDef):
    raise TypeError(
        "graph_def must be of type GraphDef. "
        f"Received type: {type(graph_def)}.")
  if saver_def and not isinstance(saver_def, saver_pb2.SaverDef):
    raise TypeError(
        f"saver_def must be of type SaverDef. "
        f"Received type: {type(saver_def)}.")

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  # Creates a MetaGraphDef proto.
  meta_graph_def = meta_graph_pb2.MetaGraphDef()
  # Adds meta_info_def.
  if not meta_info_def:
    meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()

  # Set the tf version strings to the current tf build.
  meta_info_def.tensorflow_version = versions.__version__
  meta_info_def.tensorflow_git_version = versions.__git_version__
  meta_graph_def.meta_info_def.MergeFrom(meta_info_def)

  # Adds graph_def or the default.
  if not graph_def:
    meta_graph_def.graph_def.MergeFrom(graph.as_graph_def(add_shapes=True))
  else:
    meta_graph_def.graph_def.MergeFrom(graph_def)

  # Fills in meta_info_def.stripped_op_list using the ops from graph_def.
  # pylint: disable=g-explicit-length-test
  if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0:
    meta_graph_def.meta_info_def.stripped_op_list.MergeFrom(
        stripped_op_list_for_graph(meta_graph_def.graph_def))
  # pylint: enable=g-explicit-length-test

  # Strip default valued attributes in graph_def.
  if strip_default_attrs:
    strip_graph_default_valued_attrs(meta_graph_def)

  # Adds saver_def.
  if saver_def:
    meta_graph_def.saver_def.MergeFrom(saver_def)

  # Adds collection_list.
  if collection_list is not None:
    clist = collection_list
  else:
    clist = graph.get_all_collection_keys()

  for ctype in clist:
    if clear_extraneous_savers and ctype == ops.GraphKeys.SAVERS:
      # Avoid importing Saver here
      from_proto = ops.get_from_proto_function(ctype)
      add_collection_def(meta_graph_def, ctype,
                         graph=graph,
                         export_scope=export_scope,
                         exclude_nodes=exclude_nodes,
                         override_contents=[from_proto(saver_def)])
    else:
      add_collection_def(meta_graph_def, ctype,
                         graph=graph,
                         export_scope=export_scope,
                         exclude_nodes=exclude_nodes)
  return meta_graph_def


def read_meta_graph_file(filename):
  """Reads a file containing `MetaGraphDef` and returns the protocol buffer.

  Args:
    filename: `meta_graph_def` filename including the path.

  Returns:
    A `MetaGraphDef` protocol buffer.

  Raises:
    IOError: If the file doesn't exist, or cannot be successfully parsed.
  """
  meta_graph_def = meta_graph_pb2.MetaGraphDef()
  if not file_io.file_exists(filename):
    raise IOError(f"File does not exist. Received: {filename}.")
  # First try to read it as a binary file.
  with file_io.FileIO(filename, "rb") as f:
    file_content = f.read()
  try:
    meta_graph_def.ParseFromString(file_content)
    if sys.byteorder == "big":
      bst.swap_tensor_content_in_graph_function(meta_graph_def, "little", "big")
    return meta_graph_def
  except Exception:  # pylint: disable=broad-except
    pass

  # Next try to read it as a text file.
  try:
    text_format.Merge(file_content.decode("utf-8"), meta_graph_def)
    if sys.byteorder == "big":
      bst.swap_tensor_content_in_graph_function(meta_graph_def, "little", "big")
  except text_format.ParseError as e:
    raise IOError(f"Cannot parse file {filename}: {str(e)}.")

  return meta_graph_def


def import_scoped_meta_graph(meta_graph_or_file,
                             clear_devices=False,
                             graph=None,
                             import_scope=None,
                             input_map=None,
                             unbound_inputs_col_name="unbound_inputs",
                             restore_collections_predicate=(lambda key: True)):
  """Recreates a `Graph` saved in a `MetaGraphDef` proto.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates the desired collections, and returns a dictionary of
  all the Variables imported into the name scope.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.
    restore_collections_predicate: a predicate on collection names. A collection
      named c (i.e whose key is c) will be restored iff
      1) `restore_collections_predicate(c)` is True, and
      2) `c != unbound_inputs_col_name`.

  Returns:
    A dictionary of all the `Variables` imported into the name scope.

  Raises:
    ValueError: If the graph_def contains unbound inputs.
  """
  return import_scoped_meta_graph_with_return_elements(
      meta_graph_or_file, clear_devices, graph, import_scope, input_map,
      unbound_inputs_col_name, restore_collections_predicate)[0]


def import_scoped_meta_graph_with_return_elements(
    meta_graph_or_file,
    clear_devices=False,
    graph=None,
    import_scope=None,
    input_map=None,
    unbound_inputs_col_name="unbound_inputs",
    restore_collections_predicate=(lambda key: True),
    return_elements=None):
  """Imports graph from `MetaGraphDef` and returns vars and return elements.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates the desired collections, and returns a dictionary of
  all the Variables imported into the name scope.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.
    restore_collections_predicate: a predicate on collection names. A collection
      named c (i.e whose key is c) will be restored iff
      1) `restore_collections_predicate(c)` is True, and
      2) `c != unbound_inputs_col_name`.
    return_elements:  A list of strings containing operation names in the
      `MetaGraphDef` that will be returned as `Operation` objects; and/or
      tensor names in `MetaGraphDef` that will be returned as `Tensor` objects.

  Returns:
    A tuple of (
      dictionary of all the `Variables` imported into the name scope,
      list of `Operation` or `Tensor` objects from the `return_elements` list).

  Raises:
    ValueError: If the graph_def contains unbound inputs.

  """
  if context.executing_eagerly():
    raise ValueError("Exporting/importing meta graphs is not supported when "
                     "eager execution is enabled.")
  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
    meta_graph_def = meta_graph_or_file
  else:
    meta_graph_def = read_meta_graph_file(meta_graph_or_file)

  if unbound_inputs_col_name:
    for key, col_def in meta_graph_def.collection_def.items():
      if key == unbound_inputs_col_name:
        kind = col_def.WhichOneof("kind")
        field = getattr(col_def, kind)
        if field.value and (
            not input_map or
            sorted([compat.as_str(v) for v in field.value]) !=
            sorted(input_map)):
          raise ValueError("Graph contains unbound inputs: %s. Must "
                           "provide these inputs through input_map." % ",".join(
                               compat.as_str(v)
                               for v in field.value
                               if not input_map or v not in input_map))
        break

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  # Gathers the list of nodes we are interested in.
  with graph.as_default():
    producer_op_list = None
    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
    input_graph_def = meta_graph_def.graph_def
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
      for node in input_graph_def.node:
        node.device = ""

    scope_to_prepend_to_names = graph.unique_name(
        import_scope or "", mark_as_used=False)

    imported_return_elements = importer.import_graph_def(
        input_graph_def,
        name=(import_scope or scope_to_prepend_to_names),
        input_map=input_map,
        producer_op_list=producer_op_list,
        return_elements=return_elements)

    # TensorFlow versions before 1.9 (not inclusive) exported SavedModels
    # without a VariableDef.trainable field set.
    tf_version = meta_graph_def.meta_info_def.tensorflow_version
    if not tf_version:
      variables_have_trainable = True
    else:
      variables_have_trainable = (
          packaging_version.parse(tf_version) >= packaging_version.parse("1.9"))

    # Sort collections so we see TRAINABLE_VARIABLES first and can default these
    # variables to trainable if the value is not set in their VariableDef.
    sorted_collections = []
    if ops.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def:
      sorted_collections.append(
          (ops.GraphKeys.TRAINABLE_VARIABLES,
           meta_graph_def.collection_def[ops.GraphKeys.TRAINABLE_VARIABLES]))
    for key, value in sorted(meta_graph_def.collection_def.items()):
      if key != ops.GraphKeys.TRAINABLE_VARIABLES:
        sorted_collections.append((key, value))

    # Restores all the other collections.
    variable_objects = {}
    for key, col_def in sorted_collections:
      # Don't add unbound_inputs to the new graph.
      if key == unbound_inputs_col_name:
        continue
      if not restore_collections_predicate(key):
        continue

      kind = col_def.WhichOneof("kind")
      if kind is None:
        logging.error("Cannot identify data type for collection %s. Skipping.",
                      key)
        continue
      from_proto = ops.get_from_proto_function(key)

      # Temporary change to allow the TFMA evaluator to read metric variables
      # saved as a bytes list.
      # TODO(kathywu): Remove this hack once cl/248406059 has been submitted.
      if key == ops.GraphKeys.METRIC_VARIABLES:
        # Metric variables will use the same proto functions as GLOBAL_VARIABLES
        from_proto = ops.get_from_proto_function(ops.GraphKeys.GLOBAL_VARIABLES)
      if from_proto and kind == "bytes_list":
        proto_type = ops.get_collection_proto_type(key)
        if key in ops.GraphKeys._VARIABLE_COLLECTIONS:  # pylint: disable=protected-access
          for value in col_def.bytes_list.value:
            variable = variable_objects.get(value, None)
            if variable is None:
              proto = proto_type()
              proto.ParseFromString(value)
              if not variables_have_trainable:
                # If the VariableDef proto does not contain a "trainable"
                # property because it was exported before that property was
                # added, we default it to whether the variable is in the
                # TRAINABLE_VARIABLES collection. We've sorted
                # TRAINABLE_VARIABLES to be first, so trainable variables will
                # be created from that collection.
                proto.trainable = (key == ops.GraphKeys.TRAINABLE_VARIABLES)
              variable = from_proto(
                  proto, import_scope=scope_to_prepend_to_names)
              variable_objects[value] = variable
            graph.add_to_collection(key, variable)
        else:
          for value in col_def.bytes_list.value:
            proto = proto_type()
            proto.ParseFromString(value)
            graph.add_to_collection(
                key, from_proto(
                    proto, import_scope=scope_to_prepend_to_names))
      else:
        field = getattr(col_def, kind)
        if key in _COMPAT_COLLECTION_LIST:
          logging.warning(
              "The saved meta_graph is possibly from an older release:\n"
              "'%s' collection should be of type 'byte_list', but instead "
              "is of type '%s'.", key, kind)
        if kind == "node_list":
          for value in field.value:
            col_op = graph.as_graph_element(
                ops.prepend_name_scope(value, scope_to_prepend_to_names))
            graph.add_to_collection(key, col_op)
        elif kind == "int64_list":
          # NOTE(opensource): This force conversion is to work around the fact
          # that Python2 distinguishes between int and long, while Python3 has
          # only int.
          for value in field.value:
            graph.add_to_collection(key, int(value))
        else:
          for value in field.value:
            graph.add_to_collection(
                key, ops.prepend_name_scope(value, scope_to_prepend_to_names))

    var_list = {}
    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                     scope=scope_to_prepend_to_names)
    for v in variables:
      var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v

  return var_list, imported_return_elements


def export_scoped_meta_graph(filename=None,
                             graph_def=None,
                             graph=None,
                             export_scope=None,
                             as_text=False,
                             unbound_inputs_col_name="unbound_inputs",
                             clear_devices=False,
                             saver_def=None,
                             clear_extraneous_savers=False,
                             strip_default_attrs=False,
                             save_debug_info=False,
                             **kwargs):
  """Returns `MetaGraphDef` proto. Optionally writes it to filename.

  This function exports the graph, saver, and collection objects into
  `MetaGraphDef` protocol buffer with the intention of it being imported
  at a later time or location to restart training, run inference, or be
  a subgraph.

  Args:
    filename: Optional filename including the path for writing the
      generated `MetaGraphDef` protocol buffer.
    graph_def: `GraphDef` protocol buffer.
    graph: The `Graph` to export. If `None`, use the default graph.
    export_scope: Optional `string`. Name scope under which to extract
      the subgraph. The scope name will be stripped from the node definitions
      for easy import later into new name scopes. If `None`, the whole graph
      is exported.
    as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
    unbound_inputs_col_name: Optional `string`. If provided, a string collection
      with the given name will be added to the returned `MetaGraphDef`,
      containing the names of tensors that must be remapped when importing the
      `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      before exporting the graph.
    saver_def: `SaverDef` protocol buffer.
    clear_extraneous_savers: Remove any Saver-related information from the
        graph (both Save/Restore ops and SaverDefs) that are not associated
        with the provided SaverDef.
    strip_default_attrs: Set to true if default valued attributes must be
      removed while exporting the GraphDef.
    save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
      which in the same directory of filename and with `_debug` added before the
      file extension.
    **kwargs: Optional keyed arguments, including meta_info_def and
        collection_list.

  Returns:
    A `MetaGraphDef` proto and dictionary of `Variables` in the exported
    name scope.

  Raises:
    ValueError: When the `GraphDef` is larger than 2GB.
    ValueError: When executing in Eager mode and either `graph_def` or `graph`
      is undefined.
  """
  if context.executing_eagerly() and not (graph_def is not None and
                                          graph is not None):
    raise ValueError("Exporting/importing meta graphs is not supported when "
                     "Eager Execution is enabled.")
  graph = graph or ops.get_default_graph()

  exclude_nodes = None
  unbound_inputs = []
  if export_scope or clear_extraneous_savers or clear_devices:
    if graph_def:
      new_graph_def = graph_pb2.GraphDef()
      new_graph_def.versions.CopyFrom(graph_def.versions)
      new_graph_def.library.CopyFrom(graph_def.library)

      if clear_extraneous_savers:
        exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def)

      for node_def in graph_def.node:
        if _should_include_node(node_def.name, export_scope, exclude_nodes):
          new_node_def = _node_def(node_def, export_scope, unbound_inputs,
                                   clear_devices=clear_devices)
          new_graph_def.node.extend([new_node_def])
      graph_def = new_graph_def
    else:
      # Only do this complicated work if we want to remove a name scope.
      graph_def = graph_pb2.GraphDef()
      # pylint: disable=protected-access
      graph_def.versions.CopyFrom(graph.graph_def_versions)
      bytesize = 0

      if clear_extraneous_savers:
        exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(),
                                                     saver_def)

      for key in sorted(graph._nodes_by_id):
        if _should_include_node(graph._nodes_by_id[key].name,
                                export_scope,
                                exclude_nodes):
          value = graph._nodes_by_id[key]
          # pylint: enable=protected-access
          node_def = _node_def(value.node_def, export_scope, unbound_inputs,
                               clear_devices=clear_devices)
          graph_def.node.extend([node_def])
          if value.outputs:
            assert "_output_shapes" not in graph_def.node[-1].attr
            graph_def.node[-1].attr["_output_shapes"].list.shape.extend([
                output.get_shape().as_proto() for output in value.outputs])
          bytesize += value.node_def.ByteSize()
          if bytesize >= (1 << 31) or bytesize < 0:
            raise ValueError(
                "GraphDef cannot be larger than 2GB. "
                f"Received size: {bytesize}.")

      graph._copy_functions_to_graph_def(graph_def, bytesize)  # pylint: disable=protected-access

    # It's possible that not all the inputs are in the export_scope.
    # If we would like such information included in the exported meta_graph,
    # add them to a special unbound_inputs collection.
    if unbound_inputs_col_name:
      # Clears the unbound_inputs collections.
      graph.clear_collection(unbound_inputs_col_name)
      for k in unbound_inputs:
        graph.add_to_collection(unbound_inputs_col_name, k)

  var_list = {}
  variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                   scope=export_scope)
  for v in variables:
    if _should_include_node(v, export_scope, exclude_nodes):
      var_list[ops.strip_name_scope(v.name, export_scope)] = v

  scoped_meta_graph_def = create_meta_graph_def(
      graph_def=graph_def,
      graph=graph,
      export_scope=export_scope,
      exclude_nodes=exclude_nodes,
      clear_extraneous_savers=clear_extraneous_savers,
      saver_def=saver_def,
      strip_default_attrs=strip_default_attrs,
      **kwargs)

  if filename:
    graph_io.write_graph(
        scoped_meta_graph_def,
        os.path.dirname(filename),
        os.path.basename(filename),
        as_text=as_text)
    if save_debug_info:
      name, _ = os.path.splitext(filename)
      debug_filename = "{name}{ext}".format(name=name, ext=".debug")

      # Gets the operation from the graph by the name. Excludes variable nodes,
      # so only the nodes in the frozen models are included.
      # TODO(liufengdb): fix this for functions.
      ops_to_export = []
      for node in scoped_meta_graph_def.graph_def.node:
        scoped_op_name = ops.prepend_name_scope(node.name, export_scope)
        ops_to_export.append(("", graph.get_operation_by_name(scoped_op_name)))

      graph_debug_info = error_interpolation.create_graph_debug_info_def(
          ops_to_export)

      graph_io.write_graph(
          graph_debug_info,
          os.path.dirname(debug_filename),
          os.path.basename(debug_filename),
          as_text=as_text)

  return scoped_meta_graph_def, var_list


def copy_scoped_meta_graph(from_scope, to_scope,
                           from_graph=None, to_graph=None):
  """Copies a sub-meta_graph from one scope to another.

  Args:
    from_scope: `String` name scope containing the subgraph to be copied.
    to_scope: `String` name scope under which the copied subgraph will reside.
    from_graph: Optional `Graph` from which to copy the subgraph. If `None`, the
      default graph is use.
    to_graph: Optional `Graph` to which to copy the subgraph. If `None`, the
      default graph is used.

  Returns:
    A dictionary of `Variables` that has been copied into `to_scope`.

  Raises:
    ValueError: If `from_scope` and `to_scope` are the same while
      `from_graph` and `to_graph` are also the same.
  """
  from_graph = from_graph or ops.get_default_graph()
  to_graph = to_graph or ops.get_default_graph()

  if from_graph == to_graph and from_scope == to_scope:
    raise ValueError("'from_scope' and 'to_scope' need to be different "
                     "when performing copy in the same graph. "
                     f"Received: 'from_graph': {from_graph}, "
                     f"'to_graph': {to_graph}, "
                     f"'from_scope': {from_scope}, 'to_scope': {to_scope}.")

  orig_meta_graph, var_list = export_scoped_meta_graph(
      export_scope=from_scope, graph=from_graph)
  var_list = import_scoped_meta_graph(orig_meta_graph,
                                      graph=to_graph,
                                      import_scope=to_scope)
  return var_list