tensorflow/tensorflow

View on GitHub
tensorflow/python/saved_model/load_test.py

Summary

Maintainability
F
1 wk
Test Coverage
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for trackable object SavedModel loading."""

import collections
import contextlib
import functools
import gc
import io
import os
import pathlib
import sys
import tempfile
import unittest
import weakref

from absl.testing import parameterized
import numpy as np

# Import for py bindings to runtime
from tensorflow.python.checkpoint import checkpoint
from tensorflow.python.checkpoint import saveable_compat
from tensorflow.python.client import session as session_lib
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.eager import wrap_function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function as framework_function
from tensorflow.python.framework import op_callbacks
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.framework import versions
from tensorflow.python.lib.io import file_io
from tensorflow.python.lib.io import tf_record
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops import while_loop
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import load_options
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import save
from tensorflow.python.saved_model import save_options
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.trackable import asset
from tensorflow.python.trackable import autotrackable
from tensorflow.python.trackable import resource
from tensorflow.python.training import monitored_session
from tensorflow.python.types import core as types_core
from tensorflow.python.util import tf_inspect


def cycle(
    obj,
    cycles,
    signatures=None,
    save_option=None,
    load_option=None,
    use_cpp_bindings=False,
):
  to_save = obj
  # TODO(vbardiovsky): It would be nice if exported protos reached a fixed
  # point w.r.t. saving/restoring, ideally after 2nd saving.
  for _ in range(cycles):
    path = tempfile.mkdtemp(prefix=test.get_temp_dir())
    # If available, we'll run the save and restore preferring the GPU. This
    # just makes sure we aren't throwing errors and have enough
    # device("CPU") blocks to satisfy the placer.
    with test_util.use_gpu():
      save.save(to_save, path, signatures, options=save_option)
      loaded = test_load(
          path, options=load_option, use_cpp_bindings=use_cpp_bindings
      )
      signatures = loaded.signatures
    to_save = loaded
  return loaded


def _test_load_base(path, tags=None, options=None,
                    use_cpp_bindings=False):  # pylint: disable=unused-argument
  return load.load(path, tags=tags, options=options)


def _test_load_internal(path, tags=None, options=None, use_cpp_bindings=False):
  if use_cpp_bindings:
    runtime = runtime_pybind.Runtime()
    return runtime.Import(path)
  return _test_load_base(path, tags=tags, options=options,
                         use_cpp_bindings=use_cpp_bindings)

# replaced by copy.bara.sky
run_external = True


def test_load(path, **kwargs):
  if not run_external:
    return _test_load_internal(path, **kwargs)
  return _test_load_base(path, **kwargs)


def _load_test_params():
  params = [
      dict(testcase_name="ReloadOncePy", cycles=1, use_cpp_bindings=False),
      dict(testcase_name="ReloadTwicePy", cycles=2, use_cpp_bindings=False),
      dict(testcase_name="ReloadThricePy", cycles=3, use_cpp_bindings=False),
  ]
  if not run_external:
    params.append(dict(testcase_name="ReloadOnceCpp", cycles=1,
                       use_cpp_bindings=True))
  return params


def _test_params():
  params = [dict(testcase_name="LoadWithPython", use_cpp_bindings=False)]
  if not run_external:
    params.append(dict(testcase_name="LoadWithCpp", use_cpp_bindings=True))
  return params


@parameterized.named_parameters(*_load_test_params())
class LoadTest(test.TestCase, parameterized.TestCase):

  def test_structure_import(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = autotrackable.AutoTrackable()
    root.dep_one = autotrackable.AutoTrackable()
    root.dep_two = autotrackable.AutoTrackable()
    root.dep_two.dep = autotrackable.AutoTrackable()
    root.dep_three = root.dep_two.dep
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertIs(imported.dep_three, imported.dep_two.dep)
    self.assertIsNot(imported.dep_one, imported.dep_two)

  @test_util.run_in_graph_and_eager_modes
  def test_variables(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = autotrackable.AutoTrackable()
    root.v1 = variables.Variable(1.0, trainable=True)
    root.v2 = variables.Variable(2.0, trainable=False)
    self.evaluate([root.v1.initializer, root.v2.initializer])

    for _ in range(cycles):
      imported = cycle(root, 1, use_cpp_bindings=use_cpp_bindings)
      self.evaluate([imported.v1.initializer, imported.v2.initializer])

    if not context.executing_eagerly():
      self.assertIsInstance(imported.v1.initializer, ops.Operation)
      self.assertIsInstance(imported.v2.initializer, ops.Operation)

    self.assertEqual(self.evaluate(imported.v1), 1.0)
    self.assertTrue(imported.v1.trainable)
    self.assertEqual(self.evaluate(imported.v2), 2.0)
    self.assertFalse(imported.v2.trainable)

  def test_variables_name(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = autotrackable.AutoTrackable()
    # Test 2 variables with same name: should work as the checkpoint
    # is based on object name and not on variable name.
    root.v1 = variables.Variable(1.0, trainable=True, name="v1")
    root.v2 = variables.Variable(2.0, trainable=False, name="v1")
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertEqual(imported.v1.numpy(), 1.0)
    self.assertEqual(imported.v2.numpy(), 2.0)
    self.assertEqual(imported.v1.name, root.v1.name)
    self.assertEqual(imported.v2.name, root.v2.name)
    with variable_scope.variable_scope("foo"):
      imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
      self.assertTrue(imported.v1.name.startswith("foo/"))
      self.assertTrue(imported.v2.name.startswith("foo/"))

  @test_util.disable_xla("This test never passed for XLA")
  def test_partially_defined_variable_shape(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    class MakeVariable(module.Module):

      def __init__(self):
        self.v = None

      @def_function.function(
          input_signature=[tensor_spec.TensorSpec([None], dtypes.int64)]
      )
      def make_variable(self, initial_value):
        if self.v is None:
          self.v = variables.Variable(initial_value)

    m = MakeVariable()
    m.make_variable([1, 2, 3])
    m = cycle(m, cycles, use_cpp_bindings=use_cpp_bindings)
    m.v.assign([1, 2, 3, 4])
    self.assertEqual([None], tensor_shape.as_shape(m.v.shape).as_list())

  @test_util.run_in_graph_and_eager_modes
  def test_capture_variables(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = autotrackable.AutoTrackable()
    root.weights = variables.Variable(2.0)
    self.evaluate(root.weights.initializer)
    root.f = def_function.function(
        lambda x: root.weights * x,
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)],
    )
    for _ in range(cycles):
      imported = cycle(root, 1, use_cpp_bindings=use_cpp_bindings)
      self.evaluate(imported.weights.initializer)
    self.assertEqual(4.0, self.evaluate(imported.f(constant_op.constant(2.0))))
    self.evaluate(imported.weights.assign(4.0))
    self.assertEqual(8.0, self.evaluate(imported.f(constant_op.constant(2.0))))

  @test_util.run_in_graph_and_eager_modes
  def test_capture_constant(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = autotrackable.AutoTrackable()
    captured_constant = constant_op.constant(2.0)
    root.f = def_function.function(
        lambda x: captured_constant * x,
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)],
    )
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertEqual(4.0, self.evaluate(imported.f(constant_op.constant(2.0))))

  def test_control_outputs(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    exported = autotrackable.AutoTrackable()
    exported.v = variables.Variable(1.0)
    exported.f = def_function.function(
        lambda: exported.v.assign(2.0, name="should_be_control_output")
    )
    exported_graph = exported.f.get_concrete_function().graph
    self.assertIn(
        exported_graph.get_operation_by_name("should_be_control_output"),
        exported_graph.control_outputs,
    )

    imported = cycle(exported, cycles, use_cpp_bindings=use_cpp_bindings)
    # Calling get_concrete_function wraps in a second call operation; we want to
    # inspect the original function body for the control output; digging into
    # graph.as_graph_def() and its FunctionDefLibrary is another option.
    (imported_concrete,) = imported.f.concrete_functions
    imported_graph = imported_concrete.graph
    self.assertIn(
        imported_graph.get_operation_by_name("should_be_control_output"),
        imported_graph.control_outputs,
    )

  def _make_asset(self, contents):
    fd, filename = tempfile.mkstemp(prefix=self.get_temp_dir())
    with os.fdopen(fd, "w") as f:
      f.write(contents)
    return filename

  @test_util.run_in_graph_and_eager_modes
  def test_assets(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    file1 = self._make_asset("contents 1")
    file2 = self._make_asset("contents 2")

    root = autotrackable.AutoTrackable()
    root.asset1 = asset.Asset(file1)
    root.asset2 = asset.Asset(file2)

    save_dir = os.path.join(self.get_temp_dir(), "save_dir")
    save.save(root, save_dir)

    file_io.delete_file(file1)
    file_io.delete_file(file2)
    load_dir = os.path.join(self.get_temp_dir(), "load_dir")
    file_io.rename(save_dir, load_dir)

    imported = test_load(load_dir, use_cpp_bindings=use_cpp_bindings)
    with open(self.evaluate(imported.asset1.asset_path), "r") as f:
      self.assertEqual("contents 1", f.read())
    with open(self.evaluate(imported.asset2.asset_path), "r") as f:
      self.assertEqual("contents 2", f.read())

  def test_cond_prune(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    x_in = []
    x_out = []

    def f(x, y):
      x_in.append(x)
      xx = cond_v2.cond_v2(
          math_ops.less(1, 2),
          lambda: x + 1,
          lambda: x + 2,
      )
      x_out.append(xx)
      return xx, 2 * y

    f_wrapped = wrap_function.wrap_function(
        f, [tensor_spec.TensorSpec((), dtypes.float32)] * 2
    )
    f_pruned = f_wrapped.prune(x_in[0], [x_out[0]])

    class Adder(module.Module):

      @def_function.function(
          input_signature=[
              tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)
          ]
      )
      def add(self, x):
        return f_pruned(x)

    root = Adder()
    root.add(constant_op.constant(1.0))
    root = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    root.add(constant_op.constant(1.0))

  def test_capture_assets(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = autotrackable.AutoTrackable()
    root.vocab = asset.Asset(self._make_asset("contents"))
    root.f = def_function.function(
        lambda: root.vocab.asset_path, input_signature=[]
    )
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    original_output = root.f().numpy()
    imported_output = imported.f().numpy()
    self.assertNotEqual(original_output, imported_output)
    with open(imported_output, "r") as f:
      self.assertEqual("contents", f.read())

  def test_capture_assets_in_graph(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = autotrackable.AutoTrackable()
    root.vocab = asset.Asset(self._make_asset("contents"))
    root.f = def_function.function(
        lambda: root.vocab.asset_path, input_signature=[]
    )

    original_output = root.f().numpy()

    if cycles > 1:
      root = cycle(root, cycles - 1, use_cpp_bindings=use_cpp_bindings)
    path = tempfile.mkdtemp(prefix=self.get_temp_dir())
    save.save(root, path)

    with ops.Graph().as_default():
      imported = test_load(path, use_cpp_bindings=use_cpp_bindings)
      imported_tensor = imported.f()
      with monitored_session.MonitoredSession() as sess:
        imported_output = sess.run(imported_tensor)
        self.assertLen(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS), 1)
        self.assertNotEqual(original_output, imported_output)
        with open(imported_output, "r") as f:
          self.assertEqual("contents", f.read())

  def test_dedup_assets(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    vocab = self._make_asset("contents")
    root = autotrackable.AutoTrackable()
    root.asset1 = asset.Asset(vocab)
    root.asset2 = asset.Asset(vocab)
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertEqual(
        imported.asset1.asset_path.numpy(), imported.asset2.asset_path.numpy()
    )

  def test_asset_fspath(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    vocab = pathlib.Path(self._make_asset("contents"))
    root = autotrackable.AutoTrackable()
    root.asset = asset.Asset(vocab)
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertTrue(hasattr(imported, "asset"))

  def test_implicit_input_signature(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @def_function.function
    def func(x):
      return 2 * x

    root = autotrackable.AutoTrackable()
    root.f = func

    # Add two traces.
    root.f(constant_op.constant(1.0))
    root.f(constant_op.constant(1))

    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    self.assertEqual(4.0, imported.f(constant_op.constant(2.0)).numpy())
    self.assertEqual(14, imported.f(constant_op.constant(7)).numpy())

  def test_explicit_input_signature(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @def_function.function(
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]
    )
    def func(x):
      return 2 * x

    root = autotrackable.AutoTrackable()
    root.f = func

    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertEqual(4.0, imported.f(constant_op.constant(2.0)).numpy())

  def test_explicit_save_signature(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @def_function.function
    def func(x):
      return 2 * x

    root = autotrackable.AutoTrackable()
    root.f = func

    imported = cycle(
        root,
        cycles,
        signatures={
            "f": root.f.get_concrete_function(
                tensor_spec.TensorSpec(None, dtypes.float32)
            )
        },
        use_cpp_bindings=use_cpp_bindings,
    )
    self.assertEqual(4.0, imported.f(constant_op.constant(2.0)).numpy())

  def test_nested_functions(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    f = def_function.function(
        lambda x: x * 2.0,
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)],
    )
    g = def_function.function(
        lambda x: f(x) + 1.0,
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)],
    )

    root = autotrackable.AutoTrackable()
    root.g = g
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    imported.g(constant_op.constant([1.0]))

  def test_function_with_default_bool_input(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    def func(x, training=False):
      if training:
        return 2 * x
      else:
        return 7

    root = autotrackable.AutoTrackable()
    root.f = def_function.function(func)

    self.assertEqual(20, root.f(constant_op.constant(10), True).numpy())
    self.assertEqual(7, root.f(constant_op.constant(1)).numpy())
    self.assertEqual(2, root.f(constant_op.constant(1), True).numpy())

    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    self.assertEqual(4, imported.f(constant_op.constant(2), True).numpy())
    self.assertEqual(7, imported.f(constant_op.constant(2)).numpy())

  def test_function_with_defaults_input_tensor(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @def_function.function(input_signature=[tensor_spec.TensorSpec([])])
    def func(x=constant_op.constant(5.0)):
      return x

    root = autotrackable.AutoTrackable()
    root.f = func

    self.assertAllEqual(5.0, root.f())
    self.assertAllEqual(7.0, root.f(7.0))

    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    self.assertEqual(5.0, imported.f().numpy())
    self.assertEqual(7.0, imported.f(constant_op.constant(7.0)).numpy())

    # imported.signatures with defaults are not supported.
    # TODO(b/277814477) support defaults in loaded.signatures
    # self.assertEqual(
    #     {"output_0": 5.0},
    #     self.evaluate(
    #         imported.signatures["serving_default"]()
    #     ),
    # )

  def test_function_with_defaults_input_numpy(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @def_function.function(input_signature=[tensor_spec.TensorSpec([])])
    def func(x=np.array(5.0)):
      return x

    root = autotrackable.AutoTrackable()
    root.f = func

    self.assertAllEqual(5.0, root.f())
    self.assertAllEqual(7.0, root.f(np.array(7.0)))

    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    self.assertEqual(5.0, imported.f().numpy())
    self.assertEqual(7.0, imported.f(np.array(7.0)).numpy())

    # imported.signatures with defaults are not supported.
    # TODO(b/277814477) support defaults in loaded.signatures
    # self.assertEqual(
    #     {"output_0": 5.0},
    #     self.evaluate(
    #         imported.signatures["serving_default"]()
    #     ),
    # )

  def test_function_with_default_none_input(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    def func(x, dtype=None):
      if dtype:
        return array_ops.zeros(shape=x.shape, dtype=dtype)
      else:
        return array_ops.zeros(shape=x.shape, dtype=dtypes.float32)

    root = autotrackable.AutoTrackable()
    root.f = def_function.function(func)

    self.assertAllEqual(
        [0.0, 0.0, 0.0], root.f(constant_op.constant([1, 2, 3])).numpy()
    )
    self.assertAllEqual(
        [0.0, 0.0, 0.0], root.f(constant_op.constant([1.0, 2.0, 3.0])).numpy()
    )
    self.assertAllEqual(
        [0.0, 0.0, 0.0, 0.0], root.f(constant_op.constant([1, 2, 3, 4])).numpy()
    )
    self.assertAllEqual(
        [0, 0, 0],
        root.f(
            constant_op.constant([1.0, 2.0, 3.0]), dtype=dtypes.int32
        ).numpy(),
    )

    concrete_functions = root.f._list_all_concrete_functions_for_serialization()  # pylint: disable=protected-access
    self.assertLen(concrete_functions, 4)

    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    restored_concrete_functions = imported.f._list_all_concrete_functions()  # pylint: disable=protected-access
    self.assertLen(restored_concrete_functions, 4)

    self.assertAllEqual(
        [0.0, 0.0, 0.0],
        imported.f(constant_op.constant([1, 2, 3]), None).numpy(),
    )
    self.assertAllEqual(
        [0.0, 0.0, 0.0],
        imported.f(constant_op.constant([1.0, 2.0, 3.0])).numpy(),
    )
    self.assertAllEqual(
        [0.0, 0.0, 0.0, 0.0],
        imported.f(constant_op.constant([1, 2, 3, 4])).numpy(),
    )
    self.assertAllEqual(
        [0, 0, 0],
        imported.f(
            constant_op.constant([1.0, 2.0, 3.0]), dtype=dtypes.int32
        ).numpy(),
    )

  def test_function_with_str_bytes_input(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @def_function.function
    def func(x, y):
      return string_ops.string_join([x, y])

    root = autotrackable.AutoTrackable()
    root.f = func

    self.assertAllEqual(b"ab", root.f("a", "b"))
    self.assertAllEqual(b"ab", root.f("a", constant_op.constant("b")))
    self.assertAllEqual(b"ab", root.f(constant_op.constant("a"), "b"))

    concrete_functions = root.f._list_all_concrete_functions_for_serialization()  # pylint: disable=protected-access
    self.assertLen(concrete_functions, 3)

    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    restored_concrete_functions = imported.f._list_all_concrete_functions()  # pylint: disable=protected-access
    self.assertLen(restored_concrete_functions, 3)

    self.assertAllEqual(b"ab", imported.f("a", "b"))
    self.assertAllEqual(b"ab", imported.f("a", constant_op.constant("b")))
    self.assertAllEqual(b"ab", imported.f(constant_op.constant("a"), "b"))

  def test_function_no_return(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    class TrackableWithOneVariable(autotrackable.AutoTrackable):

      def __init__(self, initial_value=0.0):
        super(TrackableWithOneVariable, self).__init__()
        self.variable = variables.Variable(initial_value)

      @def_function.function
      def increase(self, by=1.0):
        self.variable.assign_add(by)

    obj = TrackableWithOneVariable(5.0)

    obj.increase(constant_op.constant(10.0))
    self.assertEqual(15.0, obj.variable.numpy())
    obj.increase()
    self.assertEqual(16.0, obj.variable.numpy())

    imported = cycle(obj, cycles, use_cpp_bindings=use_cpp_bindings)

    imported.increase(constant_op.constant(10.0))
    self.assertEqual(26.0, imported.variable.numpy())
    imported.increase(constant_op.constant(1.0))
    self.assertEqual(27.0, imported.variable.numpy())

  def test_structured_inputs(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    def func(x, training=True):
      # x is a nested structure, we care about one particular tensor.
      _, (a, b) = x
      if training:
        return 2 * a["a"] + b
      else:
        return 7

    root = autotrackable.AutoTrackable()
    root.f = def_function.function(func)

    x = constant_op.constant(10)
    y = constant_op.constant(11)

    input1 = [6, ({"a": x}, y)]
    input2 = [7, ({"a": x}, y)]  # Not compatible with input1 signature.
    input3 = [6, ({"a": y}, x)]  # Compatible with input1 signature.

    # Note: by only calling f(input1) before serialization, only inputs with
    # matching signature will be valid on the loaded model.
    self.assertEqual(31, root.f(input1).numpy())

    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    with self.assertRaisesRegex(
        ValueError, "Could not find matching concrete function to call"
    ):
      imported.f(input2)

    self.assertEqual(31, imported.f(input1).numpy())
    self.assertEqual(32, imported.f(input3).numpy())

  def test_structured_inputs_bare_concrete_function(
      self, cycles, use_cpp_bindings
  ):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    def func(x, training=True):
      # x is a nested structure, we care about one particular tensor.
      _, (a, b) = x
      if training:
        return 2 * a["a"] + b
      else:
        return 7

    x = constant_op.constant(10)
    y = constant_op.constant(11)

    input1 = [6, ({"a": x}, y)]
    input2 = [7, ({"a": x}, y)]  # Not compatible with input1 signature.
    input3 = [6, ({"a": y}, x)]  # Compatible with input1 signature.

    root = autotrackable.AutoTrackable()
    root.f = def_function.function(func).get_concrete_function(input1)

    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    with self.assertRaises(TypeError):
      imported.f(input2)

    self.assertEqual(31, imported.f(input1, True).numpy())
    self.assertEqual(32, imported.f(input3, True).numpy())

  def test_structured_output(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    # Use fields with non-alphabetical order
    named_tuple_type = collections.namedtuple("NamedTupleHello", ["b", "a"])

    def func(input1, input2):
      named_tuple = named_tuple_type(a=input1 + input2, b=input1 * input2)
      return [named_tuple, input2, {"x": 0.5}]

    root = autotrackable.AutoTrackable()
    root.f = def_function.function(func)

    result = root.f(constant_op.constant(2), constant_op.constant(3))

    self.assertEqual(5, result[0].a.numpy())
    self.assertEqual(6, result[0].b.numpy())
    self.assertEqual(["b", "a"], list(result[0]._asdict().keys()))
    self.assertEqual(3, result[1].numpy())
    self.assertEqual(0.5, result[2]["x"].numpy())

    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    result = imported.f(constant_op.constant(2), constant_op.constant(5))
    self.assertEqual(7, result[0].a.numpy())
    self.assertEqual(10, result[0].b.numpy())
    self.assertEqual(["b", "a"], list(result[0]._asdict().keys()))
    self.assertEqual(5, result[1].numpy())
    self.assertEqual(0.5, result[2]["x"].numpy())

  def testConcreteFunctionType(self, cycles, use_cpp_bindings):
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    y = constant_op.constant(1)

    @def_function.function
    def foo(x):
      return {"input": x, "capture": y}

    root = autotrackable.AutoTrackable()
    root.f = foo.get_concrete_function(tensor_spec.TensorSpec([], dtypes.int32))

    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    x = constant_op.constant(2)
    output = imported.f(x)
    self.assertEqual(set(output.keys()), {"input", "capture"})
    self.assertEqual(output["input"].numpy(), 2)
    self.assertEqual(output["capture"].numpy(), 1)

    parameters = list(imported.f.function_type.parameters.values())
    self.assertLen(parameters, 1)
    self.assertEqual(parameters[0].name, "x")
    self.assertEqual(
        parameters[0].type_constraint,
        tensor_spec.TensorSpec([], dtypes.int32, name="x"),
    )

    captures = imported.f.function_type.captures
    self.assertLen(captures, 1)
    self.assertEqual(
        list(captures.values())[0], tensor_spec.TensorSpec([], dtypes.int32)
    )

    output = imported.f.function_type.output
    self.assertEqual(
        output.mapping,
        {
            "input": tensor_spec.TensorSpec(
                shape=(), dtype=dtypes.int32, name="input"
            ),
            "capture": tensor_spec.TensorSpec(
                shape=(), dtype=dtypes.int32, name="capture"
            ),
        },
    )

  def test_pretty_print_signature(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    named_tuple_type = collections.namedtuple("NamedTupleHello", ["b", "a"])

    def func(input1, input2):
      named_tuple = named_tuple_type(a=input1 + input2, b=input1 * input2)
      return [named_tuple, input2, {"x": 0.5}]

    root = autotrackable.AutoTrackable()
    root.f = def_function.function(func).get_concrete_function(
        constant_op.constant(2), constant_op.constant(3)
    )

    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertEqual(
        imported.f.pretty_printed_signature(),
        "Input Parameters:\n"
        + "  input1 (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(),"
        " dtype=tf.int32, name='input1')\n"
        + "  input2 (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(),"
        " dtype=tf.int32, name='input2')\n"
        + "Output Type:\n"
        + "  List[NamedTupleHello[['b', TensorSpec(shape=(), dtype=tf.int32,"
        " name='tensor_0_b')], ['a', TensorSpec(shape=(), dtype=tf.int32,"
        " name='tensor_0_a')]], TensorSpec(shape=(), dtype=tf.int32,"
        " name='tensor_1'), Dict[['x', TensorSpec(shape=(), dtype=tf.float32,"
        " name='tensor_2_x')]]]\n"
        + "Captures:\n"
        + "  None",
    )

  def test_positional_arguments(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    def func(x, training=False, abc=7.1, defg=7.7):
      del abc
      if training:
        return 2 * x
      if defg == 7:
        return 6
      else:
        return 7

    root = autotrackable.AutoTrackable()
    root.f = def_function.function(func)

    self.assertEqual(20, root.f(constant_op.constant(10), True).numpy())
    self.assertEqual(7, root.f(constant_op.constant(1)).numpy())
    self.assertEqual(2, root.f(constant_op.constant(1), True).numpy())
    self.assertEqual(6, root.f(constant_op.constant(1), defg=7.0).numpy())

    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    self.assertEqual(4, imported.f(constant_op.constant(2), True).numpy())
    self.assertEqual(7, imported.f(constant_op.constant(2)).numpy())
    self.assertEqual(6, imported.f(constant_op.constant(1), defg=7.0).numpy())

  def test_additional_kwargs(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    def func(x, training=False, **options):
      del options
      if training:
        return 2 * x
      else:
        return 7

    root = autotrackable.AutoTrackable()
    root.f = def_function.function(func)

    x = constant_op.constant(10)
    self.assertEqual(7, root.f(x, learning_rate=0.5, epochs=3).numpy())

    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    with self.assertRaisesRegex(
        ValueError, "Could not find matching concrete function to call.*"
    ):
      imported.f(x, learning_rate=0.5, epochs=4)

    self.assertEqual(7, imported.f(x, learning_rate=0.5, epochs=3).numpy())

  def test_member_function(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    class TrackableWithMember(autotrackable.AutoTrackable):

      def __init__(self):
        super(TrackableWithMember, self).__init__()
        self._some_value = 20

      @def_function.function
      def f(self, x, training=False):
        if training:
          return 2 * x
        else:
          return 7 + self._some_value

    root = TrackableWithMember()

    self.assertEqual(20, root.f(constant_op.constant(10), True).numpy())
    self.assertEqual(27, root.f(constant_op.constant(1)).numpy())
    self.assertEqual(2, root.f(constant_op.constant(1), True).numpy())

    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    self.assertEqual(4, imported.f(constant_op.constant(2), True).numpy())
    self.assertEqual(27, imported.f(constant_op.constant(2)).numpy())

  def test_side_effect_listing(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    class M(autotrackable.AutoTrackable):

      def __init__(self):
        super(M, self).__init__()
        self.var = None

      @def_function.function(
          input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]
      )
      def f(self, x):
        if self.var is None:
          self.var = variables.Variable(2.0)
        return x * self.var

    m = M()
    cycle(m, cycles)
    self.assertEqual(4.0, m.f(constant_op.constant(2.0)).numpy())

  def test_basic_backprop(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    weight = variables.Variable(1.0, trainable=True)
    bias = variables.Variable(0.0, trainable=True)
    g = def_function.function(
        lambda x: x * weight + bias,
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)],
    )

    root = autotrackable.AutoTrackable()
    root.weight = weight
    root.bias = bias
    root.g = g
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    with backprop.GradientTape() as t:
      x = constant_op.constant([3.5])
      loss = imported.g(x)
      grad = t.gradient(loss, [imported.weight, imported.bias])
      self.assertAllClose(grad, [3.5, 1.0])

  def test_nested_backprop(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    weight = variables.Variable(1.0, trainable=True)
    bias = variables.Variable(0.0, trainable=True)

    # Note: this function gets called from other function defs via a
    # "PartitionedCall" op node.
    @def_function.function(
        input_signature=[
            tensor_spec.TensorSpec(None, dtypes.float32),
            tensor_spec.TensorSpec(None, dtypes.float32),
        ]
    )
    def mul(x, y):
      return x * y

    # Note: this function gets called from other function defs via a
    # "StatefulPartitionedCall" op node.
    @def_function.function(
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]
    )
    def f(x):
      return mul(weight.read_value(), x)

    @def_function.function(
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]
    )
    def g(x):
      return (f(x) + bias,)

    @def_function.function(
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]
    )
    def h(x):
      return (g(x) + bias,)

    root = autotrackable.AutoTrackable()
    root.weight = weight
    root.bias = bias
    root.g = h

    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    with backprop.GradientTape() as t:
      x = constant_op.constant([3.5])
      loss = imported.g(x)
    grad = t.gradient(loss, [imported.weight, imported.bias])
    self.assertAllClose(grad, [3.5, 2.0])

  def test_while_loop_backprop(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    weight = variables.Variable(2.0, trainable=True)

    @def_function.function(
        input_signature=[
            tensor_spec.TensorSpec(dtype=dtypes.float32, shape=(None, None))
        ]
    )
    def g(x):
      """Adds rows of matrix x after multiplying each entry by v."""
      i_0 = constant_op.constant(0)
      s_0 = constant_op.constant([0.0, 0.0])
      cond = lambda i, _: i < array_ops.shape(x)[1]
      body = lambda i, s: (i + 1, s + weight * x[:, i])
      i_end, s_end = while_loop.while_loop(cond, body, (i_0, s_0))
      del i_end
      return s_end

    root = autotrackable.AutoTrackable()
    root.weight = weight
    root.g = g
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    def get_gradient(obj):
      with backprop.GradientTape() as t:
        x = constant_op.constant([[1.0, 2.0, 3.0], [1.0, -2, 3.0]])
        y = obj.g(x)
        self.assertAllClose(y, obj.weight * [6.0, 2.0])
        loss = math_ops.reduce_sum(y)  # weight * 8.
        self.assertAllEqual(t.watched_variables(), [obj.weight])
        return t.gradient(loss, obj.weight)

    imported_gradient = get_gradient(imported)
    original_gradient = get_gradient(root)
    self.assertIsNotNone(original_gradient)
    self.assertAllClose(original_gradient, 8.0)
    self.assertIsNotNone(imported_gradient)
    self.assertAllClose(imported_gradient, 8.0)

  def _test_restored_func_with_captured_var_backprop(
      self, cycles, use_cpp_bindings, dtype
  ):
    weight = variables.Variable(2.0, trainable=True, dtype=dtype)

    @def_function.function(
        input_signature=[tensor_spec.TensorSpec(dtype=dtype, shape=())]
    )
    def g(x):
      return x * weight

    root = autotrackable.AutoTrackable()
    root.weight = weight
    root.g = g
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    def get_gradient(obj):
      with backprop.GradientTape() as t:
        x = constant_op.constant(2.0, dtype=dtype)
        y = obj.g(x)
        self.assertAllClose(y, obj.weight * 2.0)
        self.assertAllEqual(t.watched_variables(), [obj.weight])
        return t.gradient(y, obj.weight)

    imported_gradient = get_gradient(imported)
    original_gradient = get_gradient(root)
    self.assertIsNotNone(original_gradient)
    self.assertAllClose(original_gradient, 2.0)
    self.assertIsNotNone(imported_gradient)
    self.assertAllClose(imported_gradient, 2.0)

  def test_nested_fn_backprop(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    weight = variables.Variable(2.0, trainable=True)

    @def_function.function(
        input_signature=[
            tensor_spec.TensorSpec(dtype=dtypes.float32, shape=(None, None))
        ]
    )
    def g(x):
      weight.read_value()  # Just get the tape to watch the variable
      handle = array_ops.identity(weight.handle)

      @def_function.function
      def launder_var_handle():
        return array_ops.identity(handle)

      return x + resource_variable_ops.read_variable_op(
          launder_var_handle(), dtypes.float32
      )

    root = autotrackable.AutoTrackable()
    root.weight = weight
    root.g = g
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    def get_gradient(obj, persistent):
      with backprop.GradientTape(persistent=persistent) as t:
        x = constant_op.constant([[1.0, 2.0, 3.0], [1.0, -2, 3.0]])
        y = obj.g(x)
        self.assertAllClose(y, obj.weight + x)
        loss = math_ops.reduce_sum(y)
        return t.gradient(loss, obj.weight)

    imported_gradient = get_gradient(imported, persistent=False)
    original_gradient = get_gradient(root, persistent=False)
    self.assertIsNotNone(original_gradient)
    self.assertAllClose(original_gradient, 6.0)
    self.assertIsNotNone(imported_gradient)
    self.assertAllClose(imported_gradient, 6.0)

  def test_restored_func_with_captured_var_backprop_float32(
      self, cycles, use_cpp_bindings
  ):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    self._test_restored_func_with_captured_var_backprop(
        cycles, use_cpp_bindings, dtypes.float32
    )

  def test_restored_func_with_captured_var_backprop_float64(
      self, cycles, use_cpp_bindings
  ):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    self._test_restored_func_with_captured_var_backprop(
        cycles, use_cpp_bindings, dtypes.float64
    )

  def test_callable(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    class M1(autotrackable.AutoTrackable):

      @def_function.function(
          input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]
      )
      def __call__(self, x):
        return x

    root = autotrackable.AutoTrackable()
    root.m1 = M1()
    root.m2 = autotrackable.AutoTrackable()
    root.m2.__call__ = def_function.function(
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]
    )(lambda x: x * 3.0)
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    x = constant_op.constant(1.0)

    self.assertTrue(callable(imported.m1))
    self.assertAllEqual(root.m1(x), imported.m1(x))

    # Note: `root.m2` was not callable since `__call__` attribute was set
    # into the instance and not on the class. But after a serialization cycle
    # that starts to work.
    self.assertTrue(callable(imported.m2))
    self.assertAllEqual(root.m2.__call__(x), imported.m2(x))

    # Verify that user objects without `__call__` attribute are not callable.
    self.assertFalse(callable(imported))

  def test_chain_callable(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    func = def_function.function(
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]
    )(lambda x: x * 3.0)
    root = autotrackable.AutoTrackable()
    root.__call__ = autotrackable.AutoTrackable()
    root.__call__.__call__ = autotrackable.AutoTrackable()
    root.__call__.__call__.__call__ = func

    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertTrue(callable(imported))
    x = constant_op.constant(1.0)
    self.assertAllEqual(imported(x).numpy(), 3.0)

  def test_load_in_graph_mode(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = autotrackable.AutoTrackable()
    root.v1 = variables.Variable(1.0, name="v_one", trainable=False)
    root.v2 = variables.Variable(2.0, name="v_two", trainable=True)
    root.f = def_function.function(
        lambda x: root.v2 * x,
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)],
    )

    if cycles > 1:
      root = cycle(root, cycles - 1, use_cpp_bindings=use_cpp_bindings)
    path = tempfile.mkdtemp(prefix=self.get_temp_dir())
    save.save(root, path)

    with ops.Graph().as_default() as g:
      imported = test_load(path, use_cpp_bindings=use_cpp_bindings)
      var_v1 = imported.v1
      self.assertFalse(var_v1.trainable)
      var_v2 = imported.v2
      self.assertTrue(var_v2.trainable)
      output = imported.f(constant_op.constant(2.0))
      with monitored_session.MonitoredSession() as sess:
        self.assertEqual(1.0, sess.run(var_v1))
        self.assertEqual(4.0, sess.run(output))
      self.assertCountEqual(
          [var_v1, var_v2], g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
      )
      # load() should not add to TRAINABLE_VARIABLES. Higher levels of model
      # building control retraining or frozen use of imported SavedModels.
      self.assertCountEqual(
          [], g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
      )

  def test_load_in_func_graph(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = autotrackable.AutoTrackable()
    root.v1 = variables.Variable(1.0)
    root.v2 = variables.Variable(2.0)
    root.f = def_function.function(
        lambda x: root.v2 * x,
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)],
    )

    if cycles > 1:
      root = cycle(root, cycles - 1, use_cpp_bindings=use_cpp_bindings)
    path = tempfile.mkdtemp(prefix=self.get_temp_dir())
    save.save(root, path)

    closure = autotrackable.AutoTrackable()

    @def_function.function
    def func(x):
      if not hasattr(closure, "model"):
        closure.model = load.load(path)
      return closure.model.f(x)

    inputs = constant_op.constant(2.0)
    self.assertEqual(4.0, func(inputs).numpy())

  def test_soft_matching(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @def_function.function(
        input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)]
    )
    def func(x):
      return 2 * x

    root = autotrackable.AutoTrackable()
    root.f = func

    self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy())
    self.assertAllEqual([2, 4], root.f(constant_op.constant([1, 2])).numpy())

    concrete_functions = root.f._list_all_concrete_functions_for_serialization()  # pylint: disable=protected-access
    self.assertLen(concrete_functions, 1)

    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    restored_concrete_functions = imported.f._list_all_concrete_functions()  # pylint: disable=protected-access
    self.assertLen(restored_concrete_functions, 1)

    with self.assertRaisesRegex(
        TypeError, "Binding inputs to tf.function failed"
    ):
      # We cannot call the function with a constant of shape ().
      imported.f(constant_op.constant(2)).numpy()

    # TODO(vbardiovsky): When classes are revived with input_signatures, we
    # should also check that the calls below are not generating any more
    # concrete functions.
    self.assertAllEqual(
        [2, 4, 6, 8], imported.f(constant_op.constant([1, 2, 3, 4])).numpy()
    )
    self.assertAllEqual(
        [2, 4, 6], imported.f(constant_op.constant([1, 2, 3])).numpy()
    )

  def test_jit_compile(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    # It'd be nice to use parameterize here, but the library does not support
    # having parameterized test methods inside already-parameterized classes.
    for jit_compile in (None, True, False):

      @def_function.function(jit_compile=jit_compile)
      def f(x):
        return x + 1.0

      root = module.Module()
      root.f = f
      save_dir = os.path.join(self.get_temp_dir(), "saved_model")
      save.save(root, save_dir)

      imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

      self.assertEqual(imported.f._jit_compile, jit_compile)

  def test_get_concrete_function(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @def_function.function
    def func(x, training=False):
      if training:
        return 2 * x
      else:
        return 3 * x

    func.get_concrete_function(
        tensor_spec.TensorSpec([None], dtypes.int32), True
    )
    func.get_concrete_function(tensor_spec.TensorSpec([None], dtypes.float32))

    root = autotrackable.AutoTrackable()
    root.f = func

    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    concrete = imported.f.get_concrete_function(
        training=True, x=tensor_spec.TensorSpec([None], dtypes.int32)
    )

    self.assertAllEqual(
        [2, 4, 6, 8], concrete(x=constant_op.constant([1, 2, 3, 4])).numpy()
    )
    with self.assertRaisesRegex(
        ValueError, "Could not find matching concrete function to call"
    ):
      imported.f.get_concrete_function(
          tensor_spec.TensorSpec([None], dtypes.int32)
      )
    imported.f.get_concrete_function(
        tensor_spec.TensorSpec([None], dtypes.int32), True
    )

  def test_concrete_function(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @def_function.function(
        input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)]
    )
    def func(x):
      return 2 * x

    root = autotrackable.AutoTrackable()
    root.f = func.get_concrete_function()

    self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy())
    self.assertAllEqual([2, 4], root.f(constant_op.constant([1, 2])).numpy())

    # TODO(andresp): Fix exporting of loaded concrete functions as signatures.
    imported = cycle(
        root, cycles, signatures={}, use_cpp_bindings=use_cpp_bindings
    )

    self.assertAllEqual(
        [2, 4, 6, 8], imported.f(constant_op.constant([1, 2, 3, 4])).numpy()
    )
    self.assertAllEqual(
        [2, 4, 6], imported.f(constant_op.constant([1, 2, 3])).numpy()
    )

  def test_concrete_function_captures(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    class Root(module.Module):

      def __init__(self):
        self.v = variables.Variable(1.0)
        self.v1 = variables.Variable(1.0)

      @def_function.function(
          input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]
      )
      def use_v(self, x):
        return self.v + self.v1 + 1.0

    root = Root()
    self.assertIn(
        root.v.handle,
        root.use_v.get_concrete_function().graph.external_captures,
    )
    root = cycle(
        root,
        cycles,
        signatures=root.use_v.get_concrete_function(),
        use_cpp_bindings=use_cpp_bindings,
    )
    func_captures = root.use_v.get_concrete_function().graph.external_captures
    self.assertLen(func_captures, 2)
    self.assertTrue(any(root.v.handle is t for t in func_captures))
    self.assertTrue(any(root.v1.handle is t for t in func_captures))
    signature_captures = root.signatures[
        "serving_default"
    ].graph.external_captures
    self.assertLen(signature_captures, 2)
    self.assertTrue(any(root.v.handle is t for t in signature_captures))
    self.assertTrue(any(root.v1.handle is t for t in signature_captures))

  def test_concrete_function_arg_names(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @def_function.function(
        input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)]
    )
    def func(x):
      return 2 * x

    root = autotrackable.AutoTrackable()
    root.f = func.get_concrete_function()

    self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy())

    # TODO(andresp): Fix exporting of loaded concrete functions as signatures.
    imported = cycle(
        root, cycles, signatures={}, use_cpp_bindings=use_cpp_bindings
    )

    self.assertAllEqual(
        [2, 4, 6], imported.f(x=constant_op.constant([1, 2, 3])).numpy()
    )

  def test_concrete_function_no_signature(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @def_function.function
    def func(x):
      return 2 * x

    root = autotrackable.AutoTrackable()
    root.f = func.get_concrete_function(constant_op.constant([1]))
    self.assertAllEqual([4], root.f(constant_op.constant([2])).numpy())
    # TODO(andresp): Fix exporting of loaded concrete functions as signatures.
    imported = cycle(
        root, cycles, signatures={}, use_cpp_bindings=use_cpp_bindings
    )
    self.assertAllEqual([6], imported.f(constant_op.constant([3])).numpy())

  @test_util.run_in_graph_and_eager_modes
  def test_concrete_function_backprop(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @def_function.function(
        input_signature=[tensor_spec.TensorSpec([], dtypes.float32)]
    )
    def func(x):
      return x**2.0

    root = autotrackable.AutoTrackable()
    root.f = func.get_concrete_function()

    def _compute_gradient(function):
      with backprop.GradientTape() as tape:
        inp = constant_op.constant(1.0)
        tape.watch(inp)
        output = function(inp)
      return tape.gradient(output, inp)

    self.assertAllEqual(2.0, _compute_gradient(root.f))
    # TODO(andresp): Fix exporting of loaded concrete functions as signatures.
    imported = cycle(
        root, cycles, signatures={}, use_cpp_bindings=use_cpp_bindings
    )
    self.assertAllEqual(2.0, _compute_gradient(imported.f))

  def test_revived_concrete_function_kwargs(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @def_function.function
    def func(x, y):
      return x * (y + 1.0)

    root = autotrackable.AutoTrackable()
    root.f = func.get_concrete_function(
        tensor_spec.TensorSpec([], dtypes.float32),
        tensor_spec.TensorSpec([], dtypes.float32),
    )
    self.assertEqual(
        8.0,
        root.f(
            y=constant_op.constant(3.0), x=constant_op.constant(2.0)
        ).numpy(),
    )
    # TODO(andresp): Fix exporting of loaded concrete functions as signatures.
    imported = cycle(
        root, cycles, signatures={}, use_cpp_bindings=use_cpp_bindings
    )
    self.assertEqual(
        8.0,
        imported.f(
            y=constant_op.constant(3.0), x=constant_op.constant(2.0)
        ).numpy(),
    )

  def test_revived_concrete_function_tensorspec_kwargs(
      self, cycles, use_cpp_bindings
  ):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @def_function.function
    def func(*args):
      x, y = args
      return x * (y + 1.0)

    root = autotrackable.AutoTrackable()
    root.f = func.get_concrete_function(
        tensor_spec.TensorSpec([], dtypes.float32, name="x"),
        tensor_spec.TensorSpec([], dtypes.float32, name="y"),
    )
    self.assertEqual(
        8.0,
        root.f(
            y=constant_op.constant(3.0), x=constant_op.constant(2.0)
        ).numpy(),
    )
    imported = cycle(
        root, cycles, signatures={}, use_cpp_bindings=use_cpp_bindings
    )
    self.assertEqual(
        8.0,
        imported.f(
            y=constant_op.constant(3.0), x=constant_op.constant(2.0)
        ).numpy(),
    )

  def test_concrete_function_variable_argument(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    capture = variables.Variable(0)

    @def_function.function
    def func(v):
      v.assign_add(1)
      capture.assign_sub(1)

    vsave = variables.Variable(1)
    root = autotrackable.AutoTrackable()
    root.f = func.get_concrete_function(vsave)
    root.capture = capture

    self.assertEqual(1, vsave.numpy())
    root.f(vsave)
    self.assertEqual(2, vsave.numpy())
    self.assertEqual(-1, capture.numpy())

    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    vload = variables.Variable(1)
    imported.f(vload)
    self.assertEqual(2, vload.numpy())
    self.assertEqual(-2, imported.capture.numpy())
    imported.f(v=vload)
    self.assertEqual(3, vload.numpy())
    self.assertEqual(-3, imported.capture.numpy())

    self.assertEqual(-1, capture.numpy())

  def test_function_and_component(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @def_function.function
    def func(v):
      return v + 1

    root = autotrackable.AutoTrackable()
    root.func = func
    root.concrete_func = func.get_concrete_function(
        tensor_spec.TensorSpec(None, dtypes.int32)
    )
    one = constant_op.constant(1)
    self.assertEqual(2, root.func(one).numpy())
    self.assertEqual(2, root.concrete_func(one).numpy())
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertEqual(2, imported.func(one).numpy())
    self.assertEqual(2, imported.concrete_func(one).numpy())

  def test_dict(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = autotrackable.AutoTrackable()
    root.variables = dict(a=variables.Variable(1.0))
    root.variables["b"] = variables.Variable(2.0)
    root.variables["c"] = 1
    root.funcs = dict(
        a=def_function.function(lambda: constant_op.constant(100.0))
    )
    root.funcs["conc"] = root.funcs["a"].get_concrete_function()
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertEqual(1.0, imported.variables["a"].numpy())
    self.assertEqual(2.0, imported.variables["b"].numpy())
    self.assertEqual(set(["a", "b"]), set(imported.variables.keys()))
    self.assertEqual(100.0, imported.funcs["a"]().numpy())
    self.assertEqual(100.0, imported.funcs["conc"]().numpy())

  def test_list(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = autotrackable.AutoTrackable()
    root.variables = [variables.Variable(1.0)]
    root.variables.append(1)
    root.variables.append(variables.Variable(3.0))
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertEqual(1.0, imported.variables[0].numpy())
    self.assertEqual(3.0, imported.variables[2].numpy())
    self.assertIs(None, imported.variables[1])
    self.assertLen(imported.variables, 3)

  def test_tuple(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = autotrackable.AutoTrackable()
    root.variables = (variables.Variable(1.0), 1, variables.Variable(3.0))
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertEqual(1.0, imported.variables[0].numpy())
    self.assertEqual(3.0, imported.variables[2].numpy())
    self.assertIs(None, imported.variables[1])
    self.assertLen(imported.variables, 3)

  def test_functions_list(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = autotrackable.AutoTrackable()
    v1 = variables.Variable(1.0)
    root.losses = [def_function.function(lambda: math_ops.reduce_sum(v1**2))]
    root.variables = [v1]

    @def_function.function
    def _v2_loss():
      if len(root.variables) == 1:
        v2 = variables.Variable(2.0)
        root.variables.append(v2)
      return math_ops.reduce_sum(root.variables[1] ** 2)

    root.losses.append(_v2_loss)
    self.assertAllClose([1.0, 4.0], [loss() for loss in root.losses])
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertAllClose([1.0, 4.0], [loss() for loss in imported.losses])
    imported.variables[0].assign(3.0)
    imported.variables[1].assign(4.0)
    self.assertAllClose([9.0, 16.0], [loss() for loss in imported.losses])

  def test_captured_constant(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    const = array_ops.zeros([100])
    root = autotrackable.AutoTrackable()
    root.f = def_function.function(lambda: const + 1.0)
    root.g = def_function.function(lambda: const + 2.0)
    self.assertAllClose(array_ops.ones([100]), root.f())
    self.assertAllClose(2.0 * array_ops.ones([100]), root.g())
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertAllClose(array_ops.ones([100]), imported.f())
    self.assertAllClose(2.0 * array_ops.ones([100]), imported.g())
    # TODO(b/123408994): Use the public get_concrete_function.
    f_concrete = imported.f._list_all_concrete_functions_for_serialization()[0]
    g_concrete = imported.g._list_all_concrete_functions_for_serialization()[0]
    self.assertLen(f_concrete.captured_inputs, 1)
    self.assertLen(g_concrete.captured_inputs, 1)
    # We should be using the same captured EagerTensor in both functions, not
    # duplicating the constant.
    self.assertIs(f_concrete.captured_inputs[0], g_concrete.captured_inputs[0])

  def test_functions_accessed_once(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    class Exported(autotrackable.AutoTrackable):

      def __init__(self):
        self._counter = 0

      @property
      def make_func(self):
        @def_function.function
        def f():
          return constant_op.constant(self._counter)

        f.get_concrete_function()  # force a trace
        self._counter += 1
        return f

    exported = Exported()
    imported = cycle(exported, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertEqual(0, imported.make_func().numpy())
    self.assertEqual(1, exported.make_func().numpy())

  def test_overwritten_signatures_error(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    exported = autotrackable.AutoTrackable()
    exported.f = def_function.function(lambda: constant_op.constant(1.0))
    imported = cycle(
        exported,
        cycles,
        signatures={"key": exported.f.get_concrete_function()},
        use_cpp_bindings=use_cpp_bindings,
    )
    self.assertEqual(1.0, imported.signatures["key"]()["output_0"].numpy())
    imported.signatures = {"key1": imported.signatures["key"]}
    with self.assertRaisesRegex(ValueError, "signatures"):
      save.save(imported, tempfile.mkdtemp(prefix=self.get_temp_dir()))

  def test_signature_loading(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    class Exported(autotrackable.AutoTrackable):

      def __init__(self):
        self.v = variables.Variable(3.0)

      @def_function.function
      def do(self, x):
        return self.v * x

    exported = Exported()
    imported = cycle(
        exported,
        cycles,
        signatures=exported.do.get_concrete_function(
            tensor_spec.TensorSpec(None, dtypes.float32)
        ),
        use_cpp_bindings=use_cpp_bindings,
    )
    self.assertEqual(["serving_default"], list(imported.signatures.keys()))
    imported_function = imported.signatures["serving_default"]
    two = constant_op.constant(2.0)
    self.assertEqual(6.0, imported_function(x=two)["output_0"].numpy())
    imported.v.assign(4.0)
    self.assertEqual(8.0, imported_function(x=two)["output_0"].numpy())
    self.assertEqual(8.0, imported_function(two)["output_0"].numpy())
    with self.assertRaises(TypeError):
      # The signatures mapping is immutable
      imported.signatures["random_key"] = 3

  def test_names_normalized(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    class ObjWithFunction(module.Module):

      @def_function.function(
          input_signature=[
              tensor_spec.TensorSpec([], dtype=dtypes.int32, name="A-b"),
              tensor_spec.TensorSpec([], dtype=dtypes.int32, name="A/D"),
              tensor_spec.TensorSpec([], dtype=dtypes.int32, name="bar"),
              tensor_spec.TensorSpec([], dtype=dtypes.int32, name="e"),
          ]
      )
      def foo(self, a, b, c, d=10, **options):
        del options
        return a + b + c + d

    exported = ObjWithFunction()

    with self.assertLogs(level="INFO") as logs:
      imported = cycle(exported, cycles, use_cpp_bindings=use_cpp_bindings)

    expected_message = (
        "INFO:absl:Function `foo` contains input name(s) A-b, A/D with "
        "unsupported characters which will be renamed to a_b, a_d in the "
        "SavedModel."
    )
    self.assertIn(expected_message, logs.output)

    loaded_signature = imported.signatures["serving_default"].inputs
    self.assertTrue(
        {"a_b:0", "a_d:0"}.issubset({arg.name for arg in loaded_signature}),
    )

  def test_multiple_argument_signatures_no_positional(
      self, cycles, use_cpp_bindings
  ):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    class Exported(autotrackable.AutoTrackable):

      @def_function.function
      def do(self, x, y):
        return x + y

    exported = Exported()
    imported = cycle(
        exported,
        cycles,
        signatures=exported.do.get_concrete_function(
            tensor_spec.TensorSpec(None, dtypes.float32),
            tensor_spec.TensorSpec(None, dtypes.float32),
        ),
        use_cpp_bindings=use_cpp_bindings,
    )
    with self.assertRaises(TypeError):
      imported.signatures["serving_default"](
          constant_op.constant(1.0), y=constant_op.constant(2.0)
      )
    self.assertEqual(
        {"output_0": 3.0},
        self.evaluate(
            imported.signatures["serving_default"](
                x=constant_op.constant(1.0), y=constant_op.constant(2.0)
            )
        ),
    )

  def _make_model_with_tables(self):
    default_val = -1
    keys = constant_op.constant(["brain", "salad", "surgery"])
    values = constant_op.constant([0, 1, 2], dtypes.int64)
    table1_initializer = lookup_ops.KeyValueTensorInitializer(keys, values)
    table1 = lookup_ops.HashTable(table1_initializer, default_val)

    table2_file = self._make_asset("test\nfoo\nbrain\n")
    table2_initializer = lookup_ops.TextFileIdTableInitializer(table2_file)
    table2 = lookup_ops.HashTable(table2_initializer, default_val)

    def _make_lookup_function(table):
      signature = [tensor_spec.TensorSpec(None, dtypes.string)]
      return def_function.function(input_signature=signature)(
          lambda x: table.lookup(x))  # pylint: disable=unnecessary-lambda

    root = autotrackable.AutoTrackable()
    root.table1 = table1
    root.lookup1 = _make_lookup_function(table1)
    root.table2 = table2
    root.lookup2 = _make_lookup_function(table2)
    return root

  def test_table(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = self._make_model_with_tables()
    imported = cycle(root, cycles, signatures={})
    keys = constant_op.constant(["brain", "test", "foo", "surgery"])
    self.assertAllEqual([0, -1, -1, 2], imported.lookup1(keys).numpy())
    self.assertAllEqual([2, 0, 1, -1], imported.lookup2(keys).numpy())

  def test_table_collections_untouched_eager(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    def _gather_nonempty_collections():
      graph = ops.get_default_graph()
      gathered = {}
      for collection in graph.collections:
        collection_contents = graph.get_collection(collection)
        if collection_contents:
          gathered[collection] = collection_contents
      return gathered

    root = self._make_model_with_tables()
    # Warm up collections to ignore those that don't expand every iteration,
    # e.g. the __varscope collection.
    cycle(root, 1, use_cpp_bindings=use_cpp_bindings)
    original_collections = _gather_nonempty_collections()
    cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertEqual(original_collections, _gather_nonempty_collections())

  def test_table_in_graph(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = self._make_model_with_tables()

    if cycles > 1:
      root = cycle(root, cycles - 1, use_cpp_bindings=use_cpp_bindings)
    path = tempfile.mkdtemp(prefix=self.get_temp_dir())
    save.save(root, path)
    imported = cycle(root, 1, use_cpp_bindings=use_cpp_bindings)

    with ops.Graph().as_default():
      imported = test_load(path, use_cpp_bindings=use_cpp_bindings)
      keys = constant_op.constant(["brain", "test", "foo", "surgery"])
      output1 = imported.lookup1(keys)
      output2 = imported.lookup2(keys)
      with monitored_session.MonitoredSession() as sess:
        self.assertAllEqual([0, -1, -1, 2], sess.run(output1))
        self.assertAllEqual([2, 0, 1, -1], sess.run(output2))

  def test_preserve_argspec(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    def f(a, b, c):  # pylint: disable=unused-argument
      return None

    original_fullargspec = tf_inspect.getfullargspec(f)

    root = autotrackable.AutoTrackable()
    root.f = def_function.function(f)
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    restored_fullargspec = tf_inspect.getfullargspec(imported.f)
    self.assertEqual(original_fullargspec, restored_fullargspec)

  def test_canonicalize_inputs(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @def_function.function(autograph=False)
    def func(a=1, b=2, c=3, training=True):
      if training:
        return [a, b, c, training]
      else:
        return [c, b, a, training]

    # TODO(b/123501567): Work-around to trigger generic traces of a function
    # with extra non tensor args.
    signature = 3 * [tensor_spec.TensorSpec(None, dtypes.float32)]

    @def_function.function(input_signature=signature)
    def trigger(a, b, c):
      func(a, b, c, True)
      func(a, b, c, False)

    trigger.get_concrete_function()

    root = autotrackable.AutoTrackable()
    root.f = func
    root = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertAllEqual(root.f(), [1.0, 2.0, 3.0, True])
    self.assertAllEqual(root.f(-1.0, training=False), [3.0, 2.0, -1.0, False])

    with self.assertRaisesRegex(
        ValueError, "Could not find matching concrete function"
    ):
      root.f(["hello", 1.0])

  def test_prefer_specific_trace(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @def_function.function(autograph=False)
    def func(a):
      if isinstance(a, int):
        return a
      else:
        return a + 1

    self.assertAllEqual(2, func(2).numpy())
    self.assertAllEqual(3, func(constant_op.constant(2)).numpy())

    root = autotrackable.AutoTrackable()
    root.f = func
    root = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertAllEqual(2, root.f(2).numpy())
    self.assertAllEqual(4, root.f(3).numpy())
    self.assertAllEqual(3, root.f(constant_op.constant(2)).numpy())
    self.assertAllEqual(4, root.f(constant_op.constant(3)).numpy())

  def test_partial_with_non_tensor_defaults(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    def f(x, y=3):
      return x + y

    func = def_function.function(functools.partial(f, y=5))

    root = autotrackable.AutoTrackable()
    root.f = func
    self.assertAllEqual(root.f(1), 6)

    root = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertAllEqual(root.f(1), 6)

  def test_partial_with_positional(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    def f(x, y):
      return x + y

    func = def_function.function(functools.partial(f, constant_op.constant(5)))

    root = autotrackable.AutoTrackable()
    root.f = func
    self.assertAllEqual(root.f(1), 6)

    root = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertAllEqual(root.f(1), 6)

  def test_partial_with_positional_captured_tensors(
      self, cycles, use_cpp_bindings
  ):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    def f(x, y):
      return x + y

    tensor = constant_op.constant(5) + constant_op.constant(7)
    func = def_function.function(functools.partial(f, tensor))

    root = autotrackable.AutoTrackable()
    root.f = func
    self.assertAllEqual(root.f(1), 13)

    root = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertAllEqual(root.f(1), 13)

  def test_partial_keyword_hiding_default(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    def f(x=3, training=True, y=7):
      if training:
        return x + y
      else:
        return x + y + 2

    func = def_function.function(functools.partial(f, y=6))

    root = autotrackable.AutoTrackable()
    root.f = func
    self.assertEqual(root.f().numpy(), 9)
    self.assertEqual(root.f(training=False).numpy(), 11)

    root = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertEqual(root.f().numpy(), 9)
    self.assertEqual(root.f(training=False).numpy(), 11)

  def test_partial_with_kwargs(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    def f(a, b, *args, **kwargs):
      args_sum = sum(args)
      return a + b + kwargs["some_tensor"] * kwargs["learning_rate"] + args_sum

    constant_tensor = constant_op.constant(10)
    func = def_function.function(
        functools.partial(
            f, 7, 1, 2, learning_rate=3, some_tensor=constant_tensor
        )
    )

    root = autotrackable.AutoTrackable()
    root.f = func
    self.assertEqual(root.f(constant_op.constant(4)).numpy(), 44)

    root = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertEqual(root.f(constant_op.constant(5)).numpy(), 45)

  def test_partial_bind_only_first_argument(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    if sys.version_info[0] < 3:
      self.skipTest(
          "Test is only valid in python3. Only then we get some more "
          "advanced inspection of partials where this is allowed."
      )

    def f(x, y):
      return x + y

    partial_func = functools.partial(f, x=5)
    tf_func = def_function.function(partial_func)

    root = autotrackable.AutoTrackable()
    root.f = tf_func
    self.assertAllEqual(root.f(y=constant_op.constant(7)), 12)

    root = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertAllEqual(root.f(y=constant_op.constant(9)), 14)

  def test_partial_with_passed_fn_as_default(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    def f(x, y):
      return x(3) + y

    def my_func(a):
      return 2 * a

    func = def_function.function(functools.partial(f, my_func))

    root = autotrackable.AutoTrackable()
    root.f = func
    self.assertEqual(root.f(constant_op.constant(3)).numpy(), 9)

    root = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertEqual(root.f(constant_op.constant(3)).numpy(), 9)

  def test_partial_with_input_signature(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    def full_function(a, b, c=3.0):
      return a, b, c

    partial = functools.partial(full_function, 1, c=4)
    self.assertAllEqual((1, 2.0, 4), partial(2.0))

    signature = [tensor_spec.TensorSpec([], dtypes.float32)]
    func = def_function.function(partial, input_signature=signature)

    root = autotrackable.AutoTrackable()
    root.f = func
    a, b, c = root.f(2.0)
    self.assertAllEqual([a.numpy(), b.numpy(), c.numpy()], (1, 2.0, 4))

    root = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    a, b, c = root.f(3.0)
    self.assertAllEqual([a.numpy(), b.numpy(), c.numpy()], (1, 3.0, 4))

  def test_convert_to_input_signature(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @def_function.function(
        input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)]
    )
    def func(x):
      return x

    root = autotrackable.AutoTrackable()
    root.f = func

    root = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    self.assertEqual([2], root.f([2]).numpy())

  def test_named_tuple(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    class NamedTupleType(collections.namedtuple("NamedTupleType", ["a", "b"])):
      pass

    @def_function.function
    def f(x):
      return x.a + x.b

    f.get_concrete_function(
        NamedTupleType(
            a=tensor_spec.TensorSpec(None, dtypes.float32, name="a"),
            b=tensor_spec.TensorSpec(None, dtypes.float32, name="b"),
        )
    )
    obj = autotrackable.AutoTrackable()
    obj.__call__ = f
    if sys.version_info.major == 3 and sys.version_info.minor < 5:
      # TODO(allenl): figure out why this doesn't work in Python3.4
      self.skipTest("Not working in Python 3.4")
    imported = cycle(obj, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertAllClose(
        3.0,
        imported(
            NamedTupleType(
                a=constant_op.constant(1.0), b=constant_op.constant(2.0)
            )
        ),
    )

  def test_extra_args(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @def_function.function
    def f(x):
      return math_ops.add(x["a"], 1.0)

    # Trigger a trace.
    f({"a": constant_op.constant(2.0)})

    obj = autotrackable.AutoTrackable()
    obj.__call__ = f
    imported = cycle(obj, cycles, use_cpp_bindings=use_cpp_bindings)

    self.assertEqual(4.0, imported({"a": 3.0}).numpy())

    with self.assertRaisesRegex(
        ValueError, "Could not find matching concrete function to call"
    ):
      imported({"a": 2.0, "b": 3.0})

  def test_shapes_available(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @def_function.function(
        input_signature=[
            tensor_spec.TensorSpec([None, 3], dtypes.int32),
            tensor_spec.TensorSpec([None, 2], dtypes.int32),
        ]
    )
    def func(x, y):
      return array_ops.concat([x, y], axis=1)

    root = autotrackable.AutoTrackable()
    root.f = func

    root = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)

    imported_graph = root.f.get_concrete_function().graph
    input_x, input_y = imported_graph.inputs
    self.assertEqual([None, 3], input_x.shape.as_list())
    self.assertEqual([None, 2], input_y.shape.as_list())
    (output,) = imported_graph.outputs
    self.assertEqual([None, 5], output.shape.as_list())
    signature = root.signatures["serving_default"]
    self.assertEqual([None, 3], signature.inputs[0].shape.as_list())
    self.assertEqual([None, 2], signature.inputs[1].shape.as_list())
    self.assertEqual([None, 5], signature.outputs[0].shape.as_list())

  def test_variables_destroyed(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    v1 = variables.Variable(1.0)
    weak_v1 = weakref.ref(v1)
    root = checkpoint.Checkpoint(v=v1)
    root = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    del v1
    self.assertIsNone(weak_v1())
    weak_v2 = weakref.ref(root.v)
    del root
    self.assertIsNone(weak_v2())

  def test_variable_attributes_preserved(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    v = variables.Variable(
        1.0,
        trainable=False,
        synchronization=variables.VariableSynchronization.NONE,
        aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA,
    )
    self.assertEqual(variables.VariableSynchronization.NONE, v.synchronization)
    self.assertEqual(
        variables.VariableAggregation.ONLY_FIRST_REPLICA, v.aggregation
    )
    root = autotrackable.AutoTrackable()
    root.v = v
    root = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertEqual(False, root.v.trainable)
    self.assertEqual(
        variables.VariableSynchronization.NONE, root.v.synchronization
    )
    self.assertEqual(
        variables.VariableAggregation.ONLY_FIRST_REPLICA, root.v.aggregation
    )

  def test_captured_dataset(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    class HasDataset(module.Module):

      def __init__(self):
        super(HasDataset, self).__init__()
        self.dataset = dataset_ops.Dataset.range(5).map(lambda x: x**2)

      @def_function.function
      def __call__(self, x):
        current_sum = array_ops.zeros([], dtype=dtypes.int64)
        for element in self.dataset:
          current_sum += x * element
        return current_sum

    root = HasDataset()
    self.assertEqual(
        3 * (1 + 4 + 9 + 16),
        root(constant_op.constant(3, dtype=dtypes.int64)).numpy(),
    )
    root = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertEqual(
        3 * (1 + 4 + 9 + 16),
        root(constant_op.constant(3, dtype=dtypes.int64)).numpy(),
    )

  def test_tuple_signature(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = checkpoint.Checkpoint()
    root.f = def_function.function(
        lambda: (array_ops.ones([]), array_ops.zeros([])), input_signature=()
    )
    root = cycle(
        root, cycles, signatures=root.f, use_cpp_bindings=use_cpp_bindings
    )
    self.assertEqual(
        ({"output_0": 1.0, "output_1": 0.0}),
        self.evaluate(root.signatures["serving_default"]()),
    )

  def test_version_info(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = checkpoint.Checkpoint()
    root = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertEqual(versions.__version__, root.tensorflow_version)
    self.assertEqual(versions.__git_version__, root.tensorflow_git_version)

  def test_load_grad_save(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = checkpoint.Checkpoint()
    root.v = variables.Variable(2.0)
    root.f = def_function.function(lambda x: root.v * x)
    root.g = def_function.function(root.f)
    for _ in range(cycles):
      with backprop.GradientTape() as tape:
        inp = constant_op.constant(2.0)
        tape.watch(inp)
        output = root.g(inp)
        self.assertAllClose(4.0, output)
      self.assertAllClose(2.0, tape.gradient(output, inp))
      root = cycle(root, 1, use_cpp_bindings=use_cpp_bindings)

  def test_destroy_resource(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    def get_handle():
      return resource_variable_ops.var_handle_op(
          shape=tensor_shape.as_shape([]),
          dtype=dtypes.float32,
          shared_name="my_var_name",
          name="my_var",
          container="my_container",
      )

    class MyResource(resource.TrackableResource):

      def _create_resource(self):
        return get_handle()

      def _initialize(self):
        resource_variable_ops.assign_variable_op(
            self.resource_handle, 1.0, name="assign"
        )

      def _destroy_resource(self):
        handle = get_handle()
        resource_variable_ops.destroy_resource_op(
            handle, ignore_lookup_error=True
        )

    class MyModel(autotrackable.AutoTrackable):

      def __init__(self):
        super(MyModel, self).__init__()
        self.resource = MyResource()

      @def_function.function(input_signature=[])
      def increase(self):
        handle = self.resource.resource_handle
        resource_variable_ops.assign_add_variable_op(
            handle, 10.0, name="assign_add"
        )
        return resource_variable_ops.read_variable_op(handle, dtypes.float32)

    root = MyModel()
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertEqual(11, imported.increase().numpy())  # Create the resource.

    handle = imported.resource.resource_handle

    # Delete the imported SaveModel. Since we explicitly set the deleter, it
    # should destroy the resource automatically.
    del imported

    # Try to destroy the resource again, should fail.
    with self.assertRaisesRegex(
        errors.NotFoundError, r"Resource .* does not exist."
    ):
      resource_variable_ops.destroy_resource_op(
          handle, ignore_lookup_error=False
      )

  def test_function_called_as_operation(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @framework_function.Defun(dtypes.float32)
    def inner(x):
      return x + 1.0

    @def_function.function(
        input_signature=[tensor_spec.TensorSpec([], dtypes.float32)]
    )
    def outer(x):
      return inner(x)

    root = module.Module()
    root.f = outer
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertAllClose(2.0, imported.f(constant_op.constant(1.0)))

  def test_ragged(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @def_function.function
    def f(x, c=1):
      """Returns Tensor x incremented by Python constant c."""
      return math_ops.add(x, c)

    for c in (1, 2, 3):
      _ = f.get_concrete_function(
          ragged_tensor.RaggedTensorSpec([None, None], dtype=dtypes.int32), c
      )

    obj = autotrackable.AutoTrackable()
    obj.f = f

    imported1 = cycle(
        obj, cycles, signatures={}, use_cpp_bindings=use_cpp_bindings
    )
    rt = ragged_factory_ops.constant([[1, 2], [3]])
    self.assertAllEqual(imported1.f(rt), [[2, 3], [4]])
    self.assertAllEqual(imported1.f(rt, 2), [[3, 4], [5]])
    self.assertAllEqual(imported1.f(rt, 3), [[4, 5], [6]])

    imported2 = cycle(obj, cycles, use_cpp_bindings=use_cpp_bindings)
    rt = ragged_factory_ops.constant([[1, 2], [3]])
    self.assertAllEqual(imported2.f(rt, 1), [[2, 3], [4]])
    self.assertAllEqual(imported2.f(rt, 2), [[3, 4], [5]])
    self.assertAllEqual(imported2.f(rt, 3), [[4, 5], [6]])

  def test_accepts_io_device(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    options = load_options.LoadOptions()
    self.assertIsNone(options.experimental_io_device)
    options = load_options.LoadOptions(experimental_io_device="/job:localhost")
    self.assertEqual("/job:localhost", options.experimental_io_device)

  def _custom_saveable_object(self, cycles, use_cpp_bindings):
    if context.is_tfrt_enabled():
      self.skipTest("Disable due to b/190539415.")
    root = autotrackable.AutoTrackable()
    root.table = lookup_ops.MutableHashTable(dtypes.string, dtypes.float32, -1)
    root.table.insert("foo", 15)
    root.table2 = lookup_ops.MutableHashTable(dtypes.string, dtypes.float32, -1)
    root.table2.insert("idk", 21)

    @def_function.function(
        input_signature=[tensor_spec.TensorSpec(None, dtypes.string)]
    )
    def lookup(key):
      return root.table.lookup(key)

    root.lookup = lookup

    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertEqual(self.evaluate(imported.lookup("foo")), 15)
    self.assertEqual(self.evaluate(imported.lookup("idk")), -1)

    if not saveable_compat.force_checkpoint_conversion_enabled():
      self.assertEqual(
          {"table"}, imported.table._self_saveable_object_factories.keys()
      )

  def test_load_custom_saveable_object(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    self._custom_saveable_object(cycles, use_cpp_bindings=use_cpp_bindings)

  def test_load_custom_saveable_object_ckpt_conversion(
      self, cycles, use_cpp_bindings
  ):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    # Tests custom saveable object with checkpoint conversion enabled (forces
    # Trackable-based checkpoint implementation).
    saveable_compat.force_checkpoint_conversion()
    self._custom_saveable_object(cycles, use_cpp_bindings=use_cpp_bindings)

  def test_load_resource_with_dependency(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    # Test with StaticHashTable, which has a _initializer attribute that tracks
    # the Asset vocab table.

    class MyLookupModel(autotrackable.AutoTrackable):

      def __init__(self, vocab_file):
        vocab_initializer = lookup_ops.TextFileInitializer(
            vocab_file,
            key_dtype=dtypes.string,
            key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
            value_dtype=dtypes.int64,
            value_index=lookup_ops.TextFileIndex.LINE_NUMBER,
        )
        self._vocab_table = lookup_ops.StaticHashTable(
            vocab_initializer, default_value=-1
        )

      @def_function.function(
          input_signature=[tensor_spec.TensorSpec((None,), dtypes.string)]
      )
      def __call__(self, inputs):
        return self._vocab_table.lookup(inputs)

    vocab_file = self._make_asset("\n".join(["a", "b", "c", "d"]))
    root = MyLookupModel(vocab_file)
    imported = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    file_io.delete_file(vocab_file)
    self.assertAllEqual(imported(constant_op.constant(["d", "b"])), [3, 1])

  def test_custom_gradients(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    @custom_gradient.custom_gradient
    def log1pexp(x):
      e = math_ops.exp(x)

      def grad(dy):
        return dy * e  # incorrect to check the custom gradients is respected.

      return math_ops.log(1 + e), grad

    @def_function.function
    def g(x):
      y = log1pexp(x)

      @def_function.function
      def g_nest():
        return log1pexp(y)

      return g_nest()

    @def_function.function
    def f(x):
      return log1pexp(g(x * x))

    v = variables.Variable(1.)

    with backprop.GradientTape() as tape2:
      with backprop.GradientTape() as tape:
        tape.watch(v)
        y = f(v)
        expected_grads = tape.gradient(y, v)
      expected_grad_grads = tape2.gradient(expected_grads, v)

    root = autotrackable.AutoTrackable()
    root.f = f
    loaded = cycle(
        root,
        cycles,
        save_option=save_options.SaveOptions(
            experimental_custom_gradients=True
        ),
        use_cpp_bindings=use_cpp_bindings,
    )
    with backprop.GradientTape() as tape2:
      with backprop.GradientTape() as tape:
        tape.watch(v)
        y = loaded.f(v)
        grads = tape.gradient(y, v)
      grad_grads = tape2.gradient(grads, v)

    self.assertAllClose(grads, expected_grads)
    self.assertAllClose(grad_grads, expected_grad_grads)

  def test_custom_gradients_with_none_grad(self, cycles, use_cpp_bindings):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    # https://github.com/google/jax/issues/7123

    @custom_gradient.custom_gradient
    def f(params, state):
      def grad_fn(*args):
        return args

      return (params, state), grad_fn

    @def_function.function(
        input_signature=[
            tensor_spec.TensorSpec([], dtypes.float32),
            tensor_spec.TensorSpec([], dtypes.int32),
        ]
    )
    def predict(params, state):
      return f(params, state)

    params = variables.Variable(1.0)
    # None grads only appear when state is an int.
    state = constant_op.constant(3, dtype=dtypes.int32)
    with backprop.GradientTape() as tape:
      tape.watch(params)
      y = predict(params, state)
      expected_grads = tape.gradient(y, params)

    root = autotrackable.AutoTrackable()
    root.fn = predict
    loaded = cycle(
        root,
        cycles,
        save_option=save_options.SaveOptions(
            experimental_custom_gradients=True
        ),
        use_cpp_bindings=use_cpp_bindings,
    )

    with backprop.GradientTape() as tape:
      tape.watch(params)
      y = loaded.fn(params, state)
      grads = tape.gradient(y, params)

    self.assertAllClose(grads, expected_grads)

  def test_custom_gradients_with_none_grad_and_partial_shape(
      self, cycles, use_cpp_bindings
  ):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    # https://github.com/google/jax/issues/7123

    @custom_gradient.custom_gradient
    def f(params, state):
      def grad_fn(*args):
        return args

      return (params, state), grad_fn

    @def_function.function(
        input_signature=[
            tensor_spec.TensorSpec(None, dtypes.float32),
            tensor_spec.TensorSpec(None, dtypes.int32),
        ]
    )
    def predict(params, state):
      return f(params, state)

    params = variables.Variable(1.0)
    # None grads only appear when state is an int.
    state = constant_op.constant(3, dtype=dtypes.int32)
    with backprop.GradientTape() as tape:
      tape.watch(params)
      y = predict(params, state)
      expected_grads = tape.gradient(y, params)

    root = autotrackable.AutoTrackable()
    root.fn = predict
    loaded = cycle(
        root,
        cycles,
        save_option=save_options.SaveOptions(
            experimental_custom_gradients=True
        ),
        use_cpp_bindings=use_cpp_bindings,
    )

    with backprop.GradientTape() as tape:
      tape.watch(params)
      y = loaded.fn(params, state)
      grads = tape.gradient(y, params)

    self.assertAllClose(grads, expected_grads)

  def test_signature_propagates_experimental_attr(
      self, cycles, use_cpp_bindings
  ):
    # TODO(b/264869228) Fix LoadTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = autotrackable.AutoTrackable()
    experimental_attributes = {"disable_summaries_at_runtime": ["x", True]}
    @def_function.function(
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)],
        experimental_attributes=experimental_attributes,
    )
    def f(x):
      return x * 2.0
    root.f = f
    self.assertEqual(root.f(constant_op.constant(1.0)).numpy(), 2.0)
    loaded = cycle(root, cycles, use_cpp_bindings=use_cpp_bindings)
    self.assertEqual(loaded.f(constant_op.constant(1.0)).numpy(), 2.0)
    self.assertProtoEquals(
        r"""
        list {
            s: 'x',
            b: True
        }
        """,
        loaded.signatures["serving_default"].function_def.attr[
            "disable_summaries_at_runtime"
        ],
    )


@parameterized.named_parameters(*_test_params())
class SingleCycleTests(test.TestCase, parameterized.TestCase):

  def test_load_with_tags(self, use_cpp_bindings):
    if use_cpp_bindings:
      self.skipTest("Cpp bindings do not support Tags.")
    root = autotrackable.AutoTrackable()
    path = tempfile.mkdtemp(prefix=self.get_temp_dir())
    save.save(root, path)

    with self.assertRaises(ValueError):
      load.load(path, tags=[tag_constants.EVAL])
    load.load(path, tags=[tag_constants.SERVING])
    load.load(path, tags=tag_constants.SERVING)
    load.load(path, tags=set([tag_constants.SERVING]))

  def test_save_load_contains_with_fspath(self, use_cpp_bindings):
    if use_cpp_bindings:
      self.skipTest("Cpp bindings cannot work with pathlib object.")
    root = autotrackable.AutoTrackable()
    path = pathlib.Path(tempfile.mkdtemp(prefix=self.get_temp_dir()))
    save.save(root, path)
    self.assertTrue(loader_impl.contains_saved_model(path))

    test_load(path, use_cpp_bindings=use_cpp_bindings)

  def test_single_restore_op_used(self, use_cpp_bindings):
    # TODO(b/264869753) Fix SingleCycleTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = module.Module()
    root.v1 = variables.Variable(1.0)
    root.v2 = variables.Variable(2.0)
    root.v3 = variables.Variable(3.0)
    path = tempfile.mkdtemp(prefix=self.get_temp_dir())
    save.save(root, path)
    restore_count = 0

    def _count_restores(op_type, *unused_args, **unused_kwargs):
      nonlocal restore_count
      if op_type == b"RestoreV2":
        restore_count += 1

    op_callbacks.add_op_callback(_count_restores)
    save.save(root, path)
    test_load(path, use_cpp_bindings=use_cpp_bindings)
    op_callbacks.remove_op_callback(_count_restores)
    self.assertEqual(1, restore_count)

  def test_docstring_examples(self, use_cpp_bindings):
    # TODO(b/264869753) Fix SingleCycleTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    path = tempfile.mkdtemp(prefix=self.get_temp_dir())
    exported = checkpoint.Checkpoint(v=variables.Variable(3.0))
    exported.f = def_function.function(
        lambda x: exported.v * x,
        input_signature=[
            tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)
        ],
    )
    save.save(exported, path)
    imported = test_load(path)
    self.assertEqual(3.0, imported.v.numpy())
    self.assertEqual(6.0, imported.f(x=constant_op.constant(2.0)).numpy())

    save.save(exported, path, exported.f.get_concrete_function())
    imported = test_load(path, use_cpp_bindings=use_cpp_bindings)
    f = imported.signatures["serving_default"]
    self.assertAllEqual(
        [[-3.0]], f(x=constant_op.constant([[-1.0]]))["output_0"].numpy()
    )

  def test_object_with_extra_dependencies(self, use_cpp_bindings):
    # TODO(b/264869753) Fix SingleCycleTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    class Extra(autotrackable.AutoTrackable):

      def _trackable_children(self, save_type, **kwargs):
        children = super(Extra, self)._trackable_children(save_type, **kwargs)
        children["a"] = variables.Variable(5.0)
        return children

    root = Extra()
    path = tempfile.mkdtemp(prefix=self.get_temp_dir())
    save.save(root, path)
    imported = test_load(path)
    self.assertEqual(5, self.evaluate(imported.a))

  def test_save_cached_variable(self, use_cpp_bindings):
    # TODO(b/264869753) Fix SingleCycleTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    with ops.Graph().as_default(), session_lib.Session() as session:
      obj = autotrackable.AutoTrackable()
      obj.v = variables.Variable(2.0, caching_device=lambda op: op.device)
      obj.w = variables.Variable(3.0)
      session.run([obj.v.initializer, obj.w.initializer])

      @def_function.function
      def total():
        return obj.v + obj.w

      @def_function.function(input_signature=[tensor_spec.TensorSpec([])])
      def wrapped_total(x):
        return total() + x

      @def_function.function
      def increment_v(x):
        obj.v.assign_add(x)
        return x

      session.run(increment_v(constant_op.constant(3.0)))  # generate signatures
      self.assertAllClose(8, total())
      self.assertAllClose(13, wrapped_total(constant_op.constant(5.0)))

      obj.total = total
      obj.wrapped_total = wrapped_total.get_concrete_function()
      obj.increment_v = increment_v

      save_dir = os.path.join(self.get_temp_dir(), "saved_model")
      save.save(obj, save_dir, signatures=total.get_concrete_function())
      imported = test_load(save_dir)
      session.run(variables.global_variables_initializer())
      self.assertAllClose(8, imported.total())
      session.run(imported.increment_v(4))
      self.assertAllClose(12, imported.total())
      self.assertAllClose(15, imported.wrapped_total(constant_op.constant(3.0)))
      self.assertAllClose(
          {"output_0": 12}, imported.signatures["serving_default"]()
      )

    # Try loading and running the function in eager mode
    imported = test_load(save_dir)
    self.assertAllClose(8, imported.total())
    imported.increment_v(5)
    self.assertAllClose(13, imported.total())
    self.assertAllClose(13.5, imported.wrapped_total(constant_op.constant(0.5)))
    self.assertAllClose(
        {"output_0": 13}, imported.signatures["serving_default"]()
    )

  # TODO(allenl, kkb): Use the new memory checker here once it's fast enough (3
  # iterations took hundreds of seconds). It would be really nice to check
  # allocations at a lower level.
  @test_util.assert_no_new_pyobjects_executing_eagerly()
  def test_functions_cleaned(self, use_cpp_bindings):
    # TODO(b/264869753) Fix SingleCycleTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    if sys.version_info.major < 3:
      self.skipTest("Not working in Python 2")
    root = module.Module()
    root.v = variables.Variable(1.0)
    root.f = def_function.function(
        lambda x: x + root.v,
        input_signature=[
            tensor_spec.TensorSpec(shape=[], dtype=dtypes.float32)
        ],
    )
    cycle(root, 1, use_cpp_bindings=use_cpp_bindings)

  def test_load_partial_object(self, use_cpp_bindings):
    # TODO(b/264869753) Fix SingleCycleTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = module.Module()
    root.variables_holder = module.Module()
    root.variables_holder.v = variables.Variable(1.0)

    class Adder(module.Module):

      @def_function.function(input_signature=[tensor_spec.TensorSpec(shape=[])])
      def __call__(self, y):
        root.variables_holder.v.assign_add(y)
        return 1

    root.adder = Adder()

    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
    save.save(root, save_dir)

    imported = load.load_partial(
        save_dir, ["root.variables_holder.v", "root.adder"]
    )
    v = imported["root.variables_holder.v"]
    adder = imported["root.adder"]
    self.assertEqual(self.evaluate(v), 1)
    adder(5)
    self.assertEqual(self.evaluate(v), 6)

    with self.assertRaisesRegex(
        ValueError, "does not include all required objects for loading"
    ):
      imported = load.load_partial(save_dir, ["root.adder"])

  def test_load_partial_checkpoint(self, use_cpp_bindings):
    # TODO(b/264869753) Fix SingleCycleTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    root = module.Module()
    root.variables_holder = module.Module()
    root.variables_holder.v = variables.Variable(1.0)

    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
    save.save(root, save_dir)

    loaded = module.Module()
    loaded.v = variables.Variable(2.0)

    load.load_partial(
        save_dir,
        {"root": loaded},
        options=load_options.LoadOptions(allow_partial_checkpoint=True),
    )
    self.assertEqual(loaded.variables_holder.v.numpy(), 1)
    with self.assertRaisesRegex(AssertionError, "were not bound"):
      load.load_partial(save_dir, {"root": loaded})

  def test_call_untraced_function_raises_error(self, use_cpp_bindings):
    # TODO(b/264869753) Fix SingleCycleTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    class ObjWithFunction(module.Module):

      @def_function.function
      def foo(self, a):
        return a

    root = ObjWithFunction()
    with self.assertLogs(level="INFO") as logs:
      loaded = cycle(root, 1, use_cpp_bindings=use_cpp_bindings)

    expected_save_message = (
        "INFO:absl:Found untraced functions such as foo while saving "
        "(showing 1 of 1). These functions will not be directly callable after "
        "loading."
    )
    self.assertIn(expected_save_message, logs.output)

    with self.assertRaisesRegex(
        ValueError, "Found zero restored functions for caller function."
    ):
      loaded.foo(1)

  def test_restored_function_execute_eagerly(self, use_cpp_bindings):
    # TODO(b/264869753) Fix SingleCycleTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    try:
      def_function.run_functions_eagerly(True)

      class MyModel(module.Module):

        @def_function.function
        def __call__(self, inputs, training=False):
          return math_ops.multiply(0.5, inputs)

      model = MyModel()
      model.__call__.get_concrete_function(
          tensor_spec.TensorSpec([None], dtypes.float32)
      )
      loaded = cycle(model, 1, use_cpp_bindings=use_cpp_bindings)

      # Calling the function should not throw an exception.
      loaded(constant_op.constant([1.0]))

    finally:
      def_function.run_functions_eagerly(False)

  def test_restored_model_concrete_function_is_deterministic(
      self, use_cpp_bindings
  ):
    # TODO(b/264869753) Fix SingleCycleTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    previous_concrete_function = None
    for _ in range(100):

      class MyModel(module.Module):

        @def_function.function
        def __call__(self, x):
          return x * constant_op.constant(3.0)

      model = MyModel()
      model(array_ops.ones((7, 3), dtype=dtypes.float32))
      model.__call__.get_concrete_function(
          tensor_spec.TensorSpec([None, 3], dtypes.float32)
      )
      loaded = cycle(model, 1, use_cpp_bindings=use_cpp_bindings)

      # Ensure the newly loaded concrete function is the same as the previous
      # after a cycle of serialization / deserialization.
      new_concrete_function = loaded.__call__.get_concrete_function(
          tensor_spec.TensorSpec([None, 3], dtypes.float32)
      )
      if previous_concrete_function is not None:
        self.assertEqual(
            previous_concrete_function.pretty_printed_signature(),
            new_concrete_function.pretty_printed_signature(),
        )

      previous_concrete_function = new_concrete_function

  def test_garbage_collection_capturable_resource_doesnt_raise_exception(
      self, use_cpp_bindings
  ):
    # TODO(b/264869753) Fix SingleCycleTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    model = module.Module()
    model.mapping = lookup_ops.StaticHashTable(
        lookup_ops.KeyValueTensorInitializer(
            keys=math_ops.range(1, dtype=dtypes.int32), values=["foo"]
        ),
        "default_value",
    )
    loaded = cycle(model, 1, use_cpp_bindings=use_cpp_bindings)
    del model
    del loaded
    # Exceptions raised during garbage collection are simply printed to stderr
    # and ignored, and we have no way to access them. We'll capture stdout
    # during the garbage collection process and inspect to see if any
    # exceptions were raised.
    stderr = io.StringIO()
    with contextlib.redirect_stderr(stderr):
      gc.collect()
    if "Exception ignored in" in stderr.getvalue():
      raise Exception(stderr.getvalue())

  def test_captured_dataset_with_asset(self, use_cpp_bindings):
    # TODO(b/264869753) Fix SingleCycleTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    class HasDataset(module.Module):

      def __init__(self, temp_dir, file_name):
        super(HasDataset, self).__init__()
        file = os.path.join(temp_dir, file_name)
        with tf_record.TFRecordWriter(file, "GZIP") as f:
          for v in ["a", "aa", "aaa"]:
            f.write(str(v))
        self.dataset = readers.TFRecordDataset([file], compression_type="GZIP")

      @def_function.function
      def __call__(self, x):
        current_sum = array_ops.zeros([], dtype=dtypes.int32)
        for element in self.dataset:
          current_sum += x * string_ops.string_length(element)
        return current_sum

    temp_dir = self.get_temp_dir()
    file_name = "tf_record_asset.tfrecord.gz"
    root = HasDataset(temp_dir, file_name)
    self.assertEqual(
        18,  # 3 * (1 + 2 + 3)
        root(constant_op.constant(3, dtype=dtypes.int32)).numpy(),
    )

    save_dir = os.path.join(self.get_temp_dir(), "save_dir")
    save.save(root, save_dir)

    file_io.delete_file(os.path.join(temp_dir, file_name))
    asset_path = os.path.join(save_dir, "assets/{}".format(file_name))
    self.assertTrue(file_io.file_exists(asset_path))
    load_dir = os.path.join(self.get_temp_dir(), "load_dir")
    file_io.rename(save_dir, load_dir)

    loaded = test_load(load_dir, use_cpp_bindings=use_cpp_bindings)
    self.assertEqual(
        18,  # 3 * (1 + 2 + 3)
        loaded(constant_op.constant(3, dtype=dtypes.int32)).numpy(),
    )

  def test_function_aliases(self, use_cpp_bindings):
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    root = autotrackable.AutoTrackable()
    root.f = def_function.function(
        lambda x: 2 * x,
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)],
    )
    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
    options = save_options.SaveOptions(function_aliases={
        "my_func": root.f,
    })
    save.save(root, save_dir, root.f, options=options)
    loaded = test_load(
        save_dir,
        use_cpp_bindings=use_cpp_bindings,
        options=load_options.LoadOptions(
            experimental_load_function_aliases=True
        ),
    )
    self.assertLen(loaded.function_aliases, 1)
    self.assertIn("my_func", loaded.function_aliases)
    self.assertEqual(loaded.function_aliases["my_func"](1.0).numpy(), 2.0)

  def test_function_aliases_with_non_saved_function(self, use_cpp_bindings):
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    # `f` below will be aliased but not saved because is not tracked
    f = def_function.function(lambda x: 2 * x)
    root = autotrackable.AutoTrackable()
    root.g = def_function.function(lambda x: 2 * f(x))
    # Create two traces
    root.g(constant_op.constant(1))
    root.g(constant_op.constant(1.0, dtype=dtypes.float32))
    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
    options = save_options.SaveOptions(
        function_aliases={
            "my_func": f,
        }
    )
    save.save(root, save_dir, options=options)
    loaded = test_load(
        save_dir,
        use_cpp_bindings=use_cpp_bindings,
        options=load_options.LoadOptions(
            experimental_load_function_aliases=True
        ),
    )
    self.assertLen(loaded.function_aliases, 1)
    self.assertIn("my_func", loaded.function_aliases)
    self.assertLen(loaded.function_aliases["my_func"], 2)
    self.assertIsInstance(
        loaded.function_aliases["my_func"][0], types_core.ConcreteFunction
    )
    self.assertIsInstance(
        loaded.function_aliases["my_func"][1], types_core.ConcreteFunction
    )

  @unittest.skip("skip until unexpected retracing is fixed/handled b/280121368")
  def test_function_aliases_with_concrete_function(self, use_cpp_bindings):
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    # `f` below will be aliased but not saved because is not tracked
    f = def_function.function(lambda x: 2 * x)
    root = autotrackable.AutoTrackable()
    root.g = def_function.function(lambda x: 2 * f(x))
    # Create two traces
    root.g(constant_op.constant(1))
    root.g(constant_op.constant(1.0, dtype=dtypes.float32))
    self.assertLen(f._list_all_concrete_functions(), 2)
    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
    options = save_options.SaveOptions(
        function_aliases={
            "my_func": f.get_concrete_function(
                tensor_spec.TensorSpec([], dtypes.float32)
            ),
        }
    )
    self.assertLen(f._list_all_concrete_functions(), 2)
    save.save(root, save_dir, options=options)
    loaded = test_load(
        save_dir,
        use_cpp_bindings=use_cpp_bindings,
        options=load_options.LoadOptions(
            experimental_load_function_aliases=True
        ),
    )
    self.assertLen(loaded.function_aliases, 1)
    self.assertIn("my_func", loaded.function_aliases)
    self.assertLen(loaded.function_aliases["my_func"], 1)
    self.assertIsInstance(
        loaded.function_aliases["my_func"][0], types_core.ConcreteFunction
    )

  @unittest.skip("skip until unexpected retracing is fixed/handled b/280121368")
  def test_function_aliases_with_concrete_functions(self, use_cpp_bindings):
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    # `f` below will be aliased but not saved because is not tracked
    f = def_function.function(lambda x: 2 * x)
    root = autotrackable.AutoTrackable()
    root.g = def_function.function(lambda x: 2 * f(x))
    # Create 3 traces for g, which will in turn create 3 traces for f.
    root.g(x=constant_op.constant(1))
    root.g(x=constant_op.constant(1.0, dtype=dtypes.float32))
    root.g(x=constant_op.constant(1.0, dtype=dtypes.float16))
    self.assertLen(f._list_all_concrete_functions(), 3)

    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
    options = save_options.SaveOptions(
        function_aliases={
            # Alias 2 out of 3 traces of f
            "my_func": [
                f.get_concrete_function(
                    x=tensor_spec.TensorSpec([], dtypes.int32)
                ),
                f.get_concrete_function(
                    x=tensor_spec.TensorSpec([], dtypes.float32)
                ),
            ],
        }
    )
    self.assertLen(f._list_all_concrete_functions(), 3)
    save.save(root, save_dir, options=options)
    loaded = test_load(
        save_dir,
        use_cpp_bindings=use_cpp_bindings,
        options=load_options.LoadOptions(
            experimental_load_function_aliases=True
        ),
    )
    self.assertLen(loaded.function_aliases, 1)
    self.assertIn("my_func", loaded.function_aliases)
    self.assertLen(loaded.function_aliases["my_func"], 2)
    self.assertIsInstance(
        loaded.function_aliases["my_func"][0], types_core.ConcreteFunction
    )
    self.assertIsInstance(
        loaded.function_aliases["my_func"][1], types_core.ConcreteFunction
    )

  def test_function_aliases_name_collision(self, use_cpp_bindings):
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    root = autotrackable.AutoTrackable()
    root.f = def_function.function(
        lambda x: 2. * x,
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
    root.function_aliases = variables.Variable(1.0)
    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
    options = save_options.SaveOptions(function_aliases={
        "my_func": root.f,
    })
    save.save(root, save_dir, root.f, options=options)
    with self.assertRaisesRegex(
        ValueError, "Could not load with experimental_load_function_aliases"
    ):
      test_load(
          save_dir,
          use_cpp_bindings=use_cpp_bindings,
          options=load_options.LoadOptions(
              experimental_load_function_aliases=True
          ),
      )


# TODO(b/264882754) Support Cpp bindings DeferredInitModuleVariablesTest
class DeferredInitModuleVariablesTest(test.TestCase, parameterized.TestCase):

  def test_deferred_init_module_variables(self):
    """Defer initialization of variables in a module to the load stage."""

    class MyModule(module.Module):

      def __init__(self, size):
        super().__init__()
        self.size = size
        # variable initialized by a Tensor-compatible value
        self.w1 = variables.Variable(
            constant_op.constant(1., shape=[self.size]), trainable=False)
        # variable initialized by a function
        self.w2 = variables.Variable(
            lambda: constant_op.constant(2., shape=[self.size]))
        # variable instantiated lazily in call()
        self.w3 = None

      def call(self):
        if self.w3 is None:
          self.w3 = variables.Variable(
              constant_op.constant(3., shape=[self.size]))
        for w in (self.w1, self.w2, self.w3):
          w.assign_add(constant_op.constant(1., shape=[self.size]))
        return self.w1, self.w2, self.w3

    def export_initializer(initial_value, export_dir):

      class Initializer(module.Module):

        @def_function.function(input_signature=[])
        def call(self):
          if callable(initial_value):
            return initial_value()
          return initial_value

      save.save(Initializer(), export_dir)

    def create_and_save_module(weight_size):

      initial_values = {}  # For storing initial_value of created variables

      def variable_creator(next_creator, **kwargs):
        variable = next_creator(**kwargs)
        variable_name = variable.name
        if ":" in variable_name:
          variable_name = variable_name[:variable_name.index(":")]
        initial_values[variable_name] = kwargs["initial_value"]
        return variable

      export_dir = self.create_tempdir().full_path

      with ops.Graph().as_default():
        with variable_scope.variable_creator_scope(variable_creator):
          exported = MyModule(weight_size)
          exported.call = def_function.function(input_signature=[])(
              exported.call)

          module_dir = f"{export_dir}/module"
          file_io.recursive_create_dir(module_dir)
          save.save_and_return_nodes(
              exported, module_dir, experimental_skip_checkpoint=True)

      # Save the initializer of the created variables.
      for variable_name, initial_value in initial_values.items():
        export_initializer(initial_value,
                           f"{export_dir}/variables/{variable_name}")

      return export_dir

    def load_and_run_module(export_dir, weight_size):

      # pylint: disable=unused-argument
      def layer_variable_creator(next_creator, **kwargs):
        variable_dir = f"{export_dir}/variables/{kwargs['name']}"
        initializer = load.load(variable_dir)
        kwargs["initial_value"] = initializer.call
        variable = resource_variable_ops.ResourceVariable(**kwargs)
        return variable

      with ops.Graph().as_default():
        with variable_scope.variable_creator_scope(layer_variable_creator):
          imported = load.load(
              f"{export_dir}/module",
              options=load_options.LoadOptions(
                  experimental_skip_checkpoint=True))
        outputs = imported.call()

        with self.cached_session() as sess:
          variables.global_variables_initializer().run()
          # Check if variables work as expected across multiple iterations.
          for i in range(3):
            np_outputs = sess.run(outputs)
            for j, np_output in enumerate(np_outputs):
              self.assertAllClose(np_output, np.full(weight_size, i + j + 2))

    # The size of the serialized content (both module and variables) stays
    # small even with a large weight_size as the initial values are not stored
    # in checkpoints.
    weight_size = 1024
    export_dir = create_and_save_module(weight_size)
    load_and_run_module(export_dir, weight_size)

  def _make_asset(self, contents):
    fd, filename = tempfile.mkstemp(prefix=self.get_temp_dir())
    with os.fdopen(fd, "w") as f:
      f.write(contents)
    return filename

  @parameterized.named_parameters(*_test_params())
  def test_assets(self, use_cpp_bindings):
    # TODO(b/264882754) Fix DeferredInitModuleVariablesTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")

    class MyLookupModel(autotrackable.AutoTrackable):

      def __init__(self, vocab_file):
        vocab_initializer = lookup_ops.TextFileInitializer(
            vocab_file,
            key_dtype=dtypes.string,
            key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
            value_dtype=dtypes.int64,
            value_index=lookup_ops.TextFileIndex.LINE_NUMBER,
        )
        self._vocab_table = lookup_ops.StaticHashTable(
            vocab_initializer, default_value=-1
        )

      @def_function.function(
          input_signature=[tensor_spec.TensorSpec((None,), dtypes.string)]
      )
      def __call__(self, inputs):
        return self._vocab_table.lookup(inputs)

    vocab_file = self._make_asset("\n".join(["a", "b", "c", "d"]))
    root = MyLookupModel(vocab_file)

    save_dir = os.path.join(self.get_temp_dir(), "save_dir")
    save.save_and_return_nodes(
        root, save_dir, experimental_skip_checkpoint=True
    )
    file_io.delete_file(vocab_file)
    load_dir = os.path.join(self.get_temp_dir(), "load_dir")
    file_io.rename(save_dir, load_dir)

    imported = test_load(
        load_dir,
        options=load_options.LoadOptions(experimental_skip_checkpoint=True),
        use_cpp_bindings=use_cpp_bindings,
    )
    self.assertAllEqual(imported(constant_op.constant(["d", "b"])), [3, 1])


class _TestModel(module.Module):

  def __init__(self, rows, cols):
    super().__init__()
    self.rows = rows
    self.cols = cols
    self.table = None

  def __call__(self, x):
    with ops.device("/cpu:0"):
      self.table = variables.Variable(
          constant_op.constant(1.0, shape=[self.rows, self.cols])
      )
      x = math_ops.matmul(self.table, x)
      x = math_ops.reduce_sum(x, axis=0)
    return x


@parameterized.named_parameters(*_test_params())
class SavedModelLoadMemoryTests(test.TestCase, parameterized.TestCase):

  @test_util.run_gpu_only
  def test_no_oom_loading_large_tenor(self, use_cpp_bindings):
    # TODO(b/264882686) Fix DeferredInitModuleVariablesTest
    if use_cpp_bindings:
      self.skipTest("Not implemented for cpp.")
    if not config.get_soft_device_placement():
      self.skipTest("This test only works for soft device placement is on")
    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
    ncols = 16
    nrows = 32
    model = _TestModel(rows=nrows, cols=ncols)
    x = array_ops.zeros(shape=(ncols, 2), dtype=dtypes.float32)
    y = model(x)
    save.save(
        model,
        save_dir,
        options=save_options.SaveOptions(
            experimental_variable_policy=save_options.VariablePolicy.SAVE_VARIABLE_DEVICES
        ),
    )
    loaded_on_cpu = test_load(
        path=save_dir,
        options=load_options.LoadOptions(
            experimental_variable_policy=save_options.VariablePolicy.SAVE_VARIABLE_DEVICES
        ),
        use_cpp_bindings=use_cpp_bindings,
    )
    loaded_on_gpu = test_load(save_dir)
    self.assertIn("CPU", loaded_on_cpu.table.device)
    self.assertIn("GPU", loaded_on_gpu.table.device)


if __name__ == "__main__":
  test.main()