tensorflow/tensorflow

View on GitHub
tensorflow/python/training/saver_test.py

Summary

Maintainability
F
1 mo
Test Coverage
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Tests for tensorflow.python.training.saver.py."""

import glob
import math
import os
import random
import time

import numpy as np

from google.protobuf.any_pb2 import Any

from tensorflow.core.framework import summary_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import queue_runner_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.checkpoint import checkpoint_management
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import function
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import cond
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variable_v1
from tensorflow.python.ops import variables
from tensorflow.python.ops import while_loop
import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.saved_model.pywrap_saved_model import metrics
from tensorflow.python.summary import summary
from tensorflow.python.trackable import base as trackable_base
from tensorflow.python.training import adam
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import py_checkpoint_reader
from tensorflow.python.training import queue_runner_impl
from tensorflow.python.training import saver as saver_module
from tensorflow.python.training import saver_test_utils
from tensorflow.python.util import compat


class SaverTest(test.TestCase):

  def basicSaveRestore(self, variable_op):
    save_path = os.path.join(self.get_temp_dir(), "basic_save_restore")

    with self.session(graph=ops_lib.Graph()) as sess:
      # Build a graph with 2 parameter nodes, and Save and
      # Restore nodes for them.
      v0 = variable_op(10.0, name="v0")
      v1 = variable_op(20.0, name="v1")
      v2 = saver_test_utils.CheckpointedOp(name="v2")
      v2_init = v2.insert("k1", 30.0)

      # Initialize all variables
      if not context.executing_eagerly():
        self.evaluate([variables.global_variables_initializer(), v2_init])

        # Check that the parameter nodes have been initialized.
      self.assertEqual(10.0, self.evaluate(v0))
      self.assertEqual(20.0, self.evaluate(v1))
      self.assertEqual(b"k1", self.evaluate(v2.keys()))
      self.assertEqual(30.0, self.evaluate(v2.values()))

      # Save the initialized values in the file at "save_path"
      save = saver_module.Saver(
          {
              "v0": v0,
              "v1": v1,
              "v2": v2.saveable
          }, restore_sequentially=True)
      val = save.save(sess, save_path)
      self.assertIsInstance(val, str)
      self.assertEqual(save_path, val)

    # Start a second session.  In that session the parameter nodes
    # have not been initialized either.
    with self.session(graph=ops_lib.Graph()) as sess:
      v0 = variable_op(-1.0, name="v0")
      v1 = variable_op(-1.0, name="v1")
      v2 = saver_test_utils.CheckpointedOp(name="v2")

      # Assert that the variables are not initialized.
      if not context.executing_eagerly():
        self.assertEqual(
            len(variables.report_uninitialized_variables().eval()), 2)
        self.assertEqual(0, len(self.evaluate(v2.keys())))
        self.assertEqual(0, len(self.evaluate(v2.values())))
      # Restore the saved values in the parameter nodes.
      save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})
      save.restore(sess, save_path)
      # Check that the parameter nodes have been restored.
      self.assertEqual(10.0, self.evaluate(v0))
      self.assertEqual(20.0, self.evaluate(v1))
      self.assertEqual(b"k1", self.evaluate(v2.keys()))
      self.assertEqual(30.0, self.evaluate(v2.values()))

    # Build another graph with 2 nodes, initialized
    # differently, and a Restore node for them.
    with self.session(graph=ops_lib.Graph()) as sess:
      v0_2 = variable_op(1000.0, name="v0")
      v1_2 = variable_op(2000.0, name="v1")
      v2_2 = saver_test_utils.CheckpointedOp(name="v2")
      v2_init = v2_2.insert("k1000", 3000.0)

      # Check that the parameter nodes have been initialized.
      if not context.executing_eagerly():
        init_all_op = [variables.global_variables_initializer(), v2_init]
        self.evaluate(init_all_op)
        # TODO(xpan): Why _mutable_hash_table_v2 doesn't create empty
        # table as it claims in eager mode?
        self.assertEqual(b"k1000", self.evaluate(v2_2.keys()))
        self.assertEqual(3000.0, self.evaluate(v2_2.values()))
      self.assertEqual(1000.0, self.evaluate(v0_2))
      self.assertEqual(2000.0, self.evaluate(v1_2))

      # Restore the values saved earlier in the parameter nodes.
      save2 = saver_module.Saver({"v0": v0_2, "v1": v1_2, "v2": v2_2.saveable})
      save2.restore(sess, save_path)
      # Check that the parameter nodes have been restored.
      self.assertEqual(10.0, self.evaluate(v0_2))
      self.assertEqual(20.0, self.evaluate(v1_2))
      self.assertEqual(b"k1", self.evaluate(v2_2.keys()))
      self.assertEqual(30.0, self.evaluate(v2_2.values()))

  def testBasic(self):
    self.basicSaveRestore(variables.Variable)

  @test_util.run_in_graph_and_eager_modes
  def testResourceBasic(self):
    self.basicSaveRestore(resource_variable_ops.ResourceVariable)

  def testResourceColocation(self):
    # train.Saver is V1 only API.
    with ops_lib.Graph().as_default():
      partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2)
      with ops_lib.device("/job:ps/device:GPU:0"):
        v = variable_scope.get_variable(
            "v0", shape=[10, 2], partitioner=partitioner, use_resource=True)
      saver_module.Saver({"v0": v}).build()
      save_op = None
      for op in ops_lib.get_default_graph().get_operations():
        if op.type == "SaveV2":
          save_op = op
          break
      assert save_op is not None
      for save_inp in save_op.inputs[3:]:
        # Input to SaveV2 op is placed on CPU of the same device as
        # the Variable.
        self.assertEqual("/job:ps/device:CPU:0", save_inp.device)

  def testResourceVariableReadOpsAddedDeterministically(self):
    graph_defs = []
    num_graphs = 10
    for _ in range(num_graphs):
      with ops_lib.Graph().as_default() as g:
        for i in range(20):
          resource_variable_ops.ResourceVariable(i, name="var%s" % i)
        saver_module.Saver()
        graph_defs.append(g.as_graph_def())
    for i in range(num_graphs - 1):
      self.assertEqual(graph_defs[i], graph_defs[i + 1])

  def testEagerBasic(self):
    with context.eager_mode():
      ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt")

      v1 = resource_variable_ops.ResourceVariable(3.14, name="v1")
      v2 = resource_variable_ops.ResourceVariable([1, 2], name="v2")
      save = saver_module.Saver([v1, v2])
      save.save(None, ckpt_prefix)

      v1.assign(0.0)
      v2.assign([0, 0])
      self.assertNear(0.0, self.evaluate(v1), 1e-5)
      self.assertAllEqual([0, 0], self.evaluate(v2))

      save.restore(None, ckpt_prefix)
      self.assertNear(3.14, self.evaluate(v1), 1e-5)
      self.assertAllEqual([1, 2], self.evaluate(v2))

  def testEagerGraphCompatibility(self):
    # Save from graph mode and restore from eager mode.
    graph_ckpt_prefix = os.path.join(self.get_temp_dir(), "graph_ckpt")
    with context.graph_mode():
      with self.session(graph=ops_lib.Graph()) as sess:
        # Create a graph model and save the checkpoint.
        w1 = resource_variable_ops.ResourceVariable(1.0, name="w1")
        w2 = resource_variable_ops.ResourceVariable(2.0, name="w2")
        graph_saver = saver_module.Saver([w1, w2])
        self.evaluate(variables.global_variables_initializer())
        graph_saver.save(sess, graph_ckpt_prefix)

    with context.eager_mode():
      ops_lib._default_graph_stack.reset()  # pylint: disable=protected-access
      ops_lib.reset_default_graph()

      w1 = resource_variable_ops.ResourceVariable(0.0, name="w1")
      w2 = resource_variable_ops.ResourceVariable(0.0, name="w2")

      graph_saver = saver_module.Saver([w1, w2])
      graph_saver.restore(None, graph_ckpt_prefix)

      self.assertAllEqual(self.evaluate(w1), 1.0)
      self.assertAllEqual(self.evaluate(w2), 2.0)

    # Save from eager mode and restore from graph mode.
    eager_ckpt_prefix = os.path.join(self.get_temp_dir(), "eager_ckpt")
    with context.eager_mode():
      ops_lib._default_graph_stack.reset()  # pylint: disable=protected-access
      ops_lib.reset_default_graph()

      w3 = resource_variable_ops.ResourceVariable(3.0, name="w3")
      w4 = resource_variable_ops.ResourceVariable(4.0, name="w4")

      graph_saver = saver_module.Saver([w3, w4])
      graph_saver.save(None, eager_ckpt_prefix)

    with context.graph_mode():
      with self.session(graph=ops_lib.Graph()) as sess:
        w3 = resource_variable_ops.ResourceVariable(0.0, name="w3")
        w4 = resource_variable_ops.ResourceVariable(0.0, name="w4")
        graph_saver = saver_module.Saver([w3, w4])
        self.evaluate(variables.global_variables_initializer())
        graph_saver.restore(sess, eager_ckpt_prefix)
        self.assertAllEqual(w3, 3.0)
        self.assertAllEqual(w4, 4.0)

  @test_util.run_in_graph_and_eager_modes
  def testResourceSaveRestoreCachingDevice(self):
    save_path = os.path.join(self.get_temp_dir(), "resource_cache")
    with self.session(graph=ops_lib.Graph()) as sess:
      v = resource_variable_ops.ResourceVariable([1], caching_device="/cpu:0",
                                                 name="v")
      if context.executing_eagerly():
        sess = None
      else:
        self.evaluate(variables.global_variables_initializer())
      save = saver_module.Saver([v])
      save.save(sess, save_path)

      save2 = saver_module.Saver([v])
      save2.restore(sess, save_path)
      self.assertEqual(self.evaluate(v), [1])

  def testNoAdditionalOpsAddedBySaverForResourceVariablesOutsideSaveScope(self):
    with ops_lib.Graph().as_default() as g:
      v = resource_variable_ops.ResourceVariable(1.0, name="v")
      with ops_lib.name_scope("saver1"):
        saver_module.Saver()
      with ops_lib.name_scope("saver2"):
        saver_module.Saver({"name": v})
    ops_in_saver1_scope_but_not_save_scope = [
        op for op in g.get_operations()
        if (op.name.startswith("saver1/") and
            not op.name.startswith("saver1/save/"))]
    self.assertEqual(ops_in_saver1_scope_but_not_save_scope, [])
    ops_in_saver2_scope_but_not_save_scope = [
        op for op in g.get_operations()
        if (op.name.startswith("saver2/") and
            not op.name.startswith("saver2/save/"))]
    self.assertEqual(ops_in_saver2_scope_but_not_save_scope, [])

  def testSaveCopyRestoreWithSaveRelativePaths(self):
    """Save, copy checkpoint dir and restore from copied dir.

    This only works for save_relative_paths=True.
    """
    save_dir1 = os.path.join(self.get_temp_dir(), "save_dir1")
    os.mkdir(save_dir1)
    save_path1 = os.path.join(save_dir1, "save_copy_restore")

    # train.Saver is V1 only API.
    with ops_lib.Graph().as_default():
      # Build a graph with 2 parameter nodes, and Save and
      # Restore nodes for them.
      v0 = variable_v1.VariableV1(10.0, name="v0")
      v1 = variable_v1.VariableV1(20.0, name="v1")
      v2 = saver_test_utils.CheckpointedOp(name="v2")
      v2_init = v2.insert("k1", 30.0)
      save = saver_module.Saver(
          var_list={
              "v0": v0,
              "v1": v1,
              "v2": v2.saveable
          },
          restore_sequentially=True,
          save_relative_paths=True)
      init_all_op = [variables.global_variables_initializer(), v2_init]

      with self.cached_session() as sess:
        # Initialize all variables
        self.evaluate(init_all_op)

        # Check that the parameter nodes have been initialized.
        self.assertEqual(10.0, self.evaluate(v0))
        self.assertEqual(20.0, self.evaluate(v1))
        self.assertEqual(b"k1", self.evaluate(v2.keys()))
        self.assertEqual(30.0, self.evaluate(v2.values()))

        # Save the initialized values in the file at "save_path"
        val = save.save(sess, save_path1)
        self.assertIsInstance(val, str)
        self.assertEqual(save_path1, val)

      self.assertEqual(
          checkpoint_management.latest_checkpoint(save_dir1), save_path1)
      save_dir2 = os.path.join(self.get_temp_dir(), "save_dir2")
      os.renames(save_dir1, save_dir2)
      save_path2 = os.path.join(save_dir2, "save_copy_restore")
      self.assertEqual(
          checkpoint_management.latest_checkpoint(save_dir2), save_path2)

      # Start a second session.  In that session the parameter nodes
      # have not been initialized either.
      with self.cached_session() as sess:
        v0 = variable_v1.VariableV1(-1.0, name="v0")
        v1 = variable_v1.VariableV1(-1.0, name="v1")
        v2 = saver_test_utils.CheckpointedOp(name="v2")
        save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})

        # Assert that the variables are not initialized.
        self.assertEqual(
            len(variables.report_uninitialized_variables().eval()), 2)
        self.assertEqual(0, len(self.evaluate(v2.keys())))
        self.assertEqual(0, len(self.evaluate(v2.values())))

        # Restore the saved values in the parameter nodes.
        save.restore(sess, save_path2)
        # Check that the parameter nodes have been restored.
        self.assertEqual(10.0, self.evaluate(v0))
        self.assertEqual(20.0, self.evaluate(v1))
        self.assertEqual(b"k1", self.evaluate(v2.keys()))
        self.assertEqual(30.0, self.evaluate(v2.values()))

  def testFilenameTensor(self):
    # train.Saver is V1 only API.
    with ops_lib.Graph().as_default():
      v0 = variable_v1.VariableV1(0, name="v0")
      filename = b"somerandomfilename"
      save = saver_module.Saver({"v0": v0}, filename=filename)
      with self.cached_session() as sess:
        tensor = sess.graph.get_tensor_by_name(
            save.saver_def.filename_tensor_name)
        self.assertEqual(self.evaluate(tensor), filename)

  def testInvalidPath(self):
    v0 = variable_v1.VariableV1(0, name="v0")
    for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):
      with self.cached_session() as sess:
        save = saver_module.Saver({"v0": v0}, write_version=ver)
        with self.assertRaisesRegex(
            ValueError, "The passed save_path is not a valid checkpoint:"):
          save.restore(sess, "invalid path")

  @test_util.run_v1_only("train.Saver is V1 only API.")
  def testInt64(self):
    save_path = os.path.join(self.get_temp_dir(), "int64")

    with self.cached_session() as sess:
      # Build a graph with 1 node, and save and restore for them.
      v = variable_v1.VariableV1(np.int64(15), name="v")
      save = saver_module.Saver({"v": v}, restore_sequentially=True)
      self.evaluate(variables.global_variables_initializer())

      # Save the initialized values in the file at "save_path"
      val = save.save(sess, save_path)
      self.assertIsInstance(val, str)
      self.assertEqual(save_path, val)

      with self.cached_session() as sess:
        v = variable_v1.VariableV1(np.int64(-1), name="v")
        save = saver_module.Saver({"v": v})

      with self.assertRaisesWithPredicateMatch(
          errors_impl.OpError, lambda e: "uninitialized value v" in e.message):
        self.evaluate(v)

      # Restore the saved values in the parameter nodes.
      save.restore(sess, save_path)
      # Check that the parameter nodes have been restored.
      self.assertEqual(np.int64(15), self.evaluate(v))

  def testSomeErrors(self):
    with ops_lib.Graph().as_default():
      v0 = variable_v1.VariableV1([10.0], name="v0")
      v1 = variable_v1.VariableV1([20.0], name="v1")
      v2 = variable_v1.VariableV1([20.0], name="v2")
      v2._set_save_slice_info(
          variables.Variable.SaveSliceInfo("v1", [1], [0], [1]))

      # By default the name used for "v2" will be "v1" and raise an error.
      with self.assertRaisesRegex(ValueError, "same name: v1"):
        saver_module.Saver([v0, v1, v2])

      # The names are different and will work.
      saver_module.Saver({"vee1": v1, "other": [v2]})

      # Partitioned variables also cause name conflicts.
      p_v1 = variable_scope.get_variable(
          "p_v1",
          shape=[4, 5],
          partitioner=partitioned_variables.fixed_size_partitioner(
              num_shards=2))
      p_v2 = variable_scope.get_variable(
          "p_v2",
          shape=[4, 5],
          partitioner=partitioned_variables.fixed_size_partitioner(
              num_shards=2))
      p_v2._name = "p_v1"
      with self.assertRaisesRegex(ValueError, "same name: p_v1"):
        saver_module.Saver([p_v1, p_v2])

  def testSameName(self):
    with ops_lib.Graph().as_default():
      v0 = variable_v1.VariableV1([10.0], name="v0")
      v2 = saver_test_utils.CheckpointedOp(name="v2")

      # Saving one variable under two names raises an error.
      with self.assertRaisesRegex(
          ValueError, "The same saveable will be restored with two names: v0"):
        saver_module.Saver({"v0": v0, "v0too": v0})

      # Ditto for custom saveables.
      with self.assertRaisesRegex(
          ValueError, "The same saveable will be restored with two names: v2"):
        saver_module.Saver({"v2": v2.saveable, "v2too": v2.saveable})

      # Verify non-duplicate names work.
      saver_module.Saver({"v0": v0, "v2": v2.saveable})

  @test_util.run_v1_only("train.Saver and VariableV1 are V1 only APIs.")
  def testBasicsWithListOfVariables(self):
    save_path = os.path.join(self.get_temp_dir(), "basics_with_list")

    with self.session(graph=ops_lib.Graph()) as sess:
      # Build a graph with 2 parameter nodes, and Save and
      # Restore nodes for them.
      v0 = variable_v1.VariableV1(10.0, name="v0")
      v1 = variable_v1.VariableV1(20.0, name="v1")
      v2 = saver_test_utils.CheckpointedOp(name="v2")
      v2_init = v2.insert("k1", 30.0)
      save = saver_module.Saver([v0, v1, v2.saveable])
      self.evaluate(variables.global_variables_initializer())
      v2_init.run()

      # Check that the parameter nodes have been initialized.
      self.assertEqual(10.0, self.evaluate(v0))
      self.assertEqual(20.0, self.evaluate(v1))
      self.assertEqual(b"k1", self.evaluate(v2.keys()))
      self.assertEqual(30.0, self.evaluate(v2.values()))

      # Save the initialized values in the file at "save_path"
      val = save.save(sess, save_path)
      self.assertIsInstance(val, str)
      self.assertEqual(save_path, val)

    # Start a second session.  In that session the variables
    # have not been initialized either.
    with self.session(graph=ops_lib.Graph()) as sess:
      v0 = variable_v1.VariableV1(-1.0, name="v0")
      v1 = variable_v1.VariableV1(-1.0, name="v1")
      v2 = saver_test_utils.CheckpointedOp(name="v2")
      save = saver_module.Saver([v0, v1, v2.saveable])

      with self.assertRaisesWithPredicateMatch(
          errors_impl.OpError, lambda e: "uninitialized value v0" in e.message):
        self.evaluate(v0)
      with self.assertRaisesWithPredicateMatch(
          errors_impl.OpError, lambda e: "uninitialized value v1" in e.message):
        self.evaluate(v1)
      self.assertEqual(0, len(self.evaluate(v2.keys())))
      self.assertEqual(0, len(self.evaluate(v2.values())))

      # Restore the saved values in the parameter nodes.
      save.restore(sess, save_path)
      # Check that the parameter nodes have been restored.
      self.assertEqual(10.0, self.evaluate(v0))
      self.assertEqual(20.0, self.evaluate(v1))
      self.assertEqual(b"k1", self.evaluate(v2.keys()))
      self.assertEqual(30.0, self.evaluate(v2.values()))

    # Build another graph with 2 nodes, initialized
    # differently, and a Restore node for them.
    with self.session(graph=ops_lib.Graph()) as sess:
      v0_2 = variable_v1.VariableV1(1000.0, name="v0")
      v1_2 = variable_v1.VariableV1(2000.0, name="v1")
      v2_2 = saver_test_utils.CheckpointedOp(name="v2")
      save2 = saver_module.Saver([v0_2, v1_2, v2_2.saveable])
      v2_2.insert("k1000", 3000.0).run()
      self.evaluate(variables.global_variables_initializer())

      # Check that the parameter nodes have been initialized.
      self.assertEqual(1000.0, self.evaluate(v0_2))
      self.assertEqual(2000.0, self.evaluate(v1_2))
      self.assertEqual(b"k1000", self.evaluate(v2_2.keys()))
      self.assertEqual(3000.0, self.evaluate(v2_2.values()))
      # Restore the values saved earlier in the parameter nodes.
      save2.restore(sess, save_path)
      # Check that the parameter nodes have been restored.
      self.assertEqual(10.0, self.evaluate(v0_2))
      self.assertEqual(20.0, self.evaluate(v1_2))
      self.assertEqual(b"k1", self.evaluate(v2_2.keys()))
      self.assertEqual(30.0, self.evaluate(v2_2.values()))

  def _SaveAndLoad(self, var_name, var_value, other_value, save_path):
    with self.session(graph=ops_lib.Graph()) as sess:
      var = resource_variable_ops.ResourceVariable(var_value, name=var_name)
      save = saver_module.Saver({var_name: var})
      if not context.executing_eagerly():
        self.evaluate(var.initializer)
      val = save.save(sess, save_path)
      self.assertEqual(save_path, val)
    with self.session(graph=ops_lib.Graph()) as sess:
      var = resource_variable_ops.ResourceVariable(other_value, name=var_name)
      save = saver_module.Saver({var_name: var})
      save.restore(sess, save_path)
      self.assertAllClose(var_value, self.evaluate(var))

  def testCacheRereadsFile(self):
    save_path = os.path.join(self.get_temp_dir(), "cache_rereads")
    # Save and reload one Variable named "var0".
    self._SaveAndLoad("var0", 0.0, 1.0, save_path)
    # Save and reload one Variable named "var1" in the same file.
    # The cached readers should know to re-read the file.
    self._SaveAndLoad("var1", 1.1, 2.2, save_path)

  def testAllowEmpty(self):
    save_path = os.path.join(self.get_temp_dir(), "allow_empty")
    # train.Saver is V1 only API.
    with ops_lib.Graph().as_default(), self.cached_session() as sess:
      _ = constant_op.constant(1)
      save = saver_module.Saver(allow_empty=True)
      val = save.save(sess, save_path)
      self.assertIsNone(val)
    with ops_lib.Graph().as_default(), self.cached_session() as sess:
      save = saver_module.Saver(allow_empty=True)
      save.restore(sess, save_path)

  def testGPU(self):
    if not test.is_gpu_available():
      return
    save_path = os.path.join(self.get_temp_dir(), "gpu")
    with session.Session("", graph=ops_lib.Graph()) as sess:
      with sess.graph.device(test.gpu_device_name()):
        v0_1 = variable_v1.VariableV1(123.45)
      save = saver_module.Saver({"v0": v0_1})
      self.evaluate(variables.global_variables_initializer())
      save.save(sess, save_path)

    with session.Session("", graph=ops_lib.Graph()) as sess:
      with sess.graph.device(test.gpu_device_name()):
        v0_2 = variable_v1.VariableV1(543.21)
      save = saver_module.Saver({"v0": v0_2})
      self.evaluate(variables.global_variables_initializer())

  def testSharedServerOnGPU(self):
    if not test.is_gpu_available():
      return
    save_path = os.path.join(self.get_temp_dir(), "gpu")
    with session.Session("", graph=ops_lib.Graph()) as sess:
      with sess.graph.device(test.gpu_device_name()):
        v0_1 = variable_v1.VariableV1(123.45)
      save = saver_module.Saver({"v0": v0_1}, sharded=True, allow_empty=True)
      self.evaluate(variables.global_variables_initializer())
      save.save(sess, save_path)

    with session.Session("", graph=ops_lib.Graph()) as sess:
      with sess.graph.device(test.gpu_device_name()):
        v0_2 = variable_v1.VariableV1(543.21)
      save = saver_module.Saver({"v0": v0_2}, sharded=True, allow_empty=True)
      self.evaluate(variables.global_variables_initializer())

  def testVariables(self):
    save_path = os.path.join(self.get_temp_dir(), "variables")
    with session.Session("", graph=ops_lib.Graph()) as sess:
      one = variable_v1.VariableV1(1.0)
      twos = variable_v1.VariableV1([2.0, 2.0, 2.0])
      v2 = saver_test_utils.CheckpointedOp(name="v2")
      init = variables.global_variables_initializer()
      save = saver_module.Saver()
      init.run()
      v2.insert("k1", 3.0).run()
      save.save(sess, save_path)

    with session.Session("", graph=ops_lib.Graph()) as sess:
      one = variable_v1.VariableV1(0.0)
      twos = variable_v1.VariableV1([0.0, 0.0, 0.0])
      v2 = saver_test_utils.CheckpointedOp(name="v2")
      # Saver with no arg, defaults to 'all variables'.
      save = saver_module.Saver()
      save.restore(sess, save_path)
      self.assertAllClose(1.0, self.evaluate(one))
      self.assertAllClose([2.0, 2.0, 2.0], self.evaluate(twos))
      self.assertEqual(b"k1", self.evaluate(v2.keys()))
      self.assertEqual(3.0, self.evaluate(v2.values()))

  def testVarListShouldBeEmptyInDeferredBuild(self):
    with ops_lib.Graph().as_default():
      v = variable_v1.VariableV1(1.0)
      with self.assertRaisesRegex(ValueError, "defer_build"):
        saver_module.Saver([v], defer_build=True)

  def testBuildShouldBeCalledBeforeSaveInCaseOfDeferBuild(self):
    save_path = os.path.join(self.get_temp_dir(), "error_deferred_build")
    with ops_lib.Graph().as_default(), session.Session() as sess:
      variable_v1.VariableV1(1.0)
      saver = saver_module.Saver(defer_build=True)
      with self.assertRaisesRegex(RuntimeError, "build"):
        saver.save(sess, save_path)

  def testDeferredBuild(self):
    save_path = os.path.join(self.get_temp_dir(), "deferred_build")
    with session.Session("", graph=ops_lib.Graph()) as sess:
      one = variable_v1.VariableV1(1.0)
      save = saver_module.Saver(defer_build=True)
      # if build is not deferred, saver cannot save the `twos`.
      twos = variable_v1.VariableV1([2.0, 2.0, 2.0])
      init = variables.global_variables_initializer()
      save.build()
      init.run()
      save.save(sess, save_path)

    with session.Session("", graph=ops_lib.Graph()) as sess:
      one = variable_v1.VariableV1(0.0)
      twos = variable_v1.VariableV1([0.0, 0.0, 0.0])
      # Saver with no arg, defaults to 'all variables'.
      save = saver_module.Saver()
      save.restore(sess, save_path)
      self.assertAllClose(1.0, self.evaluate(one))
      self.assertAllClose([2.0, 2.0, 2.0], self.evaluate(twos))

  @test_util.run_v1_only("train.Saver is V1 only API.")
  def testReshape(self):
    save_path = os.path.join(self.get_temp_dir(), "variables_reshape")
    with session.Session("", graph=ops_lib.Graph()) as sess:
      var = variable_v1.VariableV1([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
      init = variables.global_variables_initializer()
      save = saver_module.Saver()
      init.run()
      save.save(sess, save_path)

    # Error when restoring with default reshape=False
    with session.Session("", graph=ops_lib.Graph()) as sess:
      var = variable_v1.VariableV1([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
      save = saver_module.Saver()
      with self.assertRaisesRegex(
          errors_impl.InvalidArgumentError,
          "Assign requires shapes of both tensors to match."):
        save.restore(sess, save_path)

    # Restored to new shape with reshape=True
    with session.Session("", graph=ops_lib.Graph()) as sess:
      var = variable_v1.VariableV1([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
      save = saver_module.Saver(reshape=True)
      save.restore(sess, save_path)
      self.assertAllClose([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
                          self.evaluate(var))

  @test_util.run_in_graph_and_eager_modes
  def testSaveWithGlobalStep(self, pad_step_number=False):
    save_path = os.path.join(self.get_temp_dir(), "ckpt_with_global_step")
    global_step_int = 5
    # Save and reload one Variable named "var0".
    self._SaveAndLoad("var0", 0.0, 1.0, save_path)
    for use_tensor in [True, False]:
      with self.session(graph=ops_lib.Graph()):
        var = resource_variable_ops.ResourceVariable(1.0, name="var0")
        save = saver_module.Saver(
            {
                var._shared_name: var
            }, pad_step_number=pad_step_number)
        if context.executing_eagerly():
          sess = None
        else:
          self.evaluate(var.initializer)
          sess = ops_lib.get_default_session()
        if use_tensor:
          global_step = constant_op.constant(global_step_int)
          val = save.save(sess, save_path, global_step=global_step)
        else:
          val = save.save(sess, save_path, global_step=global_step_int)
        if pad_step_number:
          expected_save_path = "%s-%s" % (save_path,
                                          "{:08d}".format(global_step_int))
        else:
          expected_save_path = "%s-%d" % (save_path, global_step_int)
        self.assertEqual(expected_save_path, val)

  def testSaveWithGlobalStepWithPadding(self):
    self.testSaveWithGlobalStep(pad_step_number=True)

  def testSaveToNonexistingPath(self):
    file_io.write_string_to_file(
        os.path.join(self.get_temp_dir(), "actually_a_file"), "")
    paths = [
        os.path.join(self.get_temp_dir(), "nonexisting_dir/path"),
        os.path.join(self.get_temp_dir(), "other_nonexisting_dir/path1/path2"),
        os.path.join(self.get_temp_dir(), "actually_a_file/path"),
    ]

    for save_path in paths:
      # Build a graph with 2 parameter nodes, and Save and
      # Restore nodes for them.
      v0 = variable_v1.VariableV1(10.0, name="v0")
      v1 = variable_v1.VariableV1(20.0, name="v1")
      save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
      init_all_op = variables.global_variables_initializer()

      # In the case where the parent directory doesn't exist, whether or not the
      # save succeeds or fails is implementation dependent.  Therefore we allow
      # both cases.
      try:
        with self.cached_session() as sess:
          # Initialize all variables
          self.evaluate(init_all_op)

          # Check that the parameter nodes have been initialized.
          self.assertEqual(10.0, self.evaluate(v0))
          self.assertEqual(20.0, self.evaluate(v1))

          # Save the graph.
          save.save(sess, save_path)

        with self.cached_session() as sess:
          # Restore the saved values in the parameter nodes.
          save.restore(sess, save_path)
          # Check that the parameter nodes have been restored.
          self.assertEqual(10.0, self.evaluate(v0))
          self.assertEqual(20.0, self.evaluate(v1))
      except ValueError as exc:
        error_msg_template = "Parent directory of {} doesn't exist, can't save."
        self.assertEqual(error_msg_template.format(save_path), str(exc))

  def testSaveToURI(self):
    # ParseURI functions don't work on Windows yet.
    # TODO(jhseu): Remove this check when it works.
    if os.name == "nt":
      self.skipTest("Local URI support doesn't work on Windows")
    save_path = "file://" + os.path.join(self.get_temp_dir(), "uri")

    # Build a graph with 2 parameter nodes, and Save and
    # Restore nodes for them.
    v0 = variable_v1.VariableV1(10.0, name="v0")
    v1 = variable_v1.VariableV1(20.0, name="v1")
    save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
    init_all_op = variables.global_variables_initializer()

    with self.cached_session() as sess:
      # Initialize all variables
      self.evaluate(init_all_op)

      # Check that the parameter nodes have been initialized.
      self.assertEqual(10.0, self.evaluate(v0))
      self.assertEqual(20.0, self.evaluate(v1))
      save.save(sess, save_path)

  def testSaveRestoreAndValidateVariableDtype(self):
    for variable_op in [
        variables.Variable, resource_variable_ops.ResourceVariable
    ]:
      save_path = os.path.join(self.get_temp_dir(), "basic_save_restore")

      # Build the first session.
      with self.session(graph=ops_lib.Graph()) as sess:
        v0 = variable_op(10.0, name="v0", dtype=dtypes.float32)

        if not context.executing_eagerly():
          self.evaluate([variables.global_variables_initializer()])

        save = saver_module.Saver({"v0": v0})
        save.save(sess, save_path)

      # Start a second session.
      with self.session(graph=ops_lib.Graph()) as sess:
        v0_wrong_dtype = variable_op(1, name="v0", dtype=dtypes.int32)
        # Restore the saved value with different dtype
        # in the parameter nodes.
        save = saver_module.Saver({"v0": v0_wrong_dtype})
        with self.assertRaisesRegex(errors.InvalidArgumentError,
                                    "original dtype"):
          save.restore(sess, save_path)

  # Test restoring large tensors (triggers a thread pool)
  def testRestoreLargeTensors(self):
    save_dir = self.get_temp_dir()
    def _model():
      small_v = [variable_scope.get_variable(
          "small%d" % i, shape=[10, 2], use_resource=True) for i in range(5)]
      large_v = [variable_scope.get_variable(
          "large%d" % i, shape=[32000, 1000], use_resource=True)
                 for i in range(3)]
      return small_v + large_v

    save_graph = ops_lib.Graph()
    with save_graph.as_default(), self.session(graph=save_graph) as sess:
      orig_vars = _model()
      self.evaluate(variables.global_variables_initializer())
      save = saver_module.Saver(max_to_keep=1)
      self.evaluate(variables.global_variables_initializer())
      save.save(sess, save_dir)
      orig_vals = self.evaluate(orig_vars)

    restore_graph = ops_lib.Graph()
    with restore_graph.as_default(), self.session(
        graph=restore_graph) as sess:
      restored_vars = _model()
      save = saver_module.Saver(max_to_keep=1)
      save.restore(sess, save_dir)
      restored_vals = self.evaluate(restored_vars)

    for orig, restored in zip(orig_vals, restored_vals):
      self.assertAllEqual(orig, restored)

  def test_metrics_save_restore(self):
    api_label = saver_module._SAVER_LABEL

    def _get_write_histogram_proto():
      proto_bytes = metrics.GetCheckpointWriteDurations(api_label=api_label)
      histogram_proto = summary_pb2.HistogramProto()
      histogram_proto.ParseFromString(proto_bytes)
      return histogram_proto

    def _get_read_histogram_proto():
      proto_bytes = metrics.GetCheckpointReadDurations(api_label=api_label)
      histogram_proto = summary_pb2.HistogramProto()
      histogram_proto.ParseFromString(proto_bytes)
      return histogram_proto

    save_path = os.path.join(self.get_temp_dir(), "metrics_save_restore")
    # Values at beginning of unit test.
    time_start = metrics.GetTrainingTimeSaved(api_label=api_label)
    num_writes_start = _get_write_histogram_proto().num
    num_reads_start = _get_read_histogram_proto().num

    with self.session(graph=ops_lib.Graph()) as sess:
      v0 = resource_variable_ops.ResourceVariable(10.0, name="v0")
      v1 = resource_variable_ops.ResourceVariable(20.0, name="v1")
      v2 = saver_test_utils.CheckpointedOp(name="v2")
      # Initialize all variables
      if not context.executing_eagerly():
        self.evaluate([variables.global_variables_initializer()])

      save = saver_module.Saver({
          "v0": v0,
          "v1": v1,
          "v2": v2.saveable
      },
                                restore_sequentially=True)
      ckpt_prefix = save.save(sess, save_path)
      filesize = saver_module._get_checkpoint_size(ckpt_prefix)
      count_after_one_save = metrics.GetCheckpointSize(
          api_label=api_label, filesize=filesize)

      self.assertEqual(_get_write_histogram_proto().num, num_writes_start + 1)
      time_after_one_save = metrics.GetTrainingTimeSaved(api_label=api_label)
      self.assertGreater(time_after_one_save, time_start)

    with self.session(graph=ops_lib.Graph()) as sess:
      v0 = resource_variable_ops.ResourceVariable(-1.0, name="v0")
      v1 = resource_variable_ops.ResourceVariable(-1.0, name="v1")
      v2 = saver_test_utils.CheckpointedOp(name="v2")
      save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})
      save.restore(sess, save_path)

      self.assertEqual(_get_write_histogram_proto().num, num_writes_start + 1)
      self.assertEqual(_get_read_histogram_proto().num, num_reads_start + 1)
      # Check that training time saved has not increased.
      self.assertEqual(
          metrics.GetTrainingTimeSaved(api_label=api_label),
          time_after_one_save)
      save.save(sess, save_path)

      self.assertEqual(_get_write_histogram_proto().num, num_writes_start + 2)
      self.assertEqual(_get_read_histogram_proto().num, num_reads_start + 1)
      # Check that training time saved has increased.
      self.assertGreater(
          metrics.GetTrainingTimeSaved(api_label=api_label),
          time_after_one_save)
      self.assertEqual(
          metrics.GetCheckpointSize(api_label=api_label, filesize=filesize),
          count_after_one_save + 1)


class SaveRestoreShardedTest(test.TestCase):

  _WRITE_VERSION = saver_pb2.SaverDef.V1

  def _get_test_dir(self, dirname):
    test_dir = os.path.join(self.get_temp_dir(), dirname)
    gfile.MakeDirs(test_dir)
    return test_dir

  def testBasics(self):
    save_path = os.path.join(self.get_temp_dir(), "sharded_basics")

    # Build a graph with 2 parameter nodes on different devices.
    with session.Session(
        target="",
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      with sess.graph.device("/cpu:0"):
        v0 = variable_v1.VariableV1(10, name="v0")
        t0 = saver_test_utils.CheckpointedOp(name="t0")
      with sess.graph.device("/cpu:1"):
        v1 = variable_v1.VariableV1(20, name="v1")
        t1 = saver_test_utils.CheckpointedOp(name="t1")
      save = saver_module.Saver(
          {
              "v0": v0,
              "v1": v1,
              "t0": t0.saveable,
              "t1": t1.saveable
          },
          write_version=self._WRITE_VERSION,
          sharded=True)
      self.evaluate(variables.global_variables_initializer())
      t0.insert("k1", 30.0).run()
      t1.insert("k2", 40.0).run()
      val = save.save(sess, save_path)
      if save._write_version is saver_pb2.SaverDef.V1:
        self.assertEqual(save_path + "-?????-of-00002", val)
      else:
        self.assertEqual(save_path, val)
      meta_graph_filename = checkpoint_management.meta_graph_filename(val)
      self.assertEqual(save_path + ".meta", meta_graph_filename)

    if save._write_version is saver_pb2.SaverDef.V1:
      # Restore different ops from shard 0 of the saved files.
      with session.Session(
          target="",
          config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
        with sess.graph.device("/cpu:0"):
          v0 = variable_v1.VariableV1(111, name="v0")
          t0 = saver_test_utils.CheckpointedOp(name="t0")
        save = saver_module.Saver(
            {
                "v0": v0,
                "t0": t0.saveable
            },
            write_version=self._WRITE_VERSION,
            sharded=True)
        self.evaluate(variables.global_variables_initializer())
        t0.insert("k11", 33.0).run()
        self.assertEqual(111, self.evaluate(v0))
        self.assertEqual(b"k11", self.evaluate(t0.keys()))
        self.assertEqual(33.0, self.evaluate(t0.values()))
        save.restore(sess, save_path + "-00000-of-00002")
        self.assertEqual(10, self.evaluate(v0))
        self.assertEqual(b"k1", self.evaluate(t0.keys()))
        self.assertEqual(30.0, self.evaluate(t0.values()))

      # Restore different ops from shard 1 of the saved files.
      with session.Session(
          target="",
          config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
        with sess.graph.device("/cpu:0"):
          v1 = variable_v1.VariableV1(222)
          t1 = saver_test_utils.CheckpointedOp(name="t1")
        save = saver_module.Saver(
            {
                "v1": v1,
                "t1": t1.saveable
            },
            write_version=self._WRITE_VERSION,
            sharded=True)
        self.evaluate(variables.global_variables_initializer())
        t1.insert("k22", 44.0).run()
        self.assertEqual(222, self.evaluate(v1))
        self.assertEqual(b"k22", self.evaluate(t1.keys()))
        self.assertEqual(44.0, self.evaluate(t1.values()))
        save.restore(sess, save_path + "-00001-of-00002")
        self.assertEqual(20, self.evaluate(v1))
        self.assertEqual(b"k2", self.evaluate(t1.keys()))
        self.assertEqual(40.0, self.evaluate(t1.values()))

    # Now try a restore with the sharded filename.
    with session.Session(
        target="",
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      with sess.graph.device("/cpu:0"):
        v0 = variable_v1.VariableV1(111, name="v0")
        t0 = saver_test_utils.CheckpointedOp(name="t0")
      with sess.graph.device("/cpu:1"):
        v1 = variable_v1.VariableV1(222, name="v1")
        t1 = saver_test_utils.CheckpointedOp(name="t1")
      save = saver_module.Saver(
          {
              "v0": v0,
              "v1": v1,
              "t0": t0.saveable,
              "t1": t1.saveable
          },
          write_version=self._WRITE_VERSION,
          sharded=True)
      self.evaluate(variables.global_variables_initializer())
      t0.insert("k11", 33.0).run()
      t1.insert("k22", 44.0).run()
      self.assertEqual(111, self.evaluate(v0))
      self.assertEqual(222, self.evaluate(v1))
      self.assertEqual(b"k11", self.evaluate(t0.keys()))
      self.assertEqual(33.0, self.evaluate(t0.values()))
      self.assertEqual(b"k22", self.evaluate(t1.keys()))
      self.assertEqual(44.0, self.evaluate(t1.values()))
      save_path = os.path.join(self.get_temp_dir(), "sharded_basics")
      if save._write_version is saver_pb2.SaverDef.V1:
        save.restore(sess, save_path + "-?????-of-?????")
      else:
        save.restore(sess, save_path)
      self.assertEqual(10, self.evaluate(v0))
      self.assertEqual(20, self.evaluate(v1))
      self.assertEqual(b"k1", self.evaluate(t0.keys()))
      self.assertEqual(30.0, self.evaluate(t0.values()))
      self.assertEqual(b"k2", self.evaluate(t1.keys()))
      self.assertEqual(40.0, self.evaluate(t1.values()))

    if save._write_version is saver_pb2.SaverDef.V1:
      self.assertEqual(
          checkpoint_management.latest_checkpoint(self.get_temp_dir()),
          os.path.join(self.get_temp_dir(), "sharded_basics-?????-of-00002"))
    else:
      self.assertEqual(
          checkpoint_management.latest_checkpoint(self.get_temp_dir()),
          os.path.join(self.get_temp_dir(), "sharded_basics"))

  def testSaverDef(self):
    # train.Saver is V1 only API.
    with ops_lib.Graph().as_default(), self.cached_session():
      v0 = variable_v1.VariableV1(123, name="v0")
      save = saver_module.Saver({"v0": v0}, sharded=True)
      sd = save.as_saver_def()
      self.assertTrue(sd.sharded)

  def _testPartitionedVariables(self, use_resource):
    var_full_shape = [10, 3]
    # Allows save/restore mechanism to work w/ different slicings.
    var_name = "my_var"
    saved_dir = self._get_test_dir("partitioned_variables")
    saved_path = os.path.join(saved_dir, "ckpt")

    call_saver_with_dict = False  # updated by test loop below

    def _save(partitioner=None):
      # train.Saver is V1 only API.
      with ops_lib.Graph().as_default(), self.session() as sess:
        # Calls .eval() to return the ndarray that makes up the full variable.
        rnd = random_ops.random_uniform(var_full_shape).eval()

        if partitioner:
          vs = [
              variable_scope.get_variable(
                  var_name,
                  shape=var_full_shape,
                  initializer=rnd,
                  partitioner=partitioner,
                  use_resource=use_resource)
          ]
        else:
          if use_resource:
            vs = [resource_variable_ops.ResourceVariable(rnd, name=var_name)]
          else:
            vs = [variable_v1.VariableV1(rnd, name=var_name)]

        self.evaluate(variables.global_variables_initializer())
        if call_saver_with_dict:
          saver = saver_module.Saver({var_name: vs[0]})
        else:
          saver = saver_module.Saver(vs)
        actual_path = saver.save(sess, saved_path)
        self.assertEqual(saved_path, actual_path)

        return rnd

    def _restore(partitioner=None):
      # train.Saver is V1 only API.
      with ops_lib.Graph().as_default(), self.session() as sess:
        if partitioner:
          new_vs = [
              variable_scope.get_variable(
                  var_name,
                  shape=var_full_shape,
                  initializer=array_ops.zeros(var_full_shape),
                  partitioner=partitioner)
          ]
        else:
          new_vs = [
              variable_v1.VariableV1(
                  array_ops.zeros(
                      shape=var_full_shape),  # != original contents.
                  name=var_name)
          ]

        self.evaluate(variables.global_variables_initializer())
        if call_saver_with_dict:
          saver = saver_module.Saver({
              var_name: new_vs[0]
          })
        else:
          saver = saver_module.Saver(new_vs)
        saver.restore(sess, saved_path)

        if partitioner:
          return new_vs[0].as_tensor().eval()
        else:
          return new_vs[0].eval()

    for call_saver_with_dict in {False, True}:
      # Save PartitionedVariable and restore into full variable.
      saved_full = _save(
          partitioner=partitioned_variables.fixed_size_partitioner(
              num_shards=2))
      restored_full = _restore()
      self.assertAllEqual(saved_full, restored_full)

      # Restores into the same number of partitions.
      restored_full = _restore(
          partitioner=partitioned_variables.fixed_size_partitioner(
              num_shards=2))
      self.assertAllEqual(saved_full, restored_full)

      # Restores into a different number of partitions.
      restored_full = _restore(
          partitioner=partitioned_variables.fixed_size_partitioner(
              num_shards=3))
      self.assertAllEqual(saved_full, restored_full)

      # Now, saves a full variable and restores PartitionedVariable.
      saved_full = _save()
      restored_full = _restore(
          partitioner=partitioned_variables.fixed_size_partitioner(
              num_shards=3))
      self.assertAllEqual(saved_full, restored_full)

  def testPartitionedVariable(self):
    self._testPartitionedVariables(use_resource=False)

  def testPartitionedResourceVariable(self):
    self._testPartitionedVariables(use_resource=True)


class SaveRestoreShardedTestV2(SaveRestoreShardedTest):
  _WRITE_VERSION = saver_pb2.SaverDef.V2

  def testIterators(self):
    save_path = os.path.join(self.get_temp_dir(), "sharded_iterators")

    # Build a graph with 2 parameter nodes on different devices and save.
    with session.Session(
        target="",
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      with sess.graph.device("/cpu:0"):
        ds0 = dataset_ops.Dataset.range(10)
        it0 = dataset_ops.make_initializable_iterator(ds0)
        get_next0 = it0.get_next()
      saveable0 = iterator_ops._IteratorSaveable(
          it0._iterator_resource, name="saveable_it0")

      with sess.graph.device("/cpu:1"):
        ds1 = dataset_ops.Dataset.range(20)
        it1 = dataset_ops.make_initializable_iterator(ds1)
        get_next1 = it1.get_next()
      saveable1 = iterator_ops._IteratorSaveable(
          it1._iterator_resource, name="saveable_it1")
      saver = saver_module.Saver({
          "it0": saveable0,
          "it1": saveable1
      },
                                 write_version=self._WRITE_VERSION,
                                 sharded=True)
      self.evaluate(it0.initializer)
      self.evaluate(it1.initializer)
      self.assertEqual(0, self.evaluate(get_next0))
      self.assertEqual(1, self.evaluate(get_next0))
      self.assertEqual(0, self.evaluate(get_next1))
      val = saver.save(sess, save_path)
      self.assertEqual(save_path, val)
      data_files = glob.glob(save_path + ".data*")
      self.assertEqual(2, len(data_files))

    # Restore
    with session.Session(
        target="",
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      with sess.graph.device("/cpu:0"):
        ds0 = dataset_ops.Dataset.range(10)
        it0 = dataset_ops.make_initializable_iterator(ds0)
        get_next0 = it0.get_next()
      saveable0 = iterator_ops._IteratorSaveable(
          it0._iterator_resource, name="saveable_it0")

      with sess.graph.device("/cpu:1"):
        ds1 = dataset_ops.Dataset.range(20)
        it1 = dataset_ops.make_initializable_iterator(ds1)
        get_next1 = it1.get_next()
      saveable1 = iterator_ops._IteratorSaveable(
          it1._iterator_resource, name="saveable_it1")
      saver = saver_module.Saver({
          "it0": saveable0,
          "it1": saveable1
      },
                                 write_version=self._WRITE_VERSION,
                                 sharded=True)
      self.evaluate(it0.initializer)
      self.evaluate(it1.initializer)
      saver.restore(sess, save_path)
      self.assertEqual(2, self.evaluate(get_next0))
      self.assertEqual(1, self.evaluate(get_next1))

  def testIteratorsUnshardedRestore(self):
    save_path = os.path.join(self.get_temp_dir(), "restore_unsharded_iterators")

    # Build a graph with 2 parameter nodes on different devices and save.
    with session.Session(
        target="",
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      with sess.graph.device("/cpu:0"):
        ds0 = dataset_ops.Dataset.range(10)
        it0 = dataset_ops.make_initializable_iterator(ds0)
        get_next0 = it0.get_next()
      saveable0 = iterator_ops._IteratorSaveable(
          it0._iterator_resource, name="saveable_it0")

      with sess.graph.device("/cpu:1"):
        ds1 = dataset_ops.Dataset.range(20)
        it1 = dataset_ops.make_initializable_iterator(ds1)
        get_next1 = it1.get_next()
      saveable1 = iterator_ops._IteratorSaveable(
          it1._iterator_resource, name="saveable_it1")
      saver = saver_module.Saver({
          "it0": saveable0,
          "it1": saveable1
      },
                                 write_version=self._WRITE_VERSION,
                                 sharded=True)
      self.evaluate(it0.initializer)
      self.evaluate(it1.initializer)
      self.assertEqual(0, self.evaluate(get_next0))
      self.assertEqual(1, self.evaluate(get_next0))
      self.assertEqual(0, self.evaluate(get_next1))
      val = saver.save(sess, save_path)
      self.assertEqual(save_path, val)
      data_files = glob.glob(save_path + ".data*")
      self.assertEqual(2, len(data_files))

    # Restore
    with session.Session(
        target="",
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      with sess.graph.device("/cpu:0"):
        ds0 = dataset_ops.Dataset.range(10)
        it0 = dataset_ops.make_initializable_iterator(ds0)
        get_next0 = it0.get_next()
      saveable0 = iterator_ops._IteratorSaveable(
          it0._iterator_resource, name="saveable_it0")

      with sess.graph.device("/cpu:1"):
        ds1 = dataset_ops.Dataset.range(20)
        it1 = dataset_ops.make_initializable_iterator(ds1)
        get_next1 = it1.get_next()
      saveable1 = iterator_ops._IteratorSaveable(
          it1._iterator_resource, name="saveable_it1")
      saver = saver_module.Saver({
          "it0": saveable0,
          "it1": saveable1
      },
                                 write_version=self._WRITE_VERSION,
                                 sharded=False)
      self.evaluate(it0.initializer)
      self.evaluate(it1.initializer)
      saver.restore(sess, save_path)
      self.assertEqual(2, self.evaluate(get_next0))
      self.assertEqual(1, self.evaluate(get_next1))


class MaxToKeepTest(test.TestCase):

  def _get_test_dir(self, dirname):
    test_dir = os.path.join(self.get_temp_dir(), dirname)
    gfile.MakeDirs(test_dir)
    return test_dir

  def assertCheckpointState(self, model_checkpoint_path,
                            all_model_checkpoint_paths, save_dir):
    checkpoint_state = checkpoint_management.get_checkpoint_state(save_dir)
    self.assertEqual(checkpoint_state.model_checkpoint_path,
                     model_checkpoint_path)
    self.assertEqual(checkpoint_state.all_model_checkpoint_paths,
                     all_model_checkpoint_paths)

  def testMaxToKeepEager(self):
    with context.eager_mode():
      save_dir = self._get_test_dir("max_to_keep_eager")

      v = variable_v1.VariableV1(10.0, name="v")
      save = saver_module.Saver({"v": v}, max_to_keep=2)
      self.evaluate(variables.global_variables_initializer())
      if not context.executing_eagerly():
        self.assertEqual([], save.last_checkpoints)

      s1 = save.save(None, os.path.join(save_dir, "s1"))
      self.assertEqual([s1], save.last_checkpoints)
      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
      self.assertCheckpointState(
          model_checkpoint_path=s1,
          all_model_checkpoint_paths=[s1],
          save_dir=save_dir)

      s2 = save.save(None, os.path.join(save_dir, "s2"))
      self.assertEqual([s1, s2], save.last_checkpoints)
      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
      self.assertCheckpointState(
          model_checkpoint_path=s2,
          all_model_checkpoint_paths=[s1, s2],
          save_dir=save_dir)

      s3 = save.save(None, os.path.join(save_dir, "s3"))
      self.assertEqual([s2, s3], save.last_checkpoints)
      self.assertFalse(checkpoint_management.checkpoint_exists(s1))
      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
      self.assertTrue(checkpoint_management.checkpoint_exists(s3))
      self.assertCheckpointState(
          model_checkpoint_path=s3,
          all_model_checkpoint_paths=[s2, s3],
          save_dir=save_dir)

      # Create a second helper, identical to the first.
      save2 = saver_module.Saver({"v": v}, max_to_keep=2)
      save2.set_last_checkpoints(save.last_checkpoints)

      # Exercise the first helper.

      # Adding s2 again (old s2 is removed first, then new s2 appended)
      s2 = save.save(None, os.path.join(save_dir, "s2"))
      self.assertEqual([s3, s2], save.last_checkpoints)
      self.assertFalse(checkpoint_management.checkpoint_exists(s1))
      self.assertTrue(checkpoint_management.checkpoint_exists(s3))
      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
      self.assertCheckpointState(
          model_checkpoint_path=s2,
          all_model_checkpoint_paths=[s3, s2],
          save_dir=save_dir)

      # Adding s1 (s3 should now be deleted as oldest in list)
      s1 = save.save(None, os.path.join(save_dir, "s1"))
      self.assertEqual([s2, s1], save.last_checkpoints)
      self.assertFalse(checkpoint_management.checkpoint_exists(s3))
      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
      self.assertCheckpointState(
          model_checkpoint_path=s1,
          all_model_checkpoint_paths=[s2, s1],
          save_dir=save_dir)

      s2 = save2.save(None, os.path.join(save_dir, "s2"))
      self.assertEqual([s3, s2], save2.last_checkpoints)
      # Created by the first helper.
      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
      # Deleted by the first helper.
      self.assertFalse(checkpoint_management.checkpoint_exists(s3))

  def testNonSharded(self):
    save_dir = self._get_test_dir("max_to_keep_non_sharded")

    # train.Saver is V1 only API.
    with ops_lib.Graph().as_default(), self.cached_session() as sess:
      v = variable_v1.VariableV1(10.0, name="v")
      save = saver_module.Saver({"v": v}, max_to_keep=2)
      self.evaluate(variables.global_variables_initializer())
      self.assertEqual([], save.last_checkpoints)

      s1 = save.save(sess, os.path.join(save_dir, "s1"))
      self.assertEqual([s1], save.last_checkpoints)
      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
      self.assertCheckpointState(
          model_checkpoint_path=s1,
          all_model_checkpoint_paths=[s1],
          save_dir=save_dir)

      s2 = save.save(sess, os.path.join(save_dir, "s2"))
      self.assertEqual([s1, s2], save.last_checkpoints)
      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
      self.assertCheckpointState(
          model_checkpoint_path=s2,
          all_model_checkpoint_paths=[s1, s2],
          save_dir=save_dir)

      s3 = save.save(sess, os.path.join(save_dir, "s3"))
      self.assertEqual([s2, s3], save.last_checkpoints)
      self.assertFalse(checkpoint_management.checkpoint_exists(s1))
      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
      self.assertTrue(checkpoint_management.checkpoint_exists(s3))
      self.assertCheckpointState(
          model_checkpoint_path=s3,
          all_model_checkpoint_paths=[s2, s3],
          save_dir=save_dir)

      # Create a second helper, identical to the first.
      save2 = saver_module.Saver(saver_def=save.as_saver_def())
      save2.set_last_checkpoints(save.last_checkpoints)

      # Create a third helper, with the same configuration but no knowledge of
      # previous checkpoints.
      save3 = saver_module.Saver(saver_def=save.as_saver_def())

      # Exercise the first helper.

      # Adding s2 again (old s2 is removed first, then new s2 appended)
      s2 = save.save(sess, os.path.join(save_dir, "s2"))
      self.assertEqual([s3, s2], save.last_checkpoints)
      self.assertFalse(checkpoint_management.checkpoint_exists(s1))
      self.assertFalse(
          checkpoint_management.checkpoint_exists(
              checkpoint_management.meta_graph_filename(s1)))
      self.assertTrue(checkpoint_management.checkpoint_exists(s3))
      self.assertTrue(
          checkpoint_management.checkpoint_exists(
              checkpoint_management.meta_graph_filename(s3)))
      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
      self.assertTrue(
          checkpoint_management.checkpoint_exists(
              checkpoint_management.meta_graph_filename(s2)))
      self.assertCheckpointState(
          model_checkpoint_path=s2,
          all_model_checkpoint_paths=[s3, s2],
          save_dir=save_dir)

      # Adding s1 (s3 should now be deleted as oldest in list)
      s1 = save.save(sess, os.path.join(save_dir, "s1"))
      self.assertEqual([s2, s1], save.last_checkpoints)
      self.assertFalse(checkpoint_management.checkpoint_exists(s3))
      self.assertFalse(
          checkpoint_management.checkpoint_exists(
              checkpoint_management.meta_graph_filename(s3)))
      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
      self.assertTrue(
          checkpoint_management.checkpoint_exists(
              checkpoint_management.meta_graph_filename(s2)))
      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
      self.assertTrue(
          checkpoint_management.checkpoint_exists(
              checkpoint_management.meta_graph_filename(s1)))
      self.assertCheckpointState(
          model_checkpoint_path=s1,
          all_model_checkpoint_paths=[s2, s1],
          save_dir=save_dir)

      # Exercise the second helper.

      # Adding s2 again (old s2 is removed first, then new s2 appended)
      s2 = save2.save(sess, os.path.join(save_dir, "s2"))
      self.assertEqual([s3, s2], save2.last_checkpoints)
      # Created by the first helper.
      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
      self.assertTrue(
          checkpoint_management.checkpoint_exists(
              checkpoint_management.meta_graph_filename(s1)))
      # Deleted by the first helper.
      self.assertFalse(checkpoint_management.checkpoint_exists(s3))
      self.assertFalse(
          checkpoint_management.checkpoint_exists(
              checkpoint_management.meta_graph_filename(s3)))
      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
      self.assertTrue(
          checkpoint_management.checkpoint_exists(
              checkpoint_management.meta_graph_filename(s2)))
      self.assertCheckpointState(
          model_checkpoint_path=s2,
          all_model_checkpoint_paths=[s3, s2],
          save_dir=save_dir)

      # Adding s1 (s3 should now be deleted as oldest in list)
      s1 = save2.save(sess, os.path.join(save_dir, "s1"))
      self.assertEqual([s2, s1], save2.last_checkpoints)
      self.assertFalse(checkpoint_management.checkpoint_exists(s3))
      self.assertFalse(
          checkpoint_management.checkpoint_exists(
              checkpoint_management.meta_graph_filename(s3)))
      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
      self.assertTrue(
          checkpoint_management.checkpoint_exists(
              checkpoint_management.meta_graph_filename(s2)))
      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
      self.assertTrue(
          checkpoint_management.checkpoint_exists(
              checkpoint_management.meta_graph_filename(s1)))
      self.assertCheckpointState(
          model_checkpoint_path=s1,
          all_model_checkpoint_paths=[s2, s1],
          save_dir=save_dir)

      # Exercise the third helper.

      # Adding s2 again (but helper is unaware of previous s2)
      s2 = save3.save(sess, os.path.join(save_dir, "s2"))
      self.assertEqual([s2], save3.last_checkpoints)
      # Created by the first helper.
      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
      self.assertTrue(
          checkpoint_management.checkpoint_exists(
              checkpoint_management.meta_graph_filename(s1)))
      # Deleted by the first helper.
      self.assertFalse(checkpoint_management.checkpoint_exists(s3))
      self.assertFalse(
          checkpoint_management.checkpoint_exists(
              checkpoint_management.meta_graph_filename(s3)))
      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
      self.assertTrue(
          checkpoint_management.checkpoint_exists(
              checkpoint_management.meta_graph_filename(s2)))
      # Even though the file for s1 exists, this saver isn't aware of it, which
      # is why it doesn't end up in the checkpoint state.
      self.assertCheckpointState(
          model_checkpoint_path=s2,
          all_model_checkpoint_paths=[s2],
          save_dir=save_dir)

      # Adding s1 (s3 should not be deleted because helper is unaware of it)
      s1 = save3.save(sess, os.path.join(save_dir, "s1"))
      self.assertEqual([s2, s1], save3.last_checkpoints)
      self.assertFalse(checkpoint_management.checkpoint_exists(s3))
      self.assertFalse(
          checkpoint_management.checkpoint_exists(
              checkpoint_management.meta_graph_filename(s3)))
      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
      self.assertTrue(
          checkpoint_management.checkpoint_exists(
              checkpoint_management.meta_graph_filename(s2)))
      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
      self.assertTrue(
          checkpoint_management.checkpoint_exists(
              checkpoint_management.meta_graph_filename(s1)))
      self.assertCheckpointState(
          model_checkpoint_path=s1,
          all_model_checkpoint_paths=[s2, s1],
          save_dir=save_dir)

  def testSharded(self):
    save_dir = self._get_test_dir("max_to_keep_sharded")

    with session.Session(
        target="",
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      with sess.graph.device("/cpu:0"):
        v0 = variable_v1.VariableV1(111, name="v0")
      with sess.graph.device("/cpu:1"):
        v1 = variable_v1.VariableV1(222, name="v1")
      save = saver_module.Saver(
          {
              "v0": v0,
              "v1": v1
          }, sharded=True, max_to_keep=2)
      self.evaluate(variables.global_variables_initializer())
      self.assertEqual([], save.last_checkpoints)

      s1 = save.save(sess, os.path.join(save_dir, "s1"))
      self.assertEqual([s1], save.last_checkpoints)
      if save._write_version is saver_pb2.SaverDef.V1:
        self.assertEqual(2, len(gfile.Glob(s1)))
      else:
        self.assertEqual(4, len(gfile.Glob(s1 + "*")))

      self.assertTrue(
          gfile.Exists(checkpoint_management.meta_graph_filename(s1)))

      s2 = save.save(sess, os.path.join(save_dir, "s2"))
      self.assertEqual([s1, s2], save.last_checkpoints)
      if save._write_version is saver_pb2.SaverDef.V1:
        self.assertEqual(2, len(gfile.Glob(s1)))
      else:
        self.assertEqual(4, len(gfile.Glob(s1 + "*")))
      self.assertTrue(
          gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
      if save._write_version is saver_pb2.SaverDef.V1:
        self.assertEqual(2, len(gfile.Glob(s2)))
      else:
        self.assertEqual(4, len(gfile.Glob(s2 + "*")))
      self.assertTrue(
          gfile.Exists(checkpoint_management.meta_graph_filename(s2)))

      s3 = save.save(sess, os.path.join(save_dir, "s3"))
      self.assertEqual([s2, s3], save.last_checkpoints)
      self.assertEqual(0, len(gfile.Glob(s1 + "*")))
      self.assertFalse(
          gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
      if save._write_version is saver_pb2.SaverDef.V1:
        self.assertEqual(2, len(gfile.Glob(s2)))
      else:
        self.assertEqual(4, len(gfile.Glob(s2 + "*")))
      self.assertTrue(
          gfile.Exists(checkpoint_management.meta_graph_filename(s2)))
      if save._write_version is saver_pb2.SaverDef.V1:
        self.assertEqual(2, len(gfile.Glob(s3)))
      else:
        self.assertEqual(4, len(gfile.Glob(s3 + "*")))
      self.assertTrue(
          gfile.Exists(checkpoint_management.meta_graph_filename(s3)))

  def testNoMaxToKeep(self):
    save_dir = self._get_test_dir("no_max_to_keep")
    save_dir2 = self._get_test_dir("max_to_keep_0")

    with self.cached_session() as sess:
      v = variable_v1.VariableV1(10.0, name="v")
      self.evaluate(variables.global_variables_initializer())

      # Test max_to_keep being None.
      save = saver_module.Saver({"v": v}, max_to_keep=None)
      self.assertEqual([], save.last_checkpoints)
      s1 = save.save(sess, os.path.join(save_dir, "s1"))
      self.assertEqual([], save.last_checkpoints)
      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
      s2 = save.save(sess, os.path.join(save_dir, "s2"))
      self.assertEqual([], save.last_checkpoints)
      self.assertTrue(checkpoint_management.checkpoint_exists(s2))

      # Test max_to_keep being 0.
      save2 = saver_module.Saver({"v": v}, max_to_keep=0)
      self.assertEqual([], save2.last_checkpoints)
      s1 = save2.save(sess, os.path.join(save_dir2, "s1"))
      self.assertEqual([], save2.last_checkpoints)
      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
      s2 = save2.save(sess, os.path.join(save_dir2, "s2"))
      self.assertEqual([], save2.last_checkpoints)
      self.assertTrue(checkpoint_management.checkpoint_exists(s2))

  def testNoMetaGraph(self):
    save_dir = self._get_test_dir("no_meta_graph")

    with self.cached_session() as sess:
      v = variable_v1.VariableV1(10.0, name="v")
      save = saver_module.Saver({"v": v})
      self.evaluate(variables.global_variables_initializer())

      s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False)
      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
      self.assertFalse(
          gfile.Exists(checkpoint_management.meta_graph_filename(s1)))


class RecoverLastCheckpointsTest(test.TestCase):

  def _get_test_dir(self, dirname):
    test_dir = os.path.join(self.get_temp_dir(), dirname)
    gfile.MakeDirs(test_dir)
    return test_dir

  def assertCheckpointState(self, model_checkpoint_path,
                            all_model_checkpoint_paths, save_dir):
    checkpoint_state = checkpoint_management.get_checkpoint_state(save_dir)
    self.assertEqual(checkpoint_state.model_checkpoint_path,
                     model_checkpoint_path)
    self.assertEqual(checkpoint_state.all_model_checkpoint_paths,
                     all_model_checkpoint_paths)

  def test_recover_last_checkpoints(self):
    with context.eager_mode():
      save_dir = self._get_test_dir("recover_last_checkpoints")

      v = variable_v1.VariableV1(10.0, name="v")
      save = saver_module.Saver({"v": v}, max_to_keep=10)
      self.evaluate(variables.global_variables_initializer())
      self.assertEqual([], save.last_checkpoints)

      s1 = save.save(None, os.path.join(save_dir, "ckpt-1"))
      s2 = save.save(None, os.path.join(save_dir, "ckpt-2"))
      s3 = save.save(None, os.path.join(save_dir, "ckpt-3"))
      self.assertEqual([s1, s2, s3], save.last_checkpoints)
      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
      self.assertTrue(checkpoint_management.checkpoint_exists(s3))
      self.assertCheckpointState(
          model_checkpoint_path=s3,
          all_model_checkpoint_paths=[s1, s2, s3],
          save_dir=save_dir)

      # Create another saver and recover last checkpoints.
      save2 = saver_module.Saver({"v": v}, max_to_keep=10)
      self.assertEqual([], save2.last_checkpoints)
      save2.recover_last_checkpoints([s1, s2, s3])
      self.assertEqual([s1, s2, s3], save2.last_checkpoints)

      # Remove a checkpoint and check that last checkpoints are
      # restored correctly.
      for fname in gfile.Glob("{}*".format(s1)):
        gfile.Remove(fname)
      self.assertFalse(checkpoint_management.checkpoint_exists(s1))

      # Create another saver and recover last checkpoints. The removed
      # checkpoint would be correctly omitted.
      save3 = saver_module.Saver({"v": v}, max_to_keep=10)
      self.assertEqual([], save3.last_checkpoints)
      save3.recover_last_checkpoints([s1, s2, s3])
      self.assertEqual([s2, s3], save3.last_checkpoints)
      s4 = save3.save(None, os.path.join(save_dir, "ckpt-4"))
      self.assertCheckpointState(
          model_checkpoint_path=s4,
          all_model_checkpoint_paths=[s2, s3, s4],
          save_dir=save_dir)


class KeepCheckpointEveryNHoursTest(test.TestCase):

  def _get_test_dir(self, dirname):
    test_dir = os.path.join(self.get_temp_dir(), dirname)
    gfile.MakeDirs(test_dir)
    return test_dir

  @test_util.run_in_graph_and_eager_modes
  @test.mock.patch.object(saver_module, "time")
  def testNonSharded(self, mock_time):
    save_dir = self._get_test_dir("keep_checkpoint_every_n_hours")

    with self.cached_session() as sess:
      v = variable_v1.VariableV1([10.0], name="v")
      # Run the initializer NOW to avoid the 0.5s overhead of the first Run()
      # call, which throws the test timing off in fastbuild mode.
      self.evaluate(variables.global_variables_initializer())
      # Create a saver that will keep the last 2 checkpoints plus one every 0.7
      # seconds.
      start_time = time.time()
      mock_time.time.return_value = start_time
      save = saver_module.Saver(
          {
              "v": v
          }, max_to_keep=2, keep_checkpoint_every_n_hours=0.7 / 3600)
      self.assertEqual([], save.last_checkpoints)

      # Wait till 1 seconds have elapsed so s1 will be old enough to keep.
      # sleep may return early, don't trust it.
      mock_time.time.return_value = start_time + 1.0
      s1 = save.save(sess, os.path.join(save_dir, "s1"))
      self.assertEqual([s1], save.last_checkpoints)

      s2 = save.save(sess, os.path.join(save_dir, "s2"))
      self.assertEqual([s1, s2], save.last_checkpoints)

      # We now have 2 'last_checkpoints': [s1, s2].  The next call to Save(),
      # would normally delete s1, because max_to_keep is 2.  However, s1 is
      # older than 0.7s so we must keep it.
      s3 = save.save(sess, os.path.join(save_dir, "s3"))
      self.assertEqual([s2, s3], save.last_checkpoints)

      # s1 should still be here, we are Not checking now to reduce time
      # variance in the test.

      # We now have 2 'last_checkpoints': [s2, s3], and s1 on disk.  The next
      # call to Save(), will delete s2, because max_to_keep is 2, and because
      # we already kept the old s1. s2 is very close in time to s1 so it gets
      # deleted.
      s4 = save.save(sess, os.path.join(save_dir, "s4"))
      self.assertEqual([s3, s4], save.last_checkpoints)

      # Check that s1 is still here, but s2 is gone.
      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
      self.assertFalse(checkpoint_management.checkpoint_exists(s2))
      self.assertTrue(checkpoint_management.checkpoint_exists(s3))
      self.assertTrue(checkpoint_management.checkpoint_exists(s4))


class SaveRestoreWithVariableNameMap(test.TestCase):

  def _testNonReshape(self, variable_op):
    save_path = os.path.join(self.get_temp_dir(), "non_reshape")

    with self.session(graph=ops_lib.Graph()) as sess:
      # Build a graph with 2 parameter nodes, and Save and
      # Restore nodes for them.
      v0 = variable_op(10.0, name="v0")
      v1 = variable_op(20.0, name="v1")
      save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})
      self.evaluate(variables.global_variables_initializer())

      # Check that the parameter nodes have been initialized.
      self.assertEqual(10.0, self.evaluate(v0))
      self.assertEqual(20.0, self.evaluate(v1))

      # Save the initialized values in the file at "save_path"
      # Use a variable name map to set the saved tensor names
      val = save.save(sess, save_path)
      self.assertIsInstance(val, str)
      self.assertEqual(save_path, val)

      # Verify that the original names are not in the Saved file
      save = saver_module.Saver({"v0": v0, "v1": v1})
      with self.assertRaisesOpError("not found in checkpoint"):
        save.restore(sess, save_path)

    # Verify that the mapped names are present in the Saved file and can be
    # Restored using remapped names.
    with self.session(graph=ops_lib.Graph()) as sess:
      v0 = variable_op(-1.0, name="v0")
      v1 = variable_op(-1.0, name="v1")

      if not context.executing_eagerly():
        with self.assertRaisesOpError("uninitialized"):
          self.evaluate(v0)
        with self.assertRaisesOpError("uninitialized"):
          self.evaluate(v1)

      save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})
      save.restore(sess, save_path)

      # Check that the parameter nodes have been restored.
      if not context.executing_eagerly():
        self.assertEqual(10.0, self.evaluate(v0))
        self.assertEqual(20.0, self.evaluate(v1))

    # Add a prefix to the node names in the current graph and Restore using
    # remapped names.
    with self.session(graph=ops_lib.Graph()) as sess:
      v0 = variable_op(-1.0, name="restore_prefix/v0")
      v1 = variable_op(-1.0, name="restore_prefix/v1")

      if not context.executing_eagerly():
        with self.assertRaisesOpError("uninitialized"):
          self.evaluate(v0)
        with self.assertRaisesOpError("uninitialized"):
          self.evaluate(v1)

      # Restore the saved values in the parameter nodes.
      save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})
      save.restore(sess, save_path)

      # Check that the parameter nodes have been restored.
      self.assertEqual(10.0, self.evaluate(v0))
      self.assertEqual(20.0, self.evaluate(v1))

  @test_util.run_in_graph_and_eager_modes
  def testNonReshapeResourceVariable(self):
    self._testNonReshape(resource_variable_ops.ResourceVariable)

  def testNonReshapeVariable(self):
    self._testNonReshape(variables.Variable)


class MetaGraphTest(test.TestCase):

  def _get_test_dir(self, dirname):
    test_dir = os.path.join(self.get_temp_dir(), dirname)
    gfile.MakeDirs(test_dir)
    return test_dir

  @test_util.run_v1_only(
      "Queue-based input pipelines have been replaced by `tf.data` "
      "and not supported in V2.")
  def testAddCollectionDef(self):
    test_dir = self._get_test_dir("good_collection")
    filename = os.path.join(test_dir, "metafile")
    with self.cached_session():
      # Creates a graph.
      v0 = variable_v1.VariableV1(1.0, name="v0")
      cond.cond(
          math_ops.less(v0, 10), lambda: math_ops.add(v0, 1),
          lambda: math_ops.subtract(v0, 1))
      while_loop.while_loop(lambda i: math_ops.less(i, 10),
                            lambda i: math_ops.add(i, 1), [v0])
      var = variable_v1.VariableV1(constant_op.constant(0, dtype=dtypes.int64))
      count_up_to = var.count_up_to(3)
      input_queue = data_flow_ops.FIFOQueue(
          30, dtypes.float32, shared_name="collection_queue")
      qr = queue_runner_impl.QueueRunner(input_queue, [count_up_to])
      variables.global_variables_initializer()
      # Creates a saver.
      save = saver_module.Saver({"v0": v0})
      # Adds a set of collections.
      ops_lib.add_to_collection("int_collection", 3)
      ops_lib.add_to_collection("float_collection", 3.5)
      ops_lib.add_to_collection("string_collection", "hello")
      ops_lib.add_to_collection("variable_collection", v0)
      # Add QueueRunners.
      queue_runner_impl.add_queue_runner(qr)
      # Adds user_defined proto in three formats: string, bytes and Any.
      queue_runner = queue_runner_pb2.QueueRunnerDef(queue_name="test_queue")
      ops_lib.add_to_collection("user_defined_string_collection",
                                str(queue_runner))
      ops_lib.add_to_collection("user_defined_bytes_collection",
                                queue_runner.SerializeToString())
      any_buf = Any()
      any_buf.Pack(queue_runner)
      ops_lib.add_to_collection("user_defined_any_collection", any_buf)

      # Generates MetaGraphDef.
      meta_graph_def = save.export_meta_graph(filename)
      self.assertTrue(meta_graph_def.HasField("saver_def"))
      self.assertTrue(meta_graph_def.HasField("graph_def"))
      self.assertTrue(meta_graph_def.HasField("meta_info_def"))
      self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_version, "")
      self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_git_version,
                          "")
      collection_def = meta_graph_def.collection_def
      self.assertEqual(len(collection_def), 12)

    with ops_lib.Graph().as_default():
      # Restores from MetaGraphDef.
      new_saver = saver_module.import_meta_graph(filename)
      # Generates a new MetaGraphDef.
      new_meta_graph_def = new_saver.export_meta_graph()
      # It should be the same as the original.

    test_util.assert_meta_graph_protos_equal(
        self, meta_graph_def, new_meta_graph_def)

  def testAddCollectionDefFails(self):
    with self.cached_session():
      # Creates a graph.
      v0 = variable_v1.VariableV1(10.0, name="v0")
      # Creates a saver.
      save = saver_module.Saver({"v0": v0})
      # Generates MetaGraphDef.
      meta_graph_def = meta_graph_pb2.MetaGraphDef()

      # Verifies that collection with unsupported key will not be added.
      ops_lib.add_to_collection(save, 3)
      save._add_collection_def(meta_graph_def, save)
      self.assertEqual(len(meta_graph_def.collection_def), 0)

      # Verifies that collection where item type does not match expected
      # type will not be added.
      ops_lib.add_to_collection("int_collection", 3)
      ops_lib.add_to_collection("int_collection", 3.5)
      save._add_collection_def(meta_graph_def, "int_collection")
      self.assertEqual(len(meta_graph_def.collection_def), 0)

  def _testMultiSaverCollectionSave(self, test_dir):
    filename = os.path.join(test_dir, "metafile")
    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
    saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
    with self.session(graph=ops_lib.Graph()) as sess:
      # Creates a graph.
      v0 = variable_v1.VariableV1([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
                                  name="v0")
      v1 = variable_v1.VariableV1(11.0, name="v1")
      # Creates 2 savers.
      saver0 = saver_module.Saver({"v0": v0}, name="saver0")
      saver1 = saver_module.Saver({"v1": v1}, name="saver1")
      ops_lib.add_to_collection("savers", saver0)
      ops_lib.add_to_collection("savers", saver1)
      self.evaluate(variables.global_variables_initializer())
      # Saves to different checkpoints.
      saver0.save(sess, saver0_ckpt)
      saver1.save(sess, saver1_ckpt)
      # Generates MetaGraphDef.
      meta_graph_def = saver_module.export_meta_graph(filename)
      meta_graph_def0 = saver0.export_meta_graph()
      meta_graph_def1 = saver1.export_meta_graph()

      # Verifies that there is no saver_def in meta_graph_def.
      self.assertFalse(meta_graph_def.HasField("saver_def"))
      # Verifies that there is saver_def in meta_graph_def0 and 1.
      self.assertTrue(meta_graph_def0.HasField("saver_def"))
      self.assertTrue(meta_graph_def1.HasField("saver_def"))

      # Verifies SAVERS is saved as bytes_list for meta_graph_def.
      collection_def = meta_graph_def.collection_def["savers"]
      kind = collection_def.WhichOneof("kind")
      self.assertEqual(kind, "bytes_list")
      # Verifies that there are 2 entries in SAVERS collection.
      savers = getattr(collection_def, kind)
      self.assertEqual(2, len(savers.value))

      # Verifies SAVERS collection is saved as bytes_list for meta_graph_def0.
      collection_def = meta_graph_def0.collection_def["savers"]
      kind = collection_def.WhichOneof("kind")
      self.assertEqual(kind, "bytes_list")
      # Verifies that there are 2 entries in SAVERS collection.
      savers = getattr(collection_def, kind)
      self.assertEqual(2, len(savers.value))

  def _testMultiSaverCollectionRestore(self, test_dir):
    filename = os.path.join(test_dir, "metafile")
    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
    saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
    with self.session(graph=ops_lib.Graph()) as sess:
      # Imports from meta_graph.
      saver_module.import_meta_graph(filename)
      # Retrieves SAVERS collection. Verifies there are 2 entries.
      savers = ops_lib.get_collection("savers")
      self.assertEqual(2, len(savers))
      # Retrieves saver0. Verifies that new_saver0 can restore v0, but not v1.
      new_saver0 = savers[0]
      new_saver0.restore(sess, saver0_ckpt)
      v0 = sess.graph.get_tensor_by_name("v0:0")
      v1 = sess.graph.get_tensor_by_name("v1:0")
      self.assertAllEqual([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
                          self.evaluate(v0))
      self.assertEqual([3, 2], v0.get_shape())
      self.assertEqual([], v1.get_shape())
      with self.assertRaisesWithPredicateMatch(
          errors_impl.OpError, lambda e: "uninitialized value v1" in e.message):
        self.evaluate(v1)
      # Retrieves saver1. Verifies that new_saver1 can restore v1.
      new_saver1 = savers[1]
      new_saver1.restore(sess, saver1_ckpt)
      v1 = sess.graph.get_tensor_by_name("v1:0")
      self.assertEqual(11.0, self.evaluate(v1))

  @test_util.run_v1_only(
      "Exporting/importing meta graphs is only supported in V1.")
  def testMultiSaverCollection(self):
    test_dir = self._get_test_dir("saver_collection")
    self._testMultiSaverCollectionSave(test_dir)
    self._testMultiSaverCollectionRestore(test_dir)

  @test_util.run_v1_only(
      "Exporting/importing meta graphs is only supported in V1.")
  def testClearExtraneousSavers(self):
    test_dir = self._get_test_dir("clear_extraneous_savers")
    filename = os.path.join(test_dir, "metafile")
    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
    saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
    with self.session(graph=ops_lib.Graph()) as sess:
      # Creates a graph.
      v0 = variable_v1.VariableV1([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
                                  name="v0")
      v1 = variable_v1.VariableV1(11.0, name="v1")

      # Creates 2 savers.
      saver0 = saver_module.Saver({"v0": v0}, name="saver0")
      saver1 = saver_module.Saver({"v1": v1}, name="saver1")
      ops_lib.add_to_collection("savers", saver0)
      ops_lib.add_to_collection("savers", saver1)
      self.evaluate(variables.global_variables_initializer())

      # Saves to different checkpoints.
      saver0.save(sess, saver0_ckpt)
      saver1.save(sess, saver1_ckpt)

      # Generates MetaGraphDef.
      meta_graph_def = saver_module.export_meta_graph(filename)
      meta_graph_def0 = saver0.export_meta_graph()
      meta_graph_def1 = saver1.export_meta_graph(clear_extraneous_savers=True)

      # Verifies that there is no saver_def in meta_graph_def.
      self.assertFalse(meta_graph_def.HasField("saver_def"))
      # Verifies that there is saver_def in meta_graph_def0 and 1.
      self.assertTrue(meta_graph_def0.HasField("saver_def"))
      self.assertTrue(meta_graph_def1.HasField("saver_def"))

      # Verifies SAVERS is saved as bytes_list for meta_graph_def.
      collection_def = meta_graph_def.collection_def["savers"]
      kind = collection_def.WhichOneof("kind")
      self.assertEqual(kind, "bytes_list")

      # Verifies that there are 2 entries in SAVERS collection.
      savers = getattr(collection_def, kind)
      self.assertEqual(2, len(savers.value))

      # Verifies SAVERS collection is saved as bytes_list for meta_graph_def1.
      collection_def = meta_graph_def1.collection_def["savers"]
      kind = collection_def.WhichOneof("kind")
      self.assertEqual(kind, "bytes_list")

      # Verifies that there is 1 entry in SAVERS collection.
      savers = getattr(collection_def, kind)
      self.assertEqual(1, len(savers.value))

      # Verifies that saver0 graph nodes are omitted from the saver1 export
      self.assertEqual(33, len(meta_graph_def0.graph_def.node))
      self.assertEqual(21, len(meta_graph_def1.graph_def.node))

  def testBinaryAndTextFormat(self):
    test_dir = self._get_test_dir("binary_and_text")
    filename = os.path.join(test_dir, "metafile")
    # train.Saver is V1 only API.
    with ops_lib.Graph().as_default(), self.session():
      # Creates a graph.
      variable_v1.VariableV1(10.0, name="v0")
      # Exports the graph as binary format.
      saver_module.export_meta_graph(filename, as_text=False)
    with ops_lib.Graph().as_default(), self.session():
      # Imports the binary format graph.
      saver = saver_module.import_meta_graph(filename)
      self.assertIsNotNone(saver)
      # Exports the graph as text format.
      saver.export_meta_graph(filename, as_text=True)
    with ops_lib.Graph().as_default(), self.session():
      # Imports the text format graph.
      saver_module.import_meta_graph(filename)
      # Writes wrong contents to the file.
      graph_io.write_graph(saver.as_saver_def(),
                           os.path.dirname(filename),
                           os.path.basename(filename))
    with ops_lib.Graph().as_default(), self.session():
      # Import should fail.
      with self.assertRaisesWithPredicateMatch(IOError,
                                               lambda e: "Cannot parse file"):
        saver_module.import_meta_graph(filename)
      # Deletes the file
      gfile.Remove(filename)
      with self.assertRaisesWithPredicateMatch(IOError,
                                               lambda e: "does not exist"):
        saver_module.import_meta_graph(filename)

  @test_util.run_v1_only(
      "Exporting/importing meta graphs is only supported in V1.")
  def testSliceVariable(self):
    test_dir = self._get_test_dir("slice_saver")
    filename = os.path.join(test_dir, "metafile")
    with self.cached_session():
      v1 = variable_v1.VariableV1([20.0], name="v1")
      v2 = variable_v1.VariableV1([20.0], name="v2")
      v2._set_save_slice_info(
          variables.Variable.SaveSliceInfo("v1", [1], [0], [1]))

      # The names are different and will work.
      slice_saver = saver_module.Saver({"first": v1, "second": v2})
      self.evaluate(variables.global_variables_initializer())
      # Exports to meta_graph
      meta_graph_def = slice_saver.export_meta_graph(filename)

    with ops_lib.Graph().as_default():
      # Restores from MetaGraphDef.
      new_saver = saver_module.import_meta_graph(filename)
      self.assertIsNotNone(new_saver)
      # Generates a new MetaGraphDef.
      new_meta_graph_def = new_saver.export_meta_graph()
      # It should be the same as the original.
      test_util.assert_meta_graph_protos_equal(self, meta_graph_def,
                                               new_meta_graph_def)

  def _testGraphExtensionSave(self, test_dir):
    filename = os.path.join(test_dir, "metafile")
    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
    # Creates an inference graph.
    # Hidden 1
    images = constant_op.constant(1.2, dtypes.float32, shape=[100, 28])
    with ops_lib.name_scope("hidden1"):
      weights = variable_v1.VariableV1(
          random_ops.truncated_normal([28, 128],
                                      stddev=1.0 / math.sqrt(float(28))),
          name="weights")
      # The use of cond.cond here is purely for adding test coverage
      # the save and restore of control flow context (which doesn't make any
      # sense here from a machine learning perspective).  The typical biases is
      # a simple Variable without the conditions.
      biases = variable_v1.VariableV1(
          cond.cond(
              math_ops.less(random.random(), 0.5),
              lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),
          name="biases")
      hidden1 = nn_ops.relu(math_ops.matmul(images, weights) + biases)
    # Hidden 2
    with ops_lib.name_scope("hidden2"):
      weights = variable_v1.VariableV1(
          random_ops.truncated_normal([128, 32],
                                      stddev=1.0 / math.sqrt(float(128))),
          name="weights")

      # The use of while_loop.while_loop here is purely for adding test
      # coverage the save and restore of control flow context (which doesn't
      # make any sense here from a machine learning perspective).  The typical
      # biases is a simple Variable without the conditions.
      def loop_cond(it, _):
        return it < 2

      def loop_body(it, biases):
        biases += constant_op.constant(0.1, shape=[32])
        return it + 1, biases

      _, biases = while_loop.while_loop(loop_cond, loop_body, [
          constant_op.constant(0),
          variable_v1.VariableV1(array_ops.zeros([32]))
      ])
      hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases)
    # Linear
    with ops_lib.name_scope("softmax_linear"):
      weights = variable_v1.VariableV1(
          random_ops.truncated_normal([32, 10],
                                      stddev=1.0 / math.sqrt(float(32))),
          name="weights")
      biases = variable_v1.VariableV1(array_ops.zeros([10]), name="biases")
      logits = math_ops.matmul(hidden2, weights) + biases
      ops_lib.add_to_collection("logits", logits)
    init_all_op = variables.global_variables_initializer()

    with self.cached_session() as sess:
      # Initializes all the variables.
      self.evaluate(init_all_op)
      # Runs to logit.
      self.evaluate(logits)
      # Creates a saver.
      saver0 = saver_module.Saver()
      saver0.save(sess, saver0_ckpt)
      # Generates MetaGraphDef.
      saver0.export_meta_graph(filename)

  def _testGraphExtensionRestore(self, test_dir):
    filename = os.path.join(test_dir, "metafile")
    train_filename = os.path.join(test_dir, "train_metafile")
    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
    with self.session(graph=ops_lib.Graph()) as sess:
      # Restores from MetaGraphDef.
      new_saver = saver_module.import_meta_graph(filename)
      # Generates a new MetaGraphDef.
      new_saver.export_meta_graph()
      # Restores from checkpoint.
      new_saver.restore(sess, saver0_ckpt)
      # Adds loss and train.
      labels = constant_op.constant(0, dtypes.int32, shape=[100], name="labels")
      batch_size = array_ops.size(labels)
      labels = array_ops.expand_dims(labels, 1)
      indices = array_ops.expand_dims(math_ops.range(0, batch_size), 1)
      concated = array_ops.concat([indices, labels], 1)
      onehot_labels = sparse_ops.sparse_to_dense(
          concated, array_ops_stack.stack([batch_size, 10]), 1.0, 0.0)
      logits = ops_lib.get_collection("logits")[0]
      cross_entropy = nn_ops.softmax_cross_entropy_with_logits(
          labels=onehot_labels, logits=logits, name="xentropy")
      loss = math_ops.reduce_mean(cross_entropy, name="xentropy_mean")

      summary.scalar("loss", loss)
      # Creates the gradient descent optimizer with the given learning rate.
      optimizer = gradient_descent.GradientDescentOptimizer(0.01)

      # Runs train_op.
      train_op = optimizer.minimize(loss)
      ops_lib.add_to_collection("train_op", train_op)

      # Runs train_op.
      self.evaluate(train_op)

      # Generates MetaGraphDef.
      saver_module.export_meta_graph(train_filename)

  def _testRestoreFromTrainGraphWithControlContext(self, test_dir):
    train_filename = os.path.join(test_dir, "train_metafile")
    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
    with self.session(graph=ops_lib.Graph()) as sess:
      # Restores from MetaGraphDef.
      new_saver = saver_module.import_meta_graph(train_filename)
      # Restores from checkpoint.
      new_saver.restore(sess, saver0_ckpt)
      train_op = ops_lib.get_collection("train_op")[0]
      self.evaluate(train_op)

  def testGraphExtension(self):
    test_dir = self._get_test_dir("graph_extension")
    # train.Saver and train.import_meta_graph are V1 only APIs.
    with ops_lib.Graph().as_default():
      self._testGraphExtensionSave(test_dir)
      self._testGraphExtensionRestore(test_dir)
      self._testRestoreFromTrainGraphWithControlContext(test_dir)

  def _testGradientSerDes(self, graph_fn):
    """Tests that gradients can be computed after exporting and importing.

    Builds a graph, exports it, and verifies that it can be imported and the
    gradient can be built and run correctly.

    Args:
      graph_fn: takes a single float Tensor argument as input, outputs a single
        Tensor
    """
    test_dir = self._get_test_dir("nested_control_flow")
    filename = os.path.join(test_dir, "metafile")
    saver_ckpt = os.path.join(test_dir, "saver.ckpt")

    # Create while loop using `outer_body_fn`.
    with ops_lib.Graph().as_default():
      var = variable_v1.VariableV1(0.0)
      var_name = var.name
      output = graph_fn(var)
      output_name = output.name
      init_op = variables.global_variables_initializer()

      # Generate a MetaGraphDef containing the while loop.
      with session.Session() as sess:
        self.evaluate(init_op)
        self.evaluate(output)
        saver = saver_module.Saver()
        saver.save(sess, saver_ckpt)
        saver.export_meta_graph(filename)

      # Build and run the gradients of the while loop. We use this below to
      # verify that the gradients are correct with an imported MetaGraphDef.
      grad = gradients_impl.gradients([output], [var])
      # Turn off constant folding to avoid breaking testNestedControlFlowSerDes.
      # It appears that a missing control dependency in the gradient graph
      # causes the fetch node to not be triggered.
      no_constfold_config = config_pb2.ConfigProto()
      no_constfold_config.graph_options.rewrite_options.constant_folding = (
          rewriter_config_pb2.RewriterConfig.OFF)
      with session.Session(config=no_constfold_config) as sess:
        self.evaluate(init_op)
        expected_grad_value = self.evaluate(grad)

    # To avoid graph name collisions between original and loaded code.
    context._reset_context()   # pylint: disable=protected-access

    # Restore the MetaGraphDef into a new Graph.
    with ops_lib.Graph().as_default():
      with session.Session() as sess:
        saver = saver_module.import_meta_graph(filename)
        saver.restore(sess, saver_ckpt)

      # Make sure we can still build gradients and get the same result.
      var = ops_lib.get_default_graph().get_tensor_by_name(var_name)
      output = ops_lib.get_default_graph().get_tensor_by_name(output_name)
      grad = gradients_impl.gradients([output], [var])

      init_op = variables.global_variables_initializer()

      with session.Session(config=no_constfold_config) as sess:
        self.evaluate(init_op)
        actual_grad_value = self.evaluate(grad)
        self.assertEqual(expected_grad_value, actual_grad_value)

  def _testWhileLoopAndGradientSerDes(self, outer_body_fn):
    # Build a while loop with `outer_body_fn`, export it, and verify that it can
    # be imported and the gradient can be built and run correctly.
    # pylint: disable=g-long-lambda
    return self._testGradientSerDes(lambda x: while_loop.while_loop(
        lambda i, y: i < 5, outer_body_fn, [0, x])[1])
    # pylint: enable=g-long-lambda

  def testNestedWhileLoopsSerDes(self):
    # Test two simple nested while loops.
    def body(i, x):
      _, r = while_loop.while_loop(
          lambda j, y: j < 3,
          lambda j, y: (j + 1, y + x),
          [0, 0.0])
      return i + 1, x + r
    self._testWhileLoopAndGradientSerDes(body)

  def testNestedControlFlowSerDes(self):
    # Test while loop in a cond in a while loop.
    # pylint: disable=g-long-lambda
    def body(i, x):
      cond_result = cond.cond(
          i > 0,
          lambda: while_loop.while_loop(
              lambda j, y: j < 3,
              lambda j, y: (j + 1, y + x),
              [0, 0.0])[1],
          lambda: x)
      return i + 1, cond_result
    # pylint: enable=g-long-lambda
    self._testWhileLoopAndGradientSerDes(body)

  def testNestedCondsSerDes(self):
    # Test conds in a cond.
    # pylint: disable=g-long-lambda
    self._testGradientSerDes(lambda x: cond.cond(
        x > 0,
        lambda: cond.cond(x > 3,
                          lambda: array_ops.identity(x),
                          lambda: math_ops.multiply(x, 2.0)),
        lambda: cond.cond(x < -3,
                          lambda: constant_op.constant(1.0),
                          lambda: math_ops.multiply(x, -1.0))))
    # pylint: enable=g-long-lambda

  @test_util.run_v1_only("This exercises Tensor.op which is meaningless in V2.")
  def testStrippedOpListDef(self):
    with self.cached_session():
      # Creates a graph.
      v0 = variable_v1.VariableV1(0.0)
      var = variable_v1.VariableV1(10.0)
      math_ops.add(v0, var)

      @function.Defun(dtypes.float32)
      def minus_one(x):
        return x - 1

      minus_one(array_ops.identity(v0))
      save = saver_module.Saver({"v0": v0})
      variables.global_variables_initializer()

      # Generates MetaGraphDef.
      meta_graph_def = save.export_meta_graph()
      ops = [o.name for o in meta_graph_def.meta_info_def.stripped_op_list.op]
      if save._write_version is saver_pb2.SaverDef.V1:
        self.assertEqual(ops, [
            "AddV2", "Assign", "Const", "Identity", "NoOp",
            "PlaceholderWithDefault", "RestoreV2", "SaveSlices", "Sub",
            "VariableV2"
        ])
      else:
        self.assertEqual(ops, [
            "AddV2", "Assign", "Const", "Identity", "NoOp",
            "PlaceholderWithDefault", "RestoreV2", "SaveV2", "Sub", "VariableV2"
        ])

      # Test calling stripped_op_list_for_graph directly
      op_list = meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def)
      self.assertEqual(ops, [o.name for o in op_list.op])
      for o in op_list.op:
        self.assertEqual(o.summary, "")
        self.assertEqual(o.description, "")

  def testStripDefaultValuedAttrs(self):
    """Verifies that default valued attrs are stripped, unless disabled."""

    # With strip_default_attrs enabled, attributes "T" (float32) and "Tout"
    # (complex64) in the "Complex" op must be removed.
    # train.Saver and train.export_meta_graph are V1 only APIs.
    with ops_lib.Graph().as_default(), self.cached_session():
      real_num = variable_v1.VariableV1(1.0, dtype=dtypes.float32, name="real")
      imag_num = variable_v1.VariableV1(2.0, dtype=dtypes.float32, name="imag")
      math_ops.complex(real_num, imag_num, name="complex")

      save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num})
      variables.global_variables_initializer()

      meta_graph_def = save.export_meta_graph(strip_default_attrs=True)
      node_def = test_util.get_node_def_from_graph("complex",
                                                   meta_graph_def.graph_def)
      self.assertNotIn("T", node_def.attr)
      self.assertNotIn("Tout", node_def.attr)

    # With strip_default_attrs disabled, attributes "T" (float32) and "Tout"
    # (complex64) in the "Complex" op must *not* be removed, even if they map
    # to their defaults.
    with ops_lib.Graph().as_default(), self.session():
      real_num = variable_v1.VariableV1(1.0, dtype=dtypes.float32, name="real")
      imag_num = variable_v1.VariableV1(2.0, dtype=dtypes.float32, name="imag")
      math_ops.complex(real_num, imag_num, name="complex")

      save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num})
      variables.global_variables_initializer()

      meta_graph_def = save.export_meta_graph(strip_default_attrs=False)
      node_def = test_util.get_node_def_from_graph("complex",
                                                   meta_graph_def.graph_def)
      self.assertIn("T", node_def.attr)
      self.assertIn("Tout", node_def.attr)

  def testImportIntoNamescope(self):
    # Test that we can import a meta graph into a namescope.
    test_dir = self._get_test_dir("import_into_namescope")
    filename = os.path.join(test_dir, "ckpt")
    # train.Saver is V1 only API.
    with ops_lib.Graph().as_default():
      image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
      label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
      with session.Session() as sess:
        weights = variable_v1.VariableV1(
            random_ops.random_uniform([784, 10]), name="weights")
        bias = variable_v1.VariableV1(array_ops.zeros([10]), name="bias")
        logit = nn_ops.relu(
            math_ops.matmul(image, weights) + bias, name="logits")
        nn_ops.softmax(logit, name="prediction")
        cost = nn_ops.softmax_cross_entropy_with_logits(
            labels=label, logits=logit, name="cost")
        adam.AdamOptimizer().minimize(cost, name="optimize")
        saver = saver_module.Saver()
        self.evaluate(variables.global_variables_initializer())
        saver.save(sess, filename)

    graph = ops_lib.Graph()
    with session.Session(graph=graph) as sess:
      new_saver = saver_module.import_meta_graph(
          filename + ".meta", graph=graph, import_scope="new_model")
      new_saver.restore(sess, filename)
      sess.run(["new_model/optimize"], {
          "new_model/image:0": np.random.random([1, 784]),
          "new_model/label:0": np.random.randint(
              10, size=[1, 10])
      })

  def testImportIntoNamescopeWithoutVariables(self):
    # Save a simple graph that contains no variables into a checkpoint.
    test_dir = self._get_test_dir("no_vars_graph")
    filename = os.path.join(test_dir, "ckpt")
    graph_1 = ops_lib.Graph()
    with session.Session(graph=graph_1) as sess:
      constant_op.constant([1, 2, 3], name="x")
      constant_op.constant([1, 2, 3], name="y")
      saver = saver_module.Saver(allow_empty=True)
      saver.save(sess, filename)

    # Create a fresh graph.
    graph_2 = ops_lib.Graph()
    with session.Session(graph=graph_2) as sess:
      # Restore the above checkpoint under scope "subgraph_1".
      new_saver_1 = saver_module.import_meta_graph(
          filename + ".meta", graph=graph_2, import_scope="subgraph_1")
      # There are no variables to restore, so import_meta_graph should not
      # return a Saver.
      self.assertIsNone(new_saver_1)

      # Create a variable in graph_2 under scope "my_scope".
      variable_v1.VariableV1(array_ops.zeros([10]), name="my_scope/my_var")
      self.evaluate(variables.global_variables_initializer())
      # Restore the checkpoint into a different scope "subgraph_2".
      new_saver_2 = saver_module.import_meta_graph(
          filename + ".meta", graph=graph_2, import_scope="subgraph_2")
      # Because the variable does not live in scope "subgraph_2",
      # import_meta_graph should not attempt to restore the variable. So,
      # import_meta_graph still won't return a Saver instance.
      self.assertIsNone(new_saver_2)

      # However, if we restore the checkpoint under scope "my_scope",
      # import_meta_graph will detect the variable and return a Saver for
      # restoring it. This should happen even when the variable does not
      # originate from graph_1.
      new_saver_3 = saver_module.import_meta_graph(
          filename + ".meta", graph=graph_2, import_scope="my_scope")
      self.assertIsInstance(new_saver_3, saver_module.Saver)

  def testImportIntoImplicitNamescope(self):
    # Test that we can import a meta graph into an implicit namescope.
    test_dir = self._get_test_dir("import_into_namescope")
    filename = os.path.join(test_dir, "ckpt")
    # train.Saver is V1 only API.
    with ops_lib.Graph().as_default():
      image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
      label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
      with session.Session() as sess:
        weights = variable_v1.VariableV1(
            random_ops.random_uniform([784, 10]), name="weights")
        bias = variable_v1.VariableV1(array_ops.zeros([10]), name="bias")
        logit = nn_ops.relu(
            math_ops.matmul(image, weights) + bias, name="logits")
        nn_ops.softmax(logit, name="prediction")
        cost = nn_ops.softmax_cross_entropy_with_logits(
            labels=label, logits=logit, name="cost")
        adam.AdamOptimizer().minimize(cost, name="optimize")
        saver = saver_module.Saver()
        self.evaluate(variables.global_variables_initializer())
        saver.save(sess, filename)

    graph = ops_lib.Graph()
    with session.Session(graph=graph) as sess:
      with ops_lib.name_scope("new_model"):
        new_saver = saver_module.import_meta_graph(
            filename + ".meta", graph=graph)

      new_saver.restore(sess, filename)
      sess.run(["new_model/optimize"], {
          "new_model/image:0": np.random.random([1, 784]),
          "new_model/label:0": np.random.randint(
              10, size=[1, 10])
      })

  def testClearDevicesOnImport(self):
    # Test that we import a graph without its devices and run successfully.
    with ops_lib.Graph().as_default():
      with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"):
        image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
        label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
        weights = variable_v1.VariableV1(
            random_ops.random_uniform([784, 10]), name="weights")
        bias = variable_v1.VariableV1(array_ops.zeros([10]), name="bias")
        logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
        nn_ops.softmax(logit, name="prediction")
        cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
                                                        logits=logit)
        adam.AdamOptimizer().minimize(cost, name="optimize")
      meta_graph_def = saver_module.export_meta_graph()

    with session.Session(graph=ops_lib.Graph()) as sess:
      saver_module.import_meta_graph(
          meta_graph_def, clear_devices=False, import_scope="new_model")
      # Device refers to GPU, which is not available here.
      with self.assertRaises(errors_impl.InvalidArgumentError):
        self.evaluate(variables.global_variables_initializer())

    with session.Session(graph=ops_lib.Graph()) as sess:
      saver_module.import_meta_graph(
          meta_graph_def, clear_devices=True, import_scope="new_model")
      self.evaluate(variables.global_variables_initializer())
      sess.run(["new_model/optimize"], {
          "new_model/image:0": np.random.random([1, 784]),
          "new_model/label:0": np.random.randint(
              10, size=[1, 10])
      })

  def testClearDevicesOnExport(self):
    # Test that we export a graph without its devices and run successfully.
    with ops_lib.Graph().as_default():
      with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"):
        image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
        label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
        weights = variable_v1.VariableV1(
            random_ops.random_uniform([784, 10]), name="weights")
        bias = variable_v1.VariableV1(array_ops.zeros([10]), name="bias")
        logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
        nn_ops.softmax(logit, name="prediction")
        cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
                                                        logits=logit)
        adam.AdamOptimizer().minimize(cost, name="optimize")
      meta_graph_def = saver_module.export_meta_graph(clear_devices=True)
      graph_io.write_graph(meta_graph_def, self.get_temp_dir(),
                           "meta_graph.pbtxt")

    with session.Session(graph=ops_lib.Graph()) as sess:
      saver_module.import_meta_graph(meta_graph_def, import_scope="new_model")
      self.evaluate(variables.global_variables_initializer())
      sess.run(["new_model/optimize"], {
          "new_model/image:0": np.random.random([1, 784]),
          "new_model/label:0": np.random.randint(
              10, size=[1, 10])
      })

  def testPreserveDatasetAndFunctions(self):
    with ops_lib.Graph().as_default() as g:
      dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x)
      iterator = dataset_ops.make_one_shot_iterator(dataset)
      next_element = iterator.get_next()
      _ = array_ops.identity(next_element, name="output")

      # Generate three MetaGraphDef protos using different code paths.
      meta_graph_def_simple = saver_module.export_meta_graph()
      meta_graph_def_devices_cleared = saver_module.export_meta_graph(
          clear_devices=True)
      meta_graph_def_from_graph_def = saver_module.export_meta_graph(
          clear_devices=True, graph_def=g.as_graph_def())

    for meta_graph_def in [meta_graph_def_simple,
                           meta_graph_def_devices_cleared,
                           meta_graph_def_from_graph_def]:
      with session.Session(graph=ops_lib.Graph()) as sess:
        saver_module.import_meta_graph(meta_graph_def, import_scope="new_model")
        self.evaluate(variables.global_variables_initializer())
        for i in range(10):
          self.assertEqual(i * i, sess.run("new_model/output:0"))
        with self.assertRaises(errors.OutOfRangeError):
          sess.run("new_model/output:0")


class CheckpointReaderTest(test.TestCase):

  _WRITE_VERSION = saver_pb2.SaverDef.V1

  def testDebugString(self):
    # Builds a graph.
    v0 = variable_v1.VariableV1([[1, 2, 3], [4, 5, 6]],
                                dtype=dtypes.float32,
                                name="v0")
    v1 = variable_v1.VariableV1([[[1], [2]], [[3], [4]], [[5], [6]]],
                                dtype=dtypes.float32,
                                name="v1")
    init_all_op = variables.global_variables_initializer()
    save = saver_module.Saver(
        {
            "v0": v0,
            "v1": v1
        }, write_version=self._WRITE_VERSION)
    save_path = os.path.join(self.get_temp_dir(),
                             "ckpt_for_debug_string" + str(self._WRITE_VERSION))
    with self.cached_session() as sess:
      self.evaluate(init_all_op)
      # Saves a checkpoint.
      save.save(sess, save_path)

      # Creates a reader.
      reader = py_checkpoint_reader.NewCheckpointReader(save_path)
      # Verifies that the tensors exist.
      self.assertTrue(reader.has_tensor("v0"))
      self.assertTrue(reader.has_tensor("v1"))
      debug_string = reader.debug_string()
      # Verifies that debug string contains the right strings.
      self.assertIn(compat.as_bytes("v0 (DT_FLOAT) [2,3]"), debug_string)
      self.assertIn(compat.as_bytes("v1 (DT_FLOAT) [3,2,1]"), debug_string)
      # Verifies get_variable_to_shape_map() returns the correct information.
      var_map = reader.get_variable_to_shape_map()
      self.assertEqual([2, 3], var_map["v0"])
      self.assertEqual([3, 2, 1], var_map["v1"])
      # Verifies get_tensor() returns the tensor value.
      v0_tensor = reader.get_tensor("v0")
      v1_tensor = reader.get_tensor("v1")
      self.assertAllEqual(v0, v0_tensor)
      self.assertAllEqual(v1, v1_tensor)
      # Verifies get_tensor() fails for non-existent tensors.
      with self.assertRaisesRegex(errors.NotFoundError,
                                  "v3 not found in checkpoint"):
        reader.get_tensor("v3")

  def testNonexistentPath(self):
    with self.assertRaisesRegex(errors.NotFoundError,
                                "Unsuccessful TensorSliceReader"):
      py_checkpoint_reader.NewCheckpointReader("non-existent")


class CheckpointReaderForV2Test(CheckpointReaderTest):
  _WRITE_VERSION = saver_pb2.SaverDef.V2


class WriteGraphTest(test.TestCase):

  def _get_test_dir(self, dirname):
    test_dir = os.path.join(self.get_temp_dir(), dirname)
    gfile.MakeDirs(test_dir)
    return test_dir

  def testWriteGraph(self):
    test_dir = self._get_test_dir("write_graph_dir")
    variable_v1.VariableV1([[1, 2, 3], [4, 5, 6]],
                           dtype=dtypes.float32,
                           name="v0")
    path = graph_io.write_graph(ops_lib.get_default_graph(),
                                os.path.join(test_dir, "l1"), "graph.pbtxt")
    truth = os.path.join(test_dir, "l1", "graph.pbtxt")
    self.assertEqual(path, truth)
    self.assertTrue(os.path.exists(path))

  def testRecursiveCreate(self):
    test_dir = self._get_test_dir("deep_dir")
    variable_v1.VariableV1([[1, 2, 3], [4, 5, 6]],
                           dtype=dtypes.float32,
                           name="v0")
    path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(),
                                os.path.join(test_dir, "l1", "l2", "l3"),
                                "graph.pbtxt")
    truth = os.path.join(test_dir, "l1", "l2", "l3", "graph.pbtxt")
    self.assertEqual(path, truth)
    self.assertTrue(os.path.exists(path))


class ScopedGraphTest(test.TestCase):

  def _get_test_dir(self, dirname):
    test_dir = os.path.join(self.get_temp_dir(), dirname)
    gfile.MakeDirs(test_dir)
    return test_dir

  def _testScopedSave(self, test_dir, exported_filename, ckpt_filename):
    graph = ops_lib.Graph()
    with graph.as_default():
      # Creates an inference graph.
      # Hidden 1
      images = constant_op.constant(
          1.2, dtypes.float32, shape=[100, 28], name="images")
      with ops_lib.name_scope("hidden1"):
        weights1 = variable_v1.VariableV1(
            random_ops.truncated_normal([28, 128],
                                        stddev=1.0 / math.sqrt(float(28))),
            name="weights")
        # The use of cond.cond here is purely for adding test
        # coverage the save and restore of control flow context (which doesn't
        # make any sense here from a machine learning perspective).  The typical
        # biases is a simple Variable without the conditions.
        biases1 = variable_v1.VariableV1(
            cond.cond(
                math_ops.less(random.random(), 0.5),
                lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),
            name="biases")
        hidden1 = nn_ops.relu(math_ops.matmul(images, weights1) + biases1)

      # Hidden 2
      with ops_lib.name_scope("hidden2"):
        weights2 = variable_v1.VariableV1(
            random_ops.truncated_normal([128, 32],
                                        stddev=1.0 / math.sqrt(float(128))),
            name="weights")

        # The use of while_loop.while_loop here is purely for adding test
        # coverage the save and restore of control flow context (which doesn't
        # make any sense here from a machine learning perspective).  The typical
        # biases is a simple Variable without the conditions.
        def loop_cond(it, _):
          return it < 2

        def loop_body(it, biases2):
          biases2 += constant_op.constant(0.1, shape=[32])
          return it + 1, biases2

        _, biases2 = while_loop.while_loop(loop_cond, loop_body, [
            constant_op.constant(0),
            variable_v1.VariableV1(array_ops.zeros([32]))
        ])
        hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights2) + biases2)
      # Linear
      with ops_lib.name_scope("softmax_linear"):
        weights3 = variable_v1.VariableV1(
            random_ops.truncated_normal([32, 10],
                                        stddev=1.0 / math.sqrt(float(32))),
            name="weights")
        biases3 = variable_v1.VariableV1(array_ops.zeros([10]), name="biases")
        logits = math_ops.matmul(hidden2, weights3) + biases3
        ops_lib.add_to_collection("logits", logits)

        # Adds user_defined proto in three formats: string, bytes and Any.
        # Any proto should just pass through.
        queue_runner = queue_runner_pb2.QueueRunnerDef(queue_name="test_queue")
        ops_lib.add_to_collection("user_defined_string_collection",
                                  str(queue_runner))
        ops_lib.add_to_collection("user_defined_bytes_collection",
                                  queue_runner.SerializeToString())
        any_buf = Any()
        any_buf.Pack(queue_runner)
        ops_lib.add_to_collection("user_defined_any_collection", any_buf)

      _, var_list = meta_graph.export_scoped_meta_graph(
          filename=os.path.join(test_dir, exported_filename),
          graph=ops_lib.get_default_graph(),
          export_scope="hidden1")
      self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))

    with graph.as_default(), self.session() as sess:
      self.evaluate(variables.global_variables_initializer())
      saver = saver_module.Saver(var_list=var_list, max_to_keep=1)
      saver.save(sess, os.path.join(test_dir, ckpt_filename), write_state=False)

  def _testScopedRestore(self, test_dir, exported_filename,
                         new_exported_filename, ckpt_filename):
    graph = ops_lib.Graph()
    # Create all the missing inputs.
    with graph.as_default():
      new_image = constant_op.constant(
          1.2, dtypes.float32, shape=[100, 28], name="images")
      var_list = meta_graph.import_scoped_meta_graph(
          os.path.join(test_dir, exported_filename),
          graph=graph,
          input_map={"$unbound_inputs_images": new_image},
          import_scope="new_hidden1")
      self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
      hidden1 = graph.as_graph_element("new_hidden1/Relu:0")
      weights1 = graph.as_graph_element("new_hidden1/weights:0")
      biases1 = graph.as_graph_element("new_hidden1/biases:0")

    with graph.as_default():
      # Hidden 2
      with ops_lib.name_scope("hidden2"):
        weights = variable_v1.VariableV1(
            random_ops.truncated_normal([128, 32],
                                        stddev=1.0 / math.sqrt(float(128))),
            name="weights")

        # The use of while_loop.while_loop here is purely for adding test
        # coverage the save and restore of control flow context (which doesn't
        # make any sense here from a machine learning perspective).  The typical
        # biases is a simple Variable without the conditions.
        def loop_cond(it, _):
          return it < 2

        def loop_body(it, biases):
          biases += constant_op.constant(0.1, shape=[32])
          return it + 1, biases

        _, biases = while_loop.while_loop(loop_cond, loop_body, [
            constant_op.constant(0),
            variable_v1.VariableV1(array_ops.zeros([32]))
        ])
        hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases)
      # Linear
      with ops_lib.name_scope("softmax_linear"):
        weights = variable_v1.VariableV1(
            random_ops.truncated_normal([32, 10],
                                        stddev=1.0 / math.sqrt(float(32))),
            name="weights")
        biases = variable_v1.VariableV1(array_ops.zeros([10]), name="biases")
        logits = math_ops.matmul(hidden2, weights) + biases
        ops_lib.add_to_collection("logits", logits)

      # The rest of the variables.
      rest_variables = list(
          set(variables.global_variables()) - set(var_list.keys()))
      init_rest_op = variables.variables_initializer(rest_variables)

    with graph.as_default(), self.session() as sess:
      saver = saver_module.Saver(var_list=var_list, max_to_keep=1)
      saver.restore(sess, os.path.join(test_dir, ckpt_filename))
      # Verify that we have restored weights1 and biases1.
      self.evaluate([weights1, biases1])
      # Initialize the rest of the variables and run logits.
      self.evaluate(init_rest_op)
      self.evaluate(logits)

  # Verifies that we can save the subgraph under "hidden1" and restore it
  # into "new_hidden1" in the new graph.
  def testScopedSaveAndRestore(self):
    test_dir = self._get_test_dir("scoped_export_import")
    ckpt_filename = "ckpt"
    self._testScopedSave(test_dir, "exported_hidden1.pbtxt", ckpt_filename)
    self._testScopedRestore(test_dir, "exported_hidden1.pbtxt",
                            "exported_new_hidden1.pbtxt", ckpt_filename)

  # Verifies that we can copy the subgraph under "hidden1" and copy it
  # to different name scope in the same graph or different graph.
  def testCopyScopedGraph(self):
    test_dir = self._get_test_dir("scoped_copy")
    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
    graph1 = ops_lib.Graph()
    with graph1.as_default():
      with ops_lib.name_scope("hidden1"):
        images = constant_op.constant(
            1.0, dtypes.float32, shape=[3, 2], name="images")
        weights1 = variable_v1.VariableV1([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
                                          name="weights")
        biases1 = variable_v1.VariableV1([0.1] * 3, name="biases")
        nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")

    # Run the graph and save scoped checkpoint.
    with graph1.as_default(), self.session(graph=graph1) as sess:
      self.evaluate(variables.global_variables_initializer())
      _, var_list_1 = meta_graph.export_scoped_meta_graph(
          export_scope="hidden1")
      saver = saver_module.Saver(var_list=var_list_1, max_to_keep=1)
      saver.save(sess, saver0_ckpt, write_state=False)

    expected = np.reshape([[5.0999999, 7.0999999, 9.10000038] * 3], (3, 3))

    # Verifies copy to the same graph with the same name fails.
    with graph1.as_default():
      with self.assertRaisesWithPredicateMatch(
          ValueError, lambda e: "need to be different" in str(e)):
        meta_graph.copy_scoped_meta_graph(
            from_scope="hidden1", to_scope="hidden1")

    # Verifies copy to the same graph.
    with graph1.as_default():
      var_list_2 = meta_graph.copy_scoped_meta_graph(
          from_scope="hidden1", to_scope="hidden2")

    with graph1.as_default(), self.session(graph=graph1) as sess:
      saver1 = saver_module.Saver(var_list=var_list_1, max_to_keep=1)
      saver1.restore(sess, saver0_ckpt)
      saver2 = saver_module.Saver(var_list=var_list_2, max_to_keep=1)
      saver2.restore(sess, saver0_ckpt)
      self.assertAllClose(expected, sess.run("hidden1/relu:0"))
      self.assertAllClose(expected, sess.run("hidden2/relu:0"))

    # Verifies copy to different graph.
    graph2 = ops_lib.Graph()
    with graph2.as_default():
      new_var_list_1 = meta_graph.copy_scoped_meta_graph(
          from_scope="hidden1",
          to_scope="new_hidden1",
          from_graph=graph1,
          to_graph=graph2)

      with self.session() as sess:
        saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1)
        saver3.restore(sess, saver0_ckpt)
        self.assertAllClose(expected, sess.run("new_hidden1/relu:0"))

  def testExportGraphDefWithScope(self):
    test_dir = self._get_test_dir("export_graph_def")
    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
    graph1 = ops_lib.Graph()
    with graph1.as_default():
      with ops_lib.name_scope("hidden1"):
        images = constant_op.constant(
            1.0, dtypes.float32, shape=[3, 2], name="images")
        weights1 = variable_v1.VariableV1([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
                                          name="weights")
        biases1 = variable_v1.VariableV1([0.1] * 3, name="biases")
        nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")

      # Run the graph and save scoped checkpoint.
      with self.session(graph=graph1) as sess:
        self.evaluate(variables.global_variables_initializer())
        _, var_list_1 = meta_graph.export_scoped_meta_graph(
            graph_def=graph1.as_graph_def(), export_scope="hidden1")
        saver = saver_module.Saver(var_list=var_list_1, max_to_keep=1)
        saver.save(sess, saver0_ckpt, write_state=False)

    expected = np.reshape([[5.0999999, 7.0999999, 9.10000038] * 3], (3, 3))

    # Verifies that we can run successfully after restoring.
    graph2 = ops_lib.Graph()
    with graph2.as_default():
      new_var_list_1 = meta_graph.copy_scoped_meta_graph(
          from_scope="hidden1",
          to_scope="new_hidden1",
          from_graph=graph1,
          to_graph=graph2)

      with self.session(graph=graph2) as sess:
        saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1)
        saver3.restore(sess, saver0_ckpt)
        self.assertAllClose(expected, sess.run("new_hidden1/relu:0"))

  def testSerializeSaverWithScope(self):
    test_dir = self._get_test_dir("export_graph_def")
    saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
    saver2_ckpt = os.path.join(test_dir, "saver2.ckpt")
    graph = ops_lib.Graph()
    with graph.as_default():
      with ops_lib.name_scope("hidden1"):
        variable1 = variable_v1.VariableV1([1.0], name="variable1")
        saver1 = saver_module.Saver(var_list=[variable1])
        graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver1)

      with ops_lib.name_scope("hidden2"):
        variable2 = variable_v1.VariableV1([2.0], name="variable2")
      saver2 = saver_module.Saver(var_list=[variable2], name="hidden2/")
      graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver2)

      with self.session(graph=graph) as sess:
        self.evaluate(variables.global_variables_initializer())
        saver1.save(sess, saver1_ckpt, write_state=False)
        saver2.save(sess, saver2_ckpt, write_state=False)

    graph1 = ops_lib.Graph()
    with graph1.as_default():
      var_dict1 = meta_graph.copy_scoped_meta_graph(
          from_scope="hidden1",
          to_scope="new_hidden1",
          from_graph=graph,
          to_graph=graph1)
      self.assertEqual(1, len(var_dict1))

      saver_list1 = graph1.get_collection(ops_lib.GraphKeys.SAVERS)
      self.assertEqual(1, len(saver_list1))

      with self.session(graph=graph1) as sess:
        saver_list1[0].restore(sess, saver1_ckpt)
        self.assertEqual(1.0, self.evaluate(var_dict1["variable1:0"]))

    graph2 = ops_lib.Graph()
    with graph2.as_default():
      var_dict2 = meta_graph.copy_scoped_meta_graph(
          from_scope="hidden2",
          to_scope="new_hidden2",
          from_graph=graph,
          to_graph=graph2)
      self.assertEqual(1, len(var_dict2))

      saver_list2 = graph2.get_collection(ops_lib.GraphKeys.SAVERS)
      self.assertEqual(1, len(saver_list2))

      with self.session(graph=graph2) as sess:
        saver_list2[0].restore(sess, saver2_ckpt)
        self.assertEqual(2.0, self.evaluate(var_dict2["variable2:0"]))


class _OwnsAVariableSimple(trackable_base.Trackable):
  """A Trackable object which can be saved using a tf.train.Saver."""

  def __init__(self):
    self.non_dep_variable = variable_scope.get_variable(
        name="non_dep_variable", initializer=6., use_resource=True)

  def _gather_saveables_for_checkpoint(self):
    return {trackable_base.VARIABLE_VALUE_KEY: self.non_dep_variable}

  # The Saver sorts by name before parsing, so we need a name property.
  @property
  def name(self):
    return self.non_dep_variable.name


class _MirroringSaveable(
    saver_module.BaseSaverBuilder.ResourceVariableSaveable):

  def __init__(self, primary_variable, mirrored_variable, name):
    self._primary_variable = primary_variable
    self._mirrored_variable = mirrored_variable
    super(_MirroringSaveable, self).__init__(
        self._primary_variable, "", name)

  def restore(self, restored_tensors, restored_shapes):
    """Restore the same value into both variables."""
    tensor, = restored_tensors
    return control_flow_ops.group(
        self._primary_variable.assign(tensor),
        self._mirrored_variable.assign(tensor))


class _OwnsMirroredVariables(trackable_base.Trackable):
  """A Trackable object which returns a more complex SaveableObject."""

  def __init__(self):
    self.non_dep_variable = variable_scope.get_variable(
        name="non_dep_variable", initializer=6., use_resource=True)
    self.mirrored = variable_scope.get_variable(
        name="mirrored", initializer=15., use_resource=True)

  def _gather_saveables_for_checkpoint(self):
    def _saveable_factory(name=self.non_dep_variable.name):
      return _MirroringSaveable(
          primary_variable=self.non_dep_variable,
          mirrored_variable=self.mirrored,
          name=name)
    return {trackable_base.VARIABLE_VALUE_KEY: _saveable_factory}

  # The Saver sorts by name before parsing, so we need a name property.
  @property
  def name(self):
    return self.non_dep_variable.name


class TrackableCompatibilityTests(test.TestCase):

  # TODO(allenl): Track down python3 reference cycles in these tests.
  @test_util.run_in_graph_and_eager_modes
  def testNotSaveableButIsTrackable(self):
    v = _OwnsAVariableSimple()
    test_dir = self.get_temp_dir()
    prefix = os.path.join(test_dir, "ckpt")
    for saver in (saver_module.Saver(var_list=[v]),
                  saver_module.Saver(var_list={"v": v})):
      with self.cached_session() as sess:
        self.evaluate(v.non_dep_variable.assign(42.))
        save_path = saver.save(sess, prefix)
        self.evaluate(v.non_dep_variable.assign(43.))
        saver.restore(sess, save_path)
        self.assertEqual(42., self.evaluate(v.non_dep_variable))

  @test_util.run_in_graph_and_eager_modes
  def testMoreComplexSaveableReturned(self):
    v = _OwnsMirroredVariables()
    test_dir = self.get_temp_dir()
    prefix = os.path.join(test_dir, "ckpt")
    self.evaluate(v.non_dep_variable.assign(42.))
    for saver in (saver_module.Saver(var_list=[v]),
                  saver_module.Saver(var_list={"v": v})):
      with self.cached_session() as sess:
        save_path = saver.save(sess, prefix)
        self.evaluate(v.non_dep_variable.assign(43.))
        self.evaluate(v.mirrored.assign(44.))
        saver.restore(sess, save_path)
        self.assertEqual(42., self.evaluate(v.non_dep_variable))
        self.assertEqual(42., self.evaluate(v.mirrored))

  def testSingleTensorEvaluation(self):

    class _CountingSaveable(saver_module.BaseSaverBuilder.SaveableObject):

      def __init__(self, name):
        self.eval_count = 0
        def _tensor():
          self.eval_count += 1
          return constant_op.constant([1.])
        dummy_op = constant_op.constant([2.])
        super(_CountingSaveable, self).__init__(
            dummy_op,
            [saver_module.BaseSaverBuilder.SaveSpec(
                _tensor, "", name, dtype=dummy_op.dtype,
                device=dummy_op.device)],
            name)

      def restore(self, restored_tensors, restored_shapes):
        """Restore the same value into both variables."""
        pass

    with context.eager_mode():
      v = _CountingSaveable("foo")
      saver = saver_module.Saver(var_list=[v])
      test_dir = self.get_temp_dir()
      prefix = os.path.join(test_dir, "ckpt")
      with self.cached_session() as sess:
        save_path = saver.save(sess, prefix)
        self.assertEqual(1, v.eval_count)
        saver.restore(sess, save_path)
        self.assertEqual(1, v.eval_count)

  def testVariableNotFoundErrorRaised(self):
    # Restore does some tricky exception handling to figure out if it should
    # load an object-based checkpoint. Tests that the exception handling isn't
    # too broad.
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

    a = resource_variable_ops.ResourceVariable(1., name="a")
    b = resource_variable_ops.ResourceVariable(1., name="b")
    a_saver = saver_module.Saver([a])
    b_saver = saver_module.Saver([b])
    with self.cached_session() as sess:
      self.evaluate(a.initializer)
      save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
      with self.assertRaisesRegex(errors.NotFoundError,
                                  "Key b not found in checkpoint"):
        b_saver.restore(sess=sess, save_path=save_path)

      with self.assertRaises(errors.NotFoundError) as cs:
        b_saver.restore(sess=sess, save_path=save_path)

      # Make sure we don't have a confusing "During handling of the above
      # exception" block in Python 3.
      self.assertNotIn("NewCheckpointReader", cs.exception.message)

  @test_util.run_v1_only("train.Saver is V1 only API.")
  def testGraphChangedForRestoreErrorRaised(self):
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

    with ops_lib.Graph().as_default() as g:
      a = variable_v1.VariableV1(1., name="a")
      a_saver = saver_module.Saver([a])

      with self.session(graph=g) as sess:
        self.evaluate(a.initializer)
        save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)

    with ops_lib.Graph().as_default() as g:
      a = variable_v1.VariableV1([1.], name="a")
      a_saver = saver_module.Saver([a])
      with self.session(graph=g) as sess:
        with self.assertRaisesRegex(
            errors.InvalidArgumentError,
            "a mismatch between the current graph and the graph"):
          a_saver.restore(sess=sess, save_path=save_path)


if __name__ == "__main__":
  test.main()