tensorflow/tensorflow

View on GitHub
tensorflow/python/framework/convert_to_constants_test.py

Summary

Maintainability
F
3 wks
Test Coverage
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for convert_to_constants.py."""

import os
import re

import numpy as np

from google.protobuf import text_format
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.core.framework import op_def_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import convert_to_constants
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond
from tensorflow.python.ops import cond_v2
from tensorflow.python.ops import control_flow_case
from tensorflow.python.ops import control_flow_switch_case
from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import ref_variable
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variable_v1
from tensorflow.python.ops import variables
from tensorflow.python.ops import while_loop
from tensorflow.python.ops import while_v2
from tensorflow.python.platform import test
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import simple_save
from tensorflow.python.saved_model.load import load
from tensorflow.python.saved_model.save import save
from tensorflow.python.trackable import autotrackable
from tensorflow.python.training.saver import export_meta_graph
from tensorflow.python.util import compat
from tensorflow.python.util import nest


class _GraphMerger(object):
  """GraphDef merging methods for testing purposes."""

  @staticmethod
  def merge_any(x1, x2, empty_fn):
    """Merges two values using the message's CopyFrom/MergeFrom methods."""
    merged = empty_fn()
    merged.CopyFrom(x1)
    merged.MergeFrom(x2)
    return merged

  @staticmethod
  def merge_nodes(node1, node2):
    """Merges two NodeDef messages."""
    merged = _GraphMerger.merge_any(node1, node2, node_def_pb2.NodeDef)
    merged_inputs = node1.input[:]
    merged_inputs.extend([i for i in node2.input[:] if i not in merged_inputs])
    merged.input[:] = merged_inputs
    return merged

  @staticmethod
  def merge_lists(repeated1, repeated2, empty_fn, key_fn, merge_fn):
    """Merges two lists representing maps."""
    merged = {}
    xs1 = {key_fn(x): x for x in repeated1}
    xs2 = {key_fn(x): x for x in repeated2}
    for name in set().union(xs1.keys(), xs2.keys()):
      x1 = empty_fn() if name not in xs1 else xs1[name]
      x2 = empty_fn() if name not in xs2 else xs2[name]
      merged[name] = merge_fn(x1, x2)
    return sorted(merged.values(), key=key_fn)

  @staticmethod
  def merge_node_lists(repeated_nodes1, repeated_nodes2):
    """Merges two repeated node fields."""
    return _GraphMerger.merge_lists(repeated_nodes1, repeated_nodes2,
                                    node_def_pb2.NodeDef, lambda n: n.name,
                                    _GraphMerger.merge_nodes)

  @staticmethod
  def merge_functions(fn1, fn2):
    """Merges two FunctionDefs."""
    merged = _GraphMerger.merge_any(fn1, fn2, function_pb2.FunctionDef)

    del merged.signature.input_arg[:]
    merged.signature.input_arg.extend(
        _GraphMerger.merge_lists(
            fn1.signature.input_arg[:], fn2.signature.input_arg[:],
            op_def_pb2.OpDef.ArgDef, lambda a: a.name,
            lambda x, y: _GraphMerger.merge_any(x, y, op_def_pb2.OpDef.ArgDef)))

    del merged.signature.output_arg[:]
    merged.signature.output_arg.extend(
        _GraphMerger.merge_lists(
            fn1.signature.output_arg[:], fn2.signature.output_arg[:],
            op_def_pb2.OpDef.ArgDef, lambda a: a.name,
            lambda x, y: _GraphMerger.merge_any(x, y, op_def_pb2.OpDef.ArgDef)))

    del merged.node_def[:]
    merged.node_def.extend(
        _GraphMerger.merge_node_lists(fn1.node_def[:], fn2.node_def[:]))

    return merged

  @staticmethod
  def merge_graphs(graph1, graph2):
    """Merges two GraphDef messages."""
    merged = graph_pb2.GraphDef()
    merged.node.extend(
        _GraphMerger.merge_node_lists(graph1.node[:], graph2.node[:]))

    merged.library.function.extend(
        _GraphMerger.merge_lists(graph1.library.function,
                                 graph2.library.function,
                                 function_pb2.FunctionDef,
                                 lambda f: f.signature.name,
                                 _GraphMerger.merge_functions))

    return merged


def has_stateful_partitioned_call_op(graph_def):
  """Determines if a StatefulPartitionedCall op exists in the graph."""
  for node in graph_def.node:
    if node.op == "StatefulPartitionedCall":
      return True
  return False


def get_num_variables(graph_def):
  """Returns the number of ReadVariableOp in the graph."""
  return sum(node.op == "ReadVariableOp" for node in graph_def.node)


class VariablesToConstantsTest(test.TestCase):

  def _freezeModel(self, func):
    """Freezes the function.

    Args:
      func: Function.

    Returns:
      root: AutoTrackable object with original ConcreteFunction.
      output_func: frozen ConcreteFunction.
    """
    root = autotrackable.AutoTrackable()
    root.f = func
    input_func = root.f.get_concrete_function()

    output_func = convert_to_constants.convert_variables_to_constants_v2(
        input_func, lower_control_flow=False)
    return root, output_func

  def _testConvertedFunction(self, obj, func, converted_concrete_func,
                             input_data):
    # Ensure the converted graph has no variables and no function calls.
    constant_graph_def = converted_concrete_func.graph.as_graph_def()
    self.assertEqual(0, get_num_variables(constant_graph_def))
    self.assertFalse(has_stateful_partitioned_call_op(constant_graph_def))

    # Check that the converted ConcreteFunction produces the same result as the
    # original Function.
    expected_value = nest.flatten(func(**input_data))
    actual_value = nest.flatten(converted_concrete_func(**input_data))

    for expected, actual in zip(expected_value, actual_value):
      np.testing.assert_almost_equal(expected.numpy(), actual.numpy())

    # Ensure the shape is retained.
    for tensor in converted_concrete_func.inputs:
      actual_shape = input_data[tensor.name.split(":")[0]].shape
      self.assertEqual(tensor.shape, actual_shape)

    # Save the converted ConcreteFunction as a signature.
    save_dir = os.path.join(self.get_temp_dir(), "frozen_saved_model")
    root = autotrackable.AutoTrackable()
    root.f = converted_concrete_func
    save(root, save_dir, {"mykey": converted_concrete_func})

    # Load it back and make sure it works.
    loaded_obj = load(save_dir)
    actual_value = nest.flatten(loaded_obj.signatures["mykey"](**input_data))
    for expected, actual in zip(expected_value, actual_value):
      np.testing.assert_almost_equal(expected.numpy(), actual.numpy())

  @test_util.run_v2_only
  def testConstSavedModel(self):
    """Test a basic model with constants while saving/loading the SavedModel."""
    input_data = {"x": constant_op.constant(1., shape=[1])}
    root = autotrackable.AutoTrackable()
    root.f = def_function.function(lambda x: 2. * x)
    to_save = root.f.get_concrete_function(input_data["x"])

    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
    save(root, save_dir, to_save)
    saved_model = load(save_dir)
    input_func = saved_model.signatures["serving_default"]

    variable_graph_def = input_func.graph.as_graph_def()
    self.assertEqual(0, get_num_variables(variable_graph_def))
    self.assertTrue(variable_graph_def.library.function)

    output_func = convert_to_constants.convert_variables_to_constants_v2(
        input_func)
    self._testConvertedFunction(root, root.f, output_func, input_data)

  @test_util.run_v2_only
  def testVariableModel(self):
    """Test a basic model with Variables."""
    input_data = {"x": constant_op.constant(1., shape=[1])}
    root = autotrackable.AutoTrackable()
    root.v1 = variables.Variable(3.)
    root.v2 = variables.Variable(2.)
    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
    input_func = root.f.get_concrete_function(input_data["x"])

    variable_graph_def = input_func.graph.as_graph_def()
    self.assertEqual(2, get_num_variables(variable_graph_def))

    output_func = convert_to_constants.convert_variables_to_constants_v2(
        input_func)
    self._testConvertedFunction(root, root.f, output_func, input_data)

  @test_util.run_v2_only
  def testScalarModel(self):
    """Test a basic model with Variables."""
    input_data = {"x": constant_op.constant(1., shape=[])}
    root = autotrackable.AutoTrackable()
    root.v1 = variables.Variable(3.)
    root.v2 = variables.Variable(2.)
    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
    input_func = root.f.get_concrete_function(input_data["x"])

    variable_graph_def = input_func.graph.as_graph_def()
    self.assertEqual(2, get_num_variables(variable_graph_def))

    output_func = convert_to_constants.convert_variables_to_constants_v2(
        input_func)
    self._testConvertedFunction(root, root.f, output_func, input_data)

  @test_util.run_v2_only
  def testVariableSavedModel(self):
    """Test a basic model with Variables with saving/loading the SavedModel."""
    input_data = {"x": constant_op.constant(1., shape=[1])}
    root = autotrackable.AutoTrackable()
    root.v1 = variables.Variable(3.)
    root.v2 = variables.Variable(2.)
    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
    to_save = root.f.get_concrete_function(input_data["x"])

    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
    save(root, save_dir, to_save)
    saved_model = load(save_dir)
    input_func = saved_model.signatures["serving_default"]

    variable_graph_def = input_func.graph.as_graph_def()
    self.assertTrue(has_stateful_partitioned_call_op(variable_graph_def))

    output_func = convert_to_constants.convert_variables_to_constants_v2(
        input_func)
    self._testConvertedFunction(root, root.f, output_func, input_data)

  @test_util.run_v2_only
  def testMultiFunctionModel(self):
    """Test a basic model with multiple tf.functions."""

    class BasicModel(autotrackable.AutoTrackable):

      def __init__(self):
        self.y = None
        self.z = None

      @def_function.function
      def add(self, x):
        if self.y is None:
          self.y = variables.Variable(2.)
        return x + self.y

      @def_function.function
      def sub(self, x):
        if self.z is None:
          self.z = variables.Variable(3.)
        return x - self.z

    input_data = {"x": constant_op.constant(1., shape=[1])}
    root = BasicModel()
    input_func = root.add.get_concrete_function(input_data["x"])

    variable_graph_def = input_func.graph.as_graph_def()
    self.assertEqual(1, get_num_variables(variable_graph_def))

    output_func = convert_to_constants.convert_variables_to_constants_v2(
        input_func)
    self._testConvertedFunction(root, root.add, output_func, input_data)

  def _singleMetaGraphSavedModel(self):
    export_graph = ops.Graph()
    with export_graph.as_default():
      start = array_ops.placeholder(
          shape=[1, 1], dtype=dtypes.float32, name="start")
      distractor = ref_variable.RefVariable(-1., name="distractor")
      v = ref_variable.RefVariable(3., name="v")
      local_variable = variable_v1.VariableV1(
          1.,
          collections=[ops.GraphKeys.LOCAL_VARIABLES],
          trainable=False,
          use_resource=True)
      output = array_ops.identity(start * v * local_variable, name="output")
      with session_lib.Session() as session:
        session.run([v.initializer, distractor.initializer,
                     local_variable.initializer])
        path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
        simple_save.simple_save(
            session,
            path,
            inputs={"start": start},
            outputs={"output": output},
            legacy_init_op=local_variable.initializer)
    return path

  @test_util.run_v2_only
  def testRefVariableImport(self):
    """Test a model with 1.X ReferenceVariables."""
    input_data = {"start": constant_op.constant(1., shape=[1, 1])}

    saved = self._singleMetaGraphSavedModel()
    imported = load(saved)
    fn = imported.signatures["serving_default"]

    output_func = convert_to_constants.convert_variables_to_constants_v2(fn)
    root = autotrackable.AutoTrackable()
    self._testConvertedFunction(root, fn, output_func, input_data)

  @test_util.run_v2_only
  def testIf(self):
    """Test a model with the If op."""
    input_data = {
        "x": constant_op.constant([1., 2.], shape=[1, 2]),
        "b": constant_op.constant(True)
    }

    weights = variables.Variable([[0.1, 0.2], [0.3, 0.4]], dtype=dtypes.float32)

    def true_fn(x):
      return math_ops.matmul(x, weights)

    def false_fn(x):
      return math_ops.add(x, weights)

    @def_function.function(input_signature=[
        tensor_spec.TensorSpec(shape=[1, 2], dtype=dtypes.float32),
        tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool)
    ])
    def model(x, b):
      return cond.cond(
          b, true_fn=lambda: true_fn(x), false_fn=lambda: false_fn(x))

    root, output_func = self._freezeModel(model)
    self._testConvertedFunction(root, root.f, output_func, input_data)

  @test_util.run_v2_only
  def testStatelessIf(self):
    """Test a model with the StatelessIf op."""
    input_data = {"b": constant_op.constant(True)}

    x = constant_op.constant([1., 2.], shape=[1, 2], name="x")

    def true_fn():
      return x

    def false_fn():
      return x + 2

    @def_function.function(
        input_signature=[tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool)])
    def model(b):
      return cond_v2.cond_v2(b, true_fn, false_fn)

    root, output_func = self._freezeModel(model)
    self._testConvertedFunction(root, root.f, output_func, input_data)

  @test_util.run_v2_only
  def testStaticRnn(self):
    """Test a StaticRnn containing If ops."""
    input_data = {
        "x":
            constant_op.constant(
                np.array(np.random.random_sample((3, 10)), dtype=np.float32))
    }

    cell = rnn_cell_impl.LSTMCell(10)

    @def_function.function(input_signature=[
        tensor_spec.TensorSpec(shape=[3, 10], dtype=dtypes.float32)
    ])
    def model(x):
      seq = array_ops.split(x, 3, 0)
      return rnn.static_rnn(
          cell, seq, dtype=dtypes.float32, sequence_length=[1])

    root, output_func = self._freezeModel(model)

    self._testConvertedFunction(root, root.f, output_func, input_data)

  @test_util.run_v2_only
  def testWhile(self):
    """Test a While loop."""
    input_data = {"x": constant_op.constant([1., 2., 3., 4.], shape=[2, 2])}

    weights = variables.Variable([[0.1, 0.2], [0.3, 0.4]], dtype=dtypes.float32)

    def condition(x):
      return math_ops.reduce_sum(x) < 100

    def body(x):
      return math_ops.add(x, weights)

    @def_function.function(input_signature=[
        tensor_spec.TensorSpec(shape=[2, 2], dtype=dtypes.float32)
    ])
    def model(x):
      return while_loop.while_loop(condition, body, [x])

    root, output_func = self._freezeModel(model)

    self._testConvertedFunction(root, root.f, output_func, input_data)

  @test_util.run_v2_only
  def testStatelessWhile(self):
    """Test a StatelessWhile loop."""
    input_data = {"x": constant_op.constant(2.)}

    @def_function.function(input_signature=[
        tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32)
    ])
    def model(x):
      return while_v2.while_loop(
          lambda v: v < 4.,
          lambda v: v * v, [x],
          return_same_structure=False,
          name="while_1")  # x**2

    root, output_func = self._freezeModel(model)
    self._testConvertedFunction(root, root.f, output_func, input_data)

  @test_util.run_v2_only
  def testDynamicRnn(self):
    """Test a DynamicRnn containing While loops."""
    input_data = {
        "x":
            constant_op.constant(
                np.array(
                    np.random.random_sample((3, 10, 10)), dtype=np.float32))
    }

    cell = rnn_cell_impl.LSTMCell(10)

    @def_function.function(input_signature=[
        tensor_spec.TensorSpec(shape=[3, 10, 10], dtype=dtypes.float32)
    ])
    def model(x):
      return rnn.dynamic_rnn(cell, x, dtype=dtypes.float32)

    root, output_func = self._freezeModel(model)
    self._testConvertedFunction(root, root.f, output_func, input_data)

  @test_util.run_v2_only
  @test_util.disable_tfrt("b/180451239")
  def testSwitchCase(self):
    """Test a switch_case statement."""
    input_data = {
        "i": constant_op.constant(np.random.randint(0, 3, dtype=np.int32)),
        "x": constant_op.constant(
            np.asarray(np.random.random_sample((10, 3)), dtype=np.float32)),
    }

    w0 = variables.Variable(np.random.random_sample((3, 4)), dtype=np.float32)
    w1 = variables.Variable(np.random.random_sample((3, 4)), dtype=np.float32)
    w2 = variables.Variable(np.random.random_sample((4,)), dtype=np.float32)

    def branch0(x):
      return math_ops.matmul(x, w0)

    def branch1(x):
      return math_ops.matmul(x, w1)

    def branch2(x):
      x = array_ops.pad(x, [[0, 0], [0, 1]])
      return x + w2

    @def_function.function(input_signature=[
        tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32),
        tensor_spec.TensorSpec(shape=[10, 3], dtype=dtypes.float32),
    ])
    def model(i, x):
      return control_flow_switch_case.switch_case(
          i, [lambda: branch0(x), lambda: branch1(x), lambda: branch2(x)])

    root, output_func = self._freezeModel(model)
    self._testConvertedFunction(root, root.f, output_func, input_data)


class ConvertVariablesToConstantsV2SessionTest(test.TestCase):

  def _freezeModel(self, func):
    """Freezes the function.

    Args:
      func: Function.

    Returns:
      root: AutoTrackable object with original ConcreteFunction.
      output_func: frozen ConcreteFunction.
    """
    root = autotrackable.AutoTrackable()
    root.f = func
    input_func = root.f.get_concrete_function()

    output_func = convert_to_constants.convert_var_to_const_function_in_v1(
        input_func, lower_control_flow=False)
    return root, output_func

  def _testConvertedFunction(self, sess, obj, func, converted_concrete_func,
                             input_data):
    # Ensure the converted graph has no variables and no function calls.
    constant_graph_def = converted_concrete_func.graph.as_graph_def()
    self.assertEqual(0, get_num_variables(constant_graph_def))
    self.assertFalse(has_stateful_partitioned_call_op(constant_graph_def))

    # Check that the converted ConcreteFunction produces the same result as the
    # original Function.
    expected_value = nest.flatten(func(**input_data))
    actual_value = nest.flatten(converted_concrete_func(**input_data))

    for expected, actual in zip(expected_value, actual_value):
      np.testing.assert_almost_equal(sess.run(expected), sess.run(actual))

    # Ensure the shape is retained.
    for tensor in converted_concrete_func.inputs:
      actual_shape = input_data[tensor.name.split(":")[0]].shape
      self.assertEqual(tensor.shape, actual_shape)

    # Save the converted ConcreteFunction as a signature.
    save_dir = os.path.join(self.get_temp_dir(), "frozen_saved_model")
    root = autotrackable.AutoTrackable()
    root.f = converted_concrete_func
    save(root, save_dir, {"mykey": converted_concrete_func})

    # Load it back and make sure it works.
    loaded_obj = load(save_dir)
    actual_value = nest.flatten(loaded_obj.signatures["mykey"](**input_data))
    for expected, actual in zip(expected_value, actual_value):
      np.testing.assert_almost_equal(sess.run(expected), sess.run(actual))

  def testRaiseErrorInEagerMode(self):
    """Test the raised exception in Eager mode."""
    input_data = {"x": constant_op.constant(1., shape=[1])}
    root = autotrackable.AutoTrackable()
    root.v1 = variables.Variable(3.)
    root.v2 = variables.Variable(2.)
    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
    input_func = root.f.get_concrete_function(input_data["x"])

    with self.assertRaisesRegex(RuntimeError,
                                "must be carried out in a Session"):
      convert_to_constants.convert_var_to_const_function_in_v1(
          input_func)

  def testConvertVariables(self):
    """Test a basic model with Variables."""
    with ops.Graph().as_default():
      with session_lib.Session() as sess:
        input_data = {"x": constant_op.constant(1., shape=[1])}
        root = autotrackable.AutoTrackable()
        root.v1 = variables.Variable(3.)
        root.v2 = variables.Variable(2.)
        root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
        input_func = root.f.get_concrete_function(input_data["x"])

        variable_graph_def = input_func.graph.as_graph_def()
        self.assertEqual(2, get_num_variables(variable_graph_def))

        output_func = convert_to_constants.convert_var_to_const_function_in_v1(
            input_func)

        self._testConvertedFunction(sess, root, root.f, output_func, input_data)

  def testConvertVariablesWithAssignments(self):
    """Test a basic model with Variables and assignment ops."""
    with ops.Graph().as_default():
      with session_lib.Session() as sess:
        input_data = {"x": constant_op.constant(1., shape=[1])}
        root = autotrackable.AutoTrackable()
        root.v1 = variables.Variable(3.)
        root.v2 = variables.Variable(2.)
        root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
        input_func = root.f.get_concrete_function(input_data["x"])

        variable_graph_def = input_func.graph.as_graph_def()
        self.assertEqual(2, get_num_variables(variable_graph_def))

        assign_op_1 = root.v1.assign(1.5)
        assign_op_2 = root.v2.assign(3.0)
        assign_op_3 = root.v1.assign(4.0)
        ops.get_default_graph().add_to_collection(
            convert_to_constants.VAR_ASSIGN_COLLECTION, assign_op_1)
        ops.get_default_graph().add_to_collection(
            convert_to_constants.VAR_ASSIGN_COLLECTION, assign_op_2)
        ops.get_default_graph().add_to_collection(
            convert_to_constants.VAR_ASSIGN_COLLECTION, assign_op_3)

        output_func = convert_to_constants.convert_var_to_const_function_in_v1(
            input_func)
        self._testConvertedFunction(sess, root, root.f, output_func, input_data)

  def testConstSavedModel(self):
    """Test a basic model with constants while saving/loading the SavedModel."""
    with ops.Graph().as_default():
      with session_lib.Session() as sess:
        input_data = {"x": constant_op.constant(1., shape=[1])}
        root = autotrackable.AutoTrackable()
        root.f = def_function.function(lambda x: 2. * x)
        to_save = root.f.get_concrete_function(input_data["x"])

        save_dir = os.path.join(self.get_temp_dir(), "saved_model")
        save(root, save_dir, to_save)
        saved_model = load(save_dir)
        input_func = saved_model.signatures["serving_default"]

        variable_graph_def = input_func.graph.as_graph_def()
        self.assertEqual(0, get_num_variables(variable_graph_def))
        self.assertTrue(variable_graph_def.library.function)

        output_func = convert_to_constants.convert_var_to_const_function_in_v1(
            input_func)
        self._testConvertedFunction(sess, root, root.f, output_func, input_data)

  def testVariableSavedModel(self):
    """Test a basic model with Variables with saving/loading the SavedModel."""
    with ops.Graph().as_default():
      with session_lib.Session() as sess:
        input_data = {"x": constant_op.constant(1., shape=[1])}
        root = autotrackable.AutoTrackable()
        root.v1 = variables.Variable(3.)
        root.v2 = variables.Variable(2.)
        root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
        to_save = root.f.get_concrete_function(input_data["x"])
        sess.run(variables.global_variables_initializer())

        save_dir = os.path.join(self.get_temp_dir(), "saved_model")
        save(root, save_dir, to_save)
        saved_model = load(save_dir)
        input_func = saved_model.signatures["serving_default"]

        variable_graph_def = input_func.graph.as_graph_def()
        self.assertTrue(has_stateful_partitioned_call_op(variable_graph_def))

        output_func = convert_to_constants.convert_var_to_const_function_in_v1(
            input_func)
        self._testConvertedFunction(sess, root, root.f, output_func, input_data)

  def testMultiFunctionModel(self):
    """Test a basic model with multiple tf.functions."""

    class BasicModel(autotrackable.AutoTrackable):

      def __init__(self):
        self.y = None
        self.z = None

      @def_function.function
      def add(self, x):
        if self.y is None:
          self.y = variables.Variable(2.)
        return x + self.y

      @def_function.function
      def sub(self, x):
        if self.z is None:
          self.z = variables.Variable(3.)
        return x - self.z

    with ops.Graph().as_default():
      with session_lib.Session() as sess:
        input_data = {"x": constant_op.constant(1., shape=[1])}
        root = BasicModel()
        input_func = root.add.get_concrete_function(input_data["x"])

        variable_graph_def = input_func.graph.as_graph_def()
        self.assertEqual(1, get_num_variables(variable_graph_def))

        output_func = convert_to_constants.convert_var_to_const_function_in_v1(
            input_func)
        self._testConvertedFunction(sess, root, root.add, output_func,
                                    input_data)

  def testIf(self):
    """Test a model with the If op."""
    with ops.Graph().as_default():
      with session_lib.Session() as sess:
        input_data = {
            "x": constant_op.constant([1., 2.], shape=[1, 2]),
            "b": constant_op.constant(True)
        }

        weights = variables.Variable([[0.1, 0.2], [0.3, 0.4]],
                                     dtype=dtypes.float32)

        def true_fn(x):
          return math_ops.matmul(x, weights)

        def false_fn(x):
          return math_ops.add(x, weights)

        @def_function.function(input_signature=[
            tensor_spec.TensorSpec(shape=[1, 2], dtype=dtypes.float32),
            tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool)
        ])
        def model(x, b):
          return cond.cond(
              b, true_fn=lambda: true_fn(x), false_fn=lambda: false_fn(x))

        root, output_func = self._freezeModel(model)
        self._testConvertedFunction(sess, root, root.f, output_func, input_data)

  def testStatelessIf(self):
    """Test a model with the StatelessIf op."""
    with ops.Graph().as_default():
      with session_lib.Session() as sess:
        input_data = {"b": constant_op.constant(True)}

        x = constant_op.constant([1., 2.], shape=[1, 2], name="x")

        def true_fn():
          return x

        def false_fn():
          return x + 2

        @def_function.function(input_signature=[
            tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool)
        ])
        def model(b):
          return cond_v2.cond_v2(b, true_fn, false_fn)

        root, output_func = self._freezeModel(model)
        self._testConvertedFunction(sess, root, root.f, output_func, input_data)

  def testStaticRnn(self):
    """Test a StaticRnn containing If ops."""
    with ops.Graph().as_default():
      with session_lib.Session() as sess:
        input_data = {
            "x":
                constant_op.constant(
                    np.array(
                        np.random.random_sample((3, 10)), dtype=np.float32))
        }

        cell = rnn_cell_impl.LSTMCell(10)

        @def_function.function(input_signature=[
            tensor_spec.TensorSpec(shape=[3, 10], dtype=dtypes.float32)
        ])
        def model(x):
          seq = array_ops.split(x, 3, 0)
          return rnn.static_rnn(
              cell, seq, dtype=dtypes.float32, sequence_length=[1])

        root, output_func = self._freezeModel(model)

        self._testConvertedFunction(sess, root, root.f, output_func, input_data)

  def testWhile(self):
    """Test a While loop."""
    with ops.Graph().as_default():
      with session_lib.Session() as sess:
        input_data = {"x": constant_op.constant([1., 2., 3., 4.], shape=[2, 2])}

        weights = variables.Variable([[0.1, 0.2], [0.3, 0.4]],
                                     dtype=dtypes.float32)

        def condition(x):
          return math_ops.reduce_sum(x) < 100

        def body(x):
          return math_ops.add(x, weights)

        @def_function.function(input_signature=[
            tensor_spec.TensorSpec(shape=[2, 2], dtype=dtypes.float32)
        ])
        def model(x):
          return while_loop.while_loop(condition, body, [x])

        root, output_func = self._freezeModel(model)

        self._testConvertedFunction(sess, root, root.f, output_func, input_data)

  def testStatelessWhile(self):
    """Test a StatelessWhile loop."""
    with ops.Graph().as_default():
      with session_lib.Session() as sess:
        input_data = {"x": constant_op.constant(2.)}

        @def_function.function(input_signature=[
            tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32)
        ])
        def model(x):
          return while_v2.while_loop(
              lambda v: v < 4.,
              lambda v: v * v, [x],
              return_same_structure=False,
              name="while_1")  # x**2

        root, output_func = self._freezeModel(model)
        self._testConvertedFunction(sess, root, root.f, output_func, input_data)

  def testDynamicRnn(self):
    """Test a DynamicRnn containing While loops."""
    with ops.Graph().as_default():
      with session_lib.Session() as sess:
        input_data = {
            "x":
                constant_op.constant(
                    np.array(
                        np.random.random_sample((3, 10, 10)), dtype=np.float32))
        }

        cell = rnn_cell_impl.LSTMCell(10)

        @def_function.function(input_signature=[
            tensor_spec.TensorSpec(shape=[3, 10, 10], dtype=dtypes.float32)
        ])
        def model(x):
          return rnn.dynamic_rnn(cell, x, dtype=dtypes.float32)

        root, output_func = self._freezeModel(model)
        self._testConvertedFunction(sess, root, root.f, output_func, input_data)

  @test_util.disable_tfrt("b/180451239")
  def testSwitchCase(self):
    """Test a switch_case statement."""
    with ops.Graph().as_default():
      with session_lib.Session() as sess:
        input_data = {
            "i":
                constant_op.constant(np.random.randint(0, 3, dtype=np.int32)),
            "x":
                constant_op.constant(
                    np.asarray(
                        np.random.random_sample((10, 3)), dtype=np.float32)),
        }

        w0 = variables.Variable(
            np.random.random_sample((3, 4)), dtype=np.float32)
        w1 = variables.Variable(
            np.random.random_sample((3, 4)), dtype=np.float32)
        w2 = variables.Variable(np.random.random_sample((4,)), dtype=np.float32)

        def branch0(x):
          return math_ops.matmul(x, w0)

        def branch1(x):
          return math_ops.matmul(x, w1)

        def branch2(x):
          x = array_ops.pad(x, [[0, 0], [0, 1]])
          return x + w2

        @def_function.function(input_signature=[
            tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32),
            tensor_spec.TensorSpec(shape=[10, 3], dtype=dtypes.float32),
        ])
        def model(i, x):
          return control_flow_switch_case.switch_case(
              i, [lambda: branch0(x), lambda: branch1(x), lambda: branch2(x)])

        root, output_func = self._freezeModel(model)
        self._testConvertedFunction(sess, root, root.f, output_func, input_data)


class ConvertVariablesToConstantsSessionTest(test.TestCase):

  def _assertGraphContains(self, graph, subgraph):
    """Asserts that the given subgraph is contained within the given graph."""

    def normalize_uids(msg):
      """Replace auto-id function names with something consistent."""
      # These functions have non-deterministic names, the non-determinism coming
      # from having an ops.uid() suffix in their names. We're replacing these
      # with new sequential IDs starting from 0 for each prefix, which is
      # is sufficient for tests.
      if isinstance(msg, graph_pb2.GraphDef):
        msg = text_format.MessageToString(msg)
      name_prefixes = ["case_cond_true.*", "case_cond_false.*"]
      name_regex = r"\b(" + "|".join(name_prefixes) + r")_([0-9]+)\b"
      names = {}
      for (name, index) in re.findall(name_regex, msg):
        names.setdefault(name, set()).add(int(index))
      for name, indices in names.items():
        for new_index, old_index in enumerate(sorted(list(indices))):
          msg = re.sub(r"\b" + name + "_" + str(old_index) + r"\b",
                       name + "_" + str(new_index), msg)
      return msg

    norm_graph = text_format.Parse(normalize_uids(graph), graph_pb2.GraphDef())
    norm_subgraph = text_format.Parse(
        normalize_uids(subgraph), graph_pb2.GraphDef())

    # Graph S is contained in C if and only if merge(C,S) == C.
    # We merge the input graph with an empty graph to normalize repeated fields:
    # assertProtoEquals is sensitive to ordering.
    norm_graph = _GraphMerger.merge_graphs(norm_graph, graph_pb2.GraphDef())
    merged_graph = _GraphMerger.merge_graphs(norm_graph, norm_subgraph)
    self.assertProtoEquals(norm_graph, merged_graph)

  def _ensure_no_variables_in_graph(self, graph_def):
    """Ensures there are no variables in the graph."""
    for node in graph_def.node:
      self.assertNotIn(
          node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])

  def _test_variable_to_const_conversion(self, use_resource):
    with ops.Graph().as_default():
      with variable_scope.variable_scope("", use_resource=use_resource):
        variable_node = variable_scope.get_variable(
            "variable_node", initializer=1.0)
        variable_scope.get_variable("unused_variable_node", initializer=1.0)
        output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
        with session_lib.Session() as sess:
          self.evaluate(variable_node.initializer)
          output = self.evaluate(output_node)
          self.assertNear(2.0, output, 0.00001)
          variable_graph_def = sess.graph.as_graph_def()
          constant_graph_def = (
              convert_to_constants
              .convert_variables_to_constants_from_session_graph(
                  session=sess,
                  graph_def=variable_graph_def,
                  output_node_names=["output_node"]))

          self._ensure_no_variables_in_graph(constant_graph_def)

    # Now we make sure the variable is now a constant, and that the graph still
    # produces the expected result.
    with ops.Graph().as_default():
      _ = importer.import_graph_def(constant_graph_def, name="")
      self.assertEqual(4, len(constant_graph_def.node))
      self._ensure_no_variables_in_graph(constant_graph_def)
      with session_lib.Session() as sess:
        output_node = sess.graph.get_tensor_by_name("output_node:0")
        output = self.evaluate(output_node)
        self.assertNear(2.0, output, 0.00001)

  def test_resource_variable_can_be_written_after_denylisting(self):
    with ops.Graph().as_default():
      with variable_scope.variable_scope("", use_resource=True):
        variable_node = variable_scope.get_variable(
            "variable_node", initializer=1.0)
        another_variable = variable_scope.get_variable(
            "unused_variable_node", initializer=2.0)
        with ops.control_dependencies(
            [variable_node.assign(another_variable + variable_node)]):
          output_node = array_ops.identity(variable_node, name="output_node")
        initializer_name = variable_node.initializer.name
        with session_lib.Session() as sess:
          self.evaluate(variable_node.initializer)
          self.evaluate(another_variable.initializer)
          output = self.evaluate(output_node)
          self.assertNear(3.0, output, 0.00001)
          variable_graph_def = sess.graph.as_graph_def()

          # Test variable name black list. This should result in the variable
          # not being a const.  Furthermore, the paths that read from and assign
          # to the denylisted variable should continue to be valid.
          constant_graph_def_with_denylist = (
              convert_to_constants
              .convert_variables_to_constants_from_session_graph(
                  session=sess,
                  graph_def=variable_graph_def,
                  output_node_names=["output_node", initializer_name],
                  variable_names_denylist=set(["variable_node"])))

          variable_node = None
          for node in constant_graph_def_with_denylist.node:
            if node.name == "variable_node":
              variable_node = node
          self.assertIsNotNone(variable_node)
          self.assertEqual(variable_node.op, "VarHandleOp")

    # Now we make sure another_variable is now a constant, but the original
    # variable is not, and that the graph can be executed and update the
    # variable can be updated with each execution.
    with ops.Graph().as_default():
      _ = importer.import_graph_def(constant_graph_def_with_denylist, name="")
      with session_lib.Session() as sess:
        output_node = sess.graph.get_tensor_by_name("output_node:0")
        self.evaluate(sess.graph.get_operation_by_name(initializer_name))
        output = self.evaluate(output_node)
        self.assertNear(3.0, output, 0.00001)
        output = self.evaluate(output_node)
        self.assertNear(5.0, output, 0.00001)

  def _inline_functions(self, graph_def, arrays):
    meta_graph = export_meta_graph(graph_def=graph_def)
    fetch_collection = meta_graph_pb2.CollectionDef()
    for name in arrays:
      fetch_collection.node_list.value.append(name)
    meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)

    # Initialize RewriterConfig with everything disabled except function
    # inlining.
    config = config_pb2.ConfigProto()
    rewrite_options = config.graph_options.rewrite_options
    rewrite_options.optimizers.append("function")
    return tf_optimizer.OptimizeGraph(config, meta_graph)

  def _test_convert_variables_with_functions(self, inline_functions):
    """Freezes a graph with functions."""

    @function.Defun(dtypes.float32)
    def plus_one(x):
      return x + 1.0

    with ops.Graph().as_default():
      variable_node = variables.Variable(1.0, name="variable_node")
      _ = variables.Variable(1.0, name="unused_variable_node")
      defun_node = plus_one(variable_node)
      _ = math_ops.multiply(defun_node, 2.0, name="output_node")

      with session_lib.Session() as sess:
        self.evaluate(variables.variables_initializer([variable_node]))
        variable_graph_def = sess.graph.as_graph_def()

        if inline_functions:
          # Run Grappler to create the VarOpHandle --> Placeholder -->
          # ResourceVariable pattern.
          variable_graph_def = self._inline_functions(
              variable_graph_def, ["variable_node", "output_node"])

        constant_graph_def = (
            convert_to_constants
            .convert_variables_to_constants_from_session_graph(
                session=sess,
                graph_def=variable_graph_def,
                output_node_names=["output_node"]))

    self._ensure_no_variables_in_graph(constant_graph_def)

  def testReferenceVariables(self):
    """Freezes a graph with reference variables."""
    self._test_variable_to_const_conversion(use_resource=False)

  def testResourceVariables(self):
    """Freezes a graph with resource variables."""
    self._test_variable_to_const_conversion(use_resource=True)

  def testWithFunctions(self):
    """Freezes a graph with functions."""
    self._test_convert_variables_with_functions(inline_functions=False)

  def testWithInlinedFunctions(self):
    """Freezes a graph with functions that have been inlined using Grappler."""
    self._test_convert_variables_with_functions(inline_functions=True)

  def testGraphWithSwitch(self):
    """Freezes a graph which contains a Switch with type RESOURCE_DT."""
    with ops.Graph().as_default():
      with variable_scope.variable_scope("", use_resource=True):
        x = variable_scope.get_variable("var_x", initializer=1.0)
        y = variable_scope.get_variable("var_y", initializer=2.0)
        f1 = lambda: variable_scope.get_variable("var_f1", initializer=17.0)
        f2 = lambda: variable_scope.get_variable("var_f2", initializer=23.0)
        cond_node = control_flow_case.case([(gen_math_ops.less(x, y), f1)],
                                           default=f2)
        _ = math_ops.multiply(cond_node, 2.0, name="output_node")

        with session_lib.Session() as sess:
          sess.run(variables.global_variables_initializer())
          variable_graph_def = sess.graph.as_graph_def()

          constant_graph_def = (
              convert_to_constants
              .convert_variables_to_constants_from_session_graph(
                  session=sess,
                  graph_def=variable_graph_def,
                  output_node_names=["output_node"]))

    self._ensure_no_variables_in_graph(constant_graph_def)

  def testConvertSingleVariable(self):
    """Tests that a single variable is properly converted to a constant."""

    with ops.Graph().as_default():
      with variable_scope.variable_scope("", use_resource=False):
        _ = variable_scope.get_variable("x", initializer=1.0)
      with session_lib.Session() as sess:
        sess.run(variables.global_variables_initializer())
        variable_graph_def = sess.graph.as_graph_def()
        constant_graph_def = (
            convert_to_constants
            .convert_variables_to_constants_from_session_graph(
                sess, variable_graph_def, ["x/read"]))
        self._assertGraphContains(
            constant_graph_def, """
            node {
              name: "x" op: "Const"
              attr { key: "dtype" value { type: DT_FLOAT } }
              attr {
                key: "value"
                value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
            }
            node {
              name: "x/read" op: "Identity" input: "x"
              attr { key: "T" value { type: DT_FLOAT } }
            }""")

  def testConvertSingleResourceVariable(self):
    """Tests that a resource variable is properly converted to a constant."""
    with ops.Graph().as_default():
      with variable_scope.variable_scope("", use_resource=True):
        _ = variable_scope.get_variable("x", initializer=1.0)
      with session_lib.Session() as sess:
        sess.run(variables.global_variables_initializer())
        variable_graph_def = sess.graph.as_graph_def()
        constant_graph_def = (
            convert_to_constants
            .convert_variables_to_constants_from_session_graph(
                sess, variable_graph_def, ["x/Read/ReadVariableOp"]))
        self._assertGraphContains(
            constant_graph_def, """
            node {
              name: "x" op: "Const"
              attr { key: "dtype" value { type: DT_FLOAT } }
              attr {
                key: "value"
                value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
            }
            node {
              name: "x/Read/ReadVariableOp" op: "Identity" input: "x"
              attr { key: "T" value { type: DT_FLOAT } }
            }""")

  def testConvertOneVariableOfTwo(self):
    """Tests that one variable can be kept unconverted."""
    with ops.Graph().as_default():
      with variable_scope.variable_scope("", use_resource=False):
        x = variable_scope.get_variable("x", initializer=1.0)
        y = variable_scope.get_variable("y", initializer=1.0)
        _ = math_ops.multiply(x, y, name="out")
      with session_lib.Session() as sess:
        sess.run(variables.global_variables_initializer())
        variable_graph_def = sess.graph.as_graph_def()
        constant_graph_def = (
            convert_to_constants
            .convert_variables_to_constants_from_session_graph(
                sess,
                variable_graph_def, ["out"],
                variable_names_denylist=["y"]))
        self._assertGraphContains(
            constant_graph_def, """
            node {
              name: "x" op: "Const"
              attr { key: "dtype" value { type: DT_FLOAT } }
              attr {
                key: "value"
                value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
            }
            node {
              name: "x/read" op: "Identity" input: "x"
              attr { key: "T" value { type: DT_FLOAT } }
            }
            node {
              name: "y" op: "VariableV2"
              attr { key: "dtype" value { type: DT_FLOAT } }
            }
            node {
              name: "y/read" op: "Identity" input: "y"
              attr { key: "T" value { type: DT_FLOAT } }
            }
            node {
              name: "out" op: "Mul" input: "x/read" input: "y/read"
              attr {key: "T" value {type: DT_FLOAT}}
            }""")

  def testConvertOneResourceVariableOfTwo(self):
    """Tests that one variable can be kept unconverted."""
    with ops.Graph().as_default():
      with variable_scope.variable_scope("", use_resource=True):
        x = variable_scope.get_variable("x", initializer=1.0)
        y = variable_scope.get_variable("y", initializer=1.0)
        _ = math_ops.multiply(x, y, name="out")
      with session_lib.Session() as sess:
        sess.run(variables.global_variables_initializer())
        variable_graph_def = sess.graph.as_graph_def()
        constant_graph_def = (
            convert_to_constants
            .convert_variables_to_constants_from_session_graph(
                sess,
                variable_graph_def, ["out"],
                variable_names_denylist=["y"]))
        self._assertGraphContains(
            constant_graph_def, """
            node {
              name: "x" op: "Const"
              attr { key: "dtype" value { type: DT_FLOAT } }
              attr {
                key: "value"
                value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
            }
            node {
              name: "y" op: "VarHandleOp"
              attr { key: "dtype" value { type: DT_FLOAT } }
            }
            node {
              name: "out/ReadVariableOp" op: "Identity" input: "x"
              attr { key: "T" value { type: DT_FLOAT } }
            }
            node {
              name: "out/ReadVariableOp_1" op: "ReadVariableOp" input: "y"
              attr { key: "dtype" value { type: DT_FLOAT } }
            }
            node {
              name: "out" op: "Mul"
              input: "out/ReadVariableOp" input: "out/ReadVariableOp_1"
              attr {key: "T" value {type: DT_FLOAT}}
            }""")

  def testConvertIdentityChain(self):
    """Tests that a chain of Identity ops is converted properly."""
    with ops.Graph().as_default():
      with variable_scope.variable_scope("", use_resource=True):
        x = variable_scope.get_variable("x", initializer=1.0)
        y = array_ops.identity(x, name="y")
        _ = array_ops.identity(y, name="z")
      with session_lib.Session() as sess:
        sess.run(variables.global_variables_initializer())
        variable_graph_def = sess.graph.as_graph_def()
        constant_graph_def = (
            convert_to_constants
            .convert_variables_to_constants_from_session_graph(
                sess, variable_graph_def, ["z"]))
        self._assertGraphContains(
            constant_graph_def, """
            node {
              name: "x" op: "Const"
              attr { key: "dtype" value { type: DT_FLOAT } }
              attr {
                key: "value"
                value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
            }
            node {
              name: "y/ReadVariableOp" op: "Identity" input: "x"
              attr { key: "T" value { type: DT_FLOAT } }
            }
            node {
              name: "y" op: "Identity" input: "y/ReadVariableOp"
              attr { key: "T" value { type: DT_FLOAT } }
            }
            node {
              name: "z" op: "Identity" input: "y"
              attr { key: "T" value { type: DT_FLOAT } }
            }""")

  def testConvertCase(self):
    """Tests that a v1 case() construction converts properly."""
    with ops.Graph().as_default():
      with variable_scope.variable_scope("", use_resource=False):
        control_flow_v2_toggles.disable_control_flow_v2()
        x = variable_scope.get_variable("x", initializer=1.0)
        y = variable_scope.get_variable("y", initializer=2.0)
        _ = control_flow_case.case([(gen_math_ops.less(x, y), lambda: x)],
                                   default=lambda: y)
      with session_lib.Session() as sess:
        sess.run(variables.global_variables_initializer())
        variable_graph_def = sess.graph.as_graph_def()
        constant_graph_def = (
            convert_to_constants
            .convert_variables_to_constants_from_session_graph(
                sess, variable_graph_def, ["case/cond/Merge"]))
        self._assertGraphContains(
            constant_graph_def, """
            node {
              name: "x" op: "Const"
              attr { key: "dtype" value { type: DT_FLOAT } }
              attr {
                key: "value"
                value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
            }
            node {
              name: "y" op: "Const"
              attr { key: "dtype" value { type: DT_FLOAT } }
              attr {
                key: "value"
                value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 2 }}}
            }
            node {name: "x/read" op: "Identity" input: "x"}
            node {name: "y/read" op: "Identity" input: "y"}
            node {name: "Less" op: "Less" input: "x/read" input: "y/read"}
            node {name: "case/cond/pred_id" op: "Identity" input: "Less"}
            node {
              name: "case/cond/Switch_1" op: "Switch"
              input: "case/cond/pred_id" input: "x/read"
            }
            node {
              name: "case/cond/Switch_2" op: "Switch"
              input: "case/cond/pred_id" input: "y/read"
            }
            node {
              name: "case/cond/Merge" op: "Merge"
              input: "case/cond/Switch_2" input: "case/cond/Switch_1:1"
              attr {key: "T" value {type: DT_FLOAT}}
            }""")

  def testConvertV2Case(self):
    """Tests that a v2 case() converts properly."""
    with ops.Graph().as_default():
      with variable_scope.variable_scope("", use_resource=False):
        control_flow_v2_toggles.enable_control_flow_v2()
        a = variable_scope.get_variable("a", initializer=2.0)
        x = variable_scope.get_variable("x", initializer=1.0)
        y = variable_scope.get_variable("y", initializer=2.0)
        _ = control_flow_case.case([(gen_math_ops.less(x, y), lambda: a)],
                                   default=lambda: y)
        control_flow_v2_toggles.disable_control_flow_v2()
      with session_lib.Session() as sess:
        sess.run(variables.global_variables_initializer())
        variable_graph_def = sess.graph.as_graph_def()
        constant_graph_def = (
            convert_to_constants
            .convert_variables_to_constants_from_session_graph(
                sess, variable_graph_def, ["case/cond"]))
        self._assertGraphContains(
            constant_graph_def, """
            node {
              name: "x" op: "Const"
              attr { key: "dtype" value { type: DT_FLOAT } }
              attr {
                key: "value"
                value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
            }
            node {
              name: "y" op: "Const"
              attr { key: "dtype" value { type: DT_FLOAT } }
              attr {
                key: "value"
                value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 2 }}}
            }
            node {name: "x/read" op: "Identity" input: "x"}
            node {name: "y/read" op: "Identity" input: "y"}
            node {name: "Less" op: "Less" input: "x/read" input: "y/read"}
            node {
              name: "case/cond" op: "StatelessIf"
              input: "Less" input: "a/read" input: "y/read"
              attr {key: "Tcond" value {type: DT_BOOL}}
              attr {key: "Tin" value {list {type: DT_FLOAT type: DT_FLOAT}}}
              attr {key: "Tout" value {list {type: DT_FLOAT}}}
            }
            library {
              function {
                signature {
                  name: "case_cond_false_frozen_0"
                  input_arg {name: "placeholder" type: DT_FLOAT}
                  input_arg {name: "y_read_0" type: DT_FLOAT}
                  output_arg {name: "y_read" type: DT_FLOAT}
                }
              }
              function {
                signature {
                  name: "case_cond_true_frozen_0"
                  input_arg {name: "a_read_0" type: DT_FLOAT}
                  input_arg {name: "placeholder" type: DT_FLOAT}
                  output_arg {name: "a_read" type: DT_FLOAT}
                }
              }
            }""")

  def testConvertV2ResourceCase(self):
    """Tests that a v2 case() with resource variables converts properly."""
    with ops.Graph().as_default():
      with variable_scope.variable_scope("", use_resource=True):
        control_flow_v2_toggles.enable_control_flow_v2()
        x = variable_scope.get_variable("x", initializer=1.0)
        y = variable_scope.get_variable("y", initializer=2.0)
        _ = control_flow_case.case([(gen_math_ops.less(x, y), lambda: x)],
                                   default=lambda: y)
        control_flow_v2_toggles.disable_control_flow_v2()
      with session_lib.Session() as sess:
        sess.run(variables.global_variables_initializer())
        variable_graph_def = sess.graph.as_graph_def()
        constant_graph_def = (
            convert_to_constants
            .convert_variables_to_constants_from_session_graph(
                sess, variable_graph_def, ["case/cond"]))
        self._assertGraphContains(
            constant_graph_def, """
            node {name: "x" op: "Const"}
            node {name: "y" op: "Const"}
            node {
              name: "case/cond" op: "If" input: "Less" input: "x" input: "y"
              attr {key: "Tcond" value {type: DT_BOOL}}
              attr {key: "Tin" value {list {type: DT_FLOAT type: DT_FLOAT}}}
              attr {key: "Tout" value {list {type: DT_FLOAT}}}
            }
            library {
              function {
                signature {
                  name: "case_cond_false_frozen_0"
                  input_arg {name: "placeholder" type: DT_FLOAT}
                  input_arg {name: "readvariableop_y" type: DT_FLOAT}
                  output_arg {name: "readvariableop" type: DT_FLOAT}
                }
              }
              function {
                signature {
                  name: "case_cond_true_frozen_0"
                  input_arg {name: "placeholder" type: DT_FLOAT}
                  input_arg {name: "readvariableop_x" type: DT_FLOAT}
                  output_arg {name: "readvariableop" type: DT_FLOAT}
                }
              }
            }""")

  def testConvertV2UnconvertedResourceNestedCase(self):
    """Tests unconverted variable propagation through nested functions."""
    with ops.Graph().as_default():
      with variable_scope.variable_scope("", use_resource=True):
        control_flow_v2_toggles.enable_control_flow_v2()
        x = variable_scope.get_variable("x", initializer=1.0)
        y = variable_scope.get_variable("y", initializer=2.0)
        z = variable_scope.get_variable("z", initializer=3.0)
        # pylint: disable=g-long-lambda
        _ = control_flow_case.case(
            [(gen_math_ops.less(x, y), lambda: x)],
            default=lambda: control_flow_case.case(
                [(gen_math_ops.less(z, y), lambda: z)], default=lambda: y))
        # pylint: enable=g-long-lambda
        control_flow_v2_toggles.disable_control_flow_v2()
      with session_lib.Session() as sess:
        sess.run(variables.global_variables_initializer())
        variable_graph_def = sess.graph.as_graph_def()
        constant_graph_def = (
            convert_to_constants
            .convert_variables_to_constants_from_session_graph(
                sess,
                variable_graph_def, ["case/cond"],
                variable_names_denylist=["y"]))
        self._assertGraphContains(
            constant_graph_def, """
            node {name: "x" op: "Const"}
            node {name: "y" op: "VarHandleOp"}
            node {name: "z" op: "Const"}

            node {name: "Less/ReadVariableOp" op: "Identity" input: "x"}
            node {name: "Less/ReadVariableOp_1" op: "ReadVariableOp" input: "y"}

            node {
              name: "case/cond" op: "If"
              input: "x" input: "z" input: "y"
              attr {
                key: "Tin"
                value {list
                  {type: DT_FLOAT type: DT_FLOAT type: DT_RESOURCE}}}
              attr {
                key: "_read_only_resource_inputs"
                value {list {i: 1 i: 2 i: 3}}}
              attr {key: "then_branch"
                    value {func {name: "case_cond_true_frozen_0"}}}
              attr {key: "else_branch"
                    value {func {name: "case_cond_false_frozen_0"}}}
              attr {key: "output_shapes" value {list {shape {}}}}
            }
            library {
              function {
                signature {
                  name: "case_cond_true_frozen_0"
                  input_arg {name: "placeholder" type: DT_FLOAT}
                  input_arg {name: "placeholder_1" type: DT_RESOURCE}
                  input_arg {name: "readvariableop_x" type: DT_FLOAT}
                  output_arg {name: "readvariableop" type: DT_FLOAT}
                  is_stateful: true
                }

                node_def {name: "ReadVariableOp" op: "Identity"
                  input: "readvariableop_x"}}

              function {
                signature {
                  name: "case_cond_false_frozen_0"
                  input_arg {name: "placeholder" type: DT_FLOAT}
                  input_arg {name: "less_readvariableop_1_y" type: DT_RESOURCE}
                  input_arg {name: "less_readvariableop_z" type: DT_FLOAT}
                  output_arg {name: "case_cond_identity" type: DT_FLOAT}
                  is_stateful: true
                }

                node_def {name: "Less/ReadVariableOp_1" op: "ReadVariableOp"
                  input: "less_readvariableop_1_y"}

                node_def {name: "Less/ReadVariableOp" op: "Identity"
                  input: "less_readvariableop_z"}

                node_def {name: "case/cond" op: "If"
                  input: "less_readvariableop_z"
                  input: "less_readvariableop_1_y"
                  attr {
                    key: "Tin"
                    value {list {type: DT_FLOAT type: DT_RESOURCE}}}
                  attr {key: "then_branch"
                        value {func {name: "case_cond_true_frozen_1"}}}
                  attr {key: "else_branch"
                        value {func {name: "case_cond_false_frozen_1"}}}
                  attr {
                    key: "_read_only_resource_inputs"
                    value {list {i: 1 i: 2}}}}}

              function {
                signature {
                  name: "case_cond_false_frozen_1"
                  input_arg {name: "placeholder" type: DT_FLOAT}
                  input_arg {name: "readvariableop_y" type: DT_RESOURCE}
                  output_arg {name: "readvariableop" type: DT_FLOAT}
                  is_stateful: true
                }

                node_def {name: "ReadVariableOp" op: "ReadVariableOp"
                  input: "readvariableop_y"}}

              function {
                signature {
                  name: "case_cond_true_frozen_1"
                  input_arg {name: "placeholder" type: DT_RESOURCE}
                  input_arg {name: "readvariableop_z" type: DT_FLOAT}
                  output_arg {name: "readvariableop" type: DT_FLOAT}
                  is_stateful: true
                }

                node_def {name: "ReadVariableOp" op: "Identity"
                  input: "readvariableop_z"}}}""")

  def _addNoinlineAttributeToFunction(self, saved_model_dir, func_name):
    saved_model_proto = loader_impl.parse_saved_model(saved_model_dir)
    new_saved_model = saved_model_pb2.SavedModel()
    new_saved_model.CopyFrom(saved_model_proto)
    new_meta_graph_def = new_saved_model.meta_graphs[0]
    prefix_len = len("__inference_")
    for func_def in new_meta_graph_def.graph_def.library.function:
      func_name_without_prefix = func_def.signature.name[prefix_len:]
      if func_name_without_prefix.startswith(func_name):
        func_def.attr["_noinline"].CopyFrom(attr_value_pb2.AttrValue(b=True))
    old_saved_model_file = os.path.join(saved_model_dir,
                                        constants.SAVED_MODEL_FILENAME_PB)
    if os.path.exists(old_saved_model_file):
      os.remove(old_saved_model_file)
    path = os.path.join(
        compat.as_bytes(saved_model_dir),
        compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
    file_io.write_string_to_file(
        path, new_saved_model.SerializeToString(deterministic=True))

  @test_util.run_v2_only
  def testVariableModelWithFunctionAndFunctionInliningDisabled(self):
    """Test a model with Variables and disable function inlining."""

    class BasicModel:

      def __init__(self):
        self.v1 = None
        self.v2 = variables.Variable(2.)

      @def_function.function(input_signature=[
          tensor_spec.TensorSpec(shape=[1], dtype=dtypes.float32)
      ])
      def add_all(self, x):
        if self.v1 is None:
          self.v1 = variables.Variable(3.)
        return x + self.v1 + self.v2

      def run(self, x):
        y = self.add_all(x)
        return y

    save_dir = os.path.join(self.get_temp_dir(), "frozen_saved_model")
    with ops.Graph().as_default():
      model = BasicModel()
      a = array_ops.placeholder(dtypes.float32, shape=[1])
      b = model.run(a)
      with session_lib.Session() as sess:
        sess.run(variables.global_variables_initializer())
        simple_save.simple_save(sess, save_dir, {"myinput": a}, {"myoutput": b})

    # Add _noinline to the SavedModel.
    self._addNoinlineAttributeToFunction(
        saved_model_dir=save_dir, func_name="add_all")

    saved_model = load(save_dir)
    func = saved_model.signatures["serving_default"]
    frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
    constant_graph_def = frozen_func.graph.as_graph_def()
    self._ensure_no_variables_in_graph(constant_graph_def)


if __name__ == "__main__":
  test.main()