tensorflow/tensorflow

View on GitHub
tensorflow/python/kernel_tests/control_flow/control_flow_ops_py_test.py

Summary

Maintainability
F
1 mo
Test Coverage
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OiR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# pylint: disable=g-long-lambda
"""Tests for tensorflow.ops.control_flow_ops."""

import collections
import math
import re
import sys
import time

from absl.testing import parameterized
import numpy as np

from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import tf2
from tensorflow.python.client import device_lib
from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function as eager_def_function
from tensorflow.python.eager import wrap_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import function
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor as tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import cond as tf_cond
from tensorflow.python.ops import control_flow_assert
from tensorflow.python.ops import control_flow_case
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import gen_logging_ops
from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops import while_loop as while_loop_tf
from tensorflow.python.ops import while_v2  # pylint: disable=unused-import
# pylint: disable=unused-import
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
import tensorflow.python.ops.tensor_array_grad
# pylint: enable=unused-import
from tensorflow.python.platform import test
from tensorflow.python.training import adam
from tensorflow.python.training import gradient_descent
from tensorflow.python.util import nest
from tensorflow.python.ops import control_flow_switch_case
from tensorflow.python.ops import variable_v1


def check_consumers(graph):
  """Sanity check on the consumer list of the tensors."""

  consumer_count = {}
  for op in graph.get_operations():
    for v in op.inputs:
      cnt = consumer_count.get(v, 0)
      consumer_count[v] = cnt + 1
  for k, v in consumer_count.items():
    if len(k.consumers()) != v:
      return False
  return True


def all_fetchables():
  tensor_names = []
  graph = ops.get_default_graph()
  for op in graph.get_operations():
    for t in op.outputs:
      if graph.is_fetchable(t):
        tensor_names.append(t.name)
  return tensor_names


def all_feedables():
  feedable_tensors = []
  graph = ops.get_default_graph()
  for op in graph.get_operations():
    for t in op.inputs:
      if graph.is_feedable(t):
        feedable_tensors.append(t)
  return feedable_tensors


def opt_cfg(do_constant_folding=True):
  return config_pb2.ConfigProto(
      allow_soft_placement=True,
      graph_options=config_pb2.GraphOptions(
          optimizer_options=config_pb2.OptimizerOptions(
              opt_level=config_pb2.OptimizerOptions.L1,
              do_function_inlining=True,
              do_constant_folding=do_constant_folding)))


def isum(s, maximum_iterations=None):
  i = constant_op.constant(0, name="i")
  c = lambda i, s: math_ops.less(i, 10)
  b = lambda i, s: [math_ops.add(i, 1), math_ops.add(i, s)]
  _, r_s = while_loop_tf.while_loop(
      c, b, [i, s], maximum_iterations=maximum_iterations)
  return r_s


def enqueue_print_op(s):
  """Enqueues an op that prints a message to be captured in the test."""
  return logging_ops.print_v2("ControlFlowOpsTest: " + s)


def filter_test_messages(s):
  """Returns a list of messages printed by enqueue_print_op."""
  prefix = "ControlFlowOpsTest: "
  return [l[len(prefix):] for l in s.split("\n") if l.startswith(prefix)]


def tf_function_in_tf2(f):
  if tf2.enabled():
    # In TF1 do not wrap with tf.function so that we can test the v1 control
    # flow code path.
    return eager_def_function.function(f)
  return f


@test_util.with_eager_op_as_function
@test_util.with_control_flow_v2
class ControlFlowTest(test.TestCase, parameterized.TestCase):

  @test_util.run_v1_only("b/120545219")
  def testRefIdentity(self):
    with self.cached_session():
      v = variable_v1.VariableV1(7)

      v = control_flow_ops._Identity(v)
      op = state_ops.assign(v, 9)
      v2 = control_flow_ops.with_dependencies([op], v)

      self.assertTrue(isinstance(v2, tensor_lib.Tensor))
      self.evaluate(variables.global_variables_initializer())
      self.assertEqual(9, self.evaluate(v2))

  @test_util.run_v1_only("b/120545219")
  def testRefEnter(self):
    with self.cached_session():
      v = variable_v1.VariableV1(7)

      enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True)
      nine = constant_op.constant(9)
      enter_nine = gen_control_flow_ops.enter(nine, "foo_1")
      op = state_ops.assign(enter_v, enter_nine)
      v2 = control_flow_ops.with_dependencies([op], enter_v)
      v3 = control_flow_ops.exit(v2)
      self.evaluate(variables.global_variables_initializer())
      self.assertEqual(9, self.evaluate(v3))

  @test_util.run_v1_only("b/120545219")
  def testRefSwitch(self):
    with self.cached_session():
      v = variable_v1.VariableV1(7)

      p = constant_op.constant(True)
      v1 = control_flow_ops._SwitchRefOrTensor(v._ref(), p)  # pylint: disable=protected-access
      v2 = state_ops.assign(v1[1], 9)
      self.evaluate(variables.global_variables_initializer())
      self.assertEqual(9, self.evaluate(v2))

  def testEnterMulExit(self):
    with self.cached_session():
      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
      enter_data = gen_control_flow_ops.enter(data, "foo_1", False)
      five = constant_op.constant(5)
      enter_five = gen_control_flow_ops.enter(five, "foo_1", False)
      mul_op = math_ops.multiply(enter_data, enter_five)
      exit_op = control_flow_ops.exit(mul_op)

      result = self.evaluate(exit_op)
    self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)

  @test_util.run_deprecated_v1
  def testEnterShapePropagation(self):
    with self.cached_session():
      v = variables.Variable([0.0, 0.0], dtype=dtypes.float32)

      # If is_constant=True, the shape information should be propagated.
      enter_v_constant = gen_control_flow_ops.enter(
          v, "frame1", is_constant=True)
      self.assertEqual(enter_v_constant.shape, [2])

      # Otherwise, the shape should be unknown.
      enter_v_non_constant = gen_control_flow_ops.enter(
          v, "frame2", is_constant=False)
      self.assertEqual(enter_v_non_constant.shape, None)

  @test_util.run_v1_only("b/120545219")
  def testSwitchMergeIndexedSlices(self):
    with self.cached_session():
      values = constant_op.constant([1, 2, 3, 4, 5, 6])
      indices = constant_op.constant([0, 2, 4, 6, 8, 10])
      data = indexed_slices.IndexedSlices(values, indices)
      pred = ops.convert_to_tensor(True)
      switch_op = control_flow_ops.switch(data, pred)
      merge_op = control_flow_ops.merge(switch_op)[0]

      val = merge_op.values
      ind = merge_op.indices
    self.assertAllEqual(np.arange(1, 7), val)
    self.assertAllEqual(np.arange(0, 12, 2), ind)

  @test_util.run_v1_only("b/120545219")
  def testSwitchDeadBranch(self):
    with self.cached_session():
      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
      ports = ops.convert_to_tensor(True, name="ports")
      switch_op = control_flow_ops.switch(data, ports)
      dead_branch = array_ops.identity(switch_op[0])

      with self.assertRaisesWithPredicateMatch(
          errors_impl.InvalidArgumentError,
          lambda e: "Retval[0] does not have value" in str(e)):
        self.evaluate(dead_branch)

  @test_util.run_v1_only("b/120545219")
  def testSwitchMergeLess(self):
    with self.cached_session():
      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
      zero = ops.convert_to_tensor(0)
      one = ops.convert_to_tensor(1)
      less_op = math_ops.less(zero, one)
      switch_op = control_flow_ops.switch(data, less_op)
      merge_op = control_flow_ops.merge(switch_op)[0]

      result = self.evaluate(merge_op)
    self.assertAllEqual(np.arange(1, 7), result)

  @test_util.run_v1_only("b/120545219")
  def testSwitchMergeAddIdentity(self):
    with self.cached_session():
      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
      ports = ops.convert_to_tensor(False, name="ports")
      switch_op = control_flow_ops.switch(data, ports)
      one = constant_op.constant(1)
      add_op = math_ops.add(switch_op[0], one)
      id_op = array_ops.identity(switch_op[1])
      merge_op = control_flow_ops.merge([add_op, id_op])[0]

      result = self.evaluate(merge_op)
    self.assertAllEqual(np.array([x + 1 for x in [1, 2, 3, 4, 5, 6]]), result)

  @test_util.run_v1_only("b/120545219")
  def testSwitchMergeAddMul(self):
    with self.cached_session():
      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
      ports = ops.convert_to_tensor(True, name="ports")
      switch_op = control_flow_ops.switch(data, ports)
      one = constant_op.constant(1)
      add_op = math_ops.add(switch_op[0], one)
      five = constant_op.constant(5)
      mul_op = math_ops.multiply(switch_op[1], five)
      merge_op = control_flow_ops.merge([add_op, mul_op])[0]

      result = self.evaluate(merge_op)
    self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)

  @test_util.run_v1_only("b/120545219")
  def testLoop_false(self):
    with self.cached_session():
      false = ops.convert_to_tensor(False)
      n = constant_op.constant(10)

      enter_false = gen_control_flow_ops.enter(false, "foo_1", False)
      enter_n = gen_control_flow_ops.enter(n, "foo_1", False)

      merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0]
      switch_n = control_flow_ops.switch(merge_n, enter_false)
      exit_n = control_flow_ops.exit(switch_n[0])
      next_n = control_flow_ops.next_iteration(switch_n[0])
      merge_n.op._update_input(1, next_n)

      result = self.evaluate(exit_n)
    self.assertAllEqual(10, result)

  @test_util.run_deprecated_v1
  def testLoop_1(self):
    with self.cached_session():
      zero = constant_op.constant(0)
      one = constant_op.constant(1)
      n = constant_op.constant(10)

      enter_i = gen_control_flow_ops.enter(zero, "foo", False)
      enter_one = gen_control_flow_ops.enter(one, "foo", True)
      enter_n = gen_control_flow_ops.enter(n, "foo", True)

      with ops.device(test.gpu_device_name()):
        merge_i = control_flow_ops.merge([enter_i, enter_i])[0]

      less_op = math_ops.less(merge_i, enter_n)
      cond_op = control_flow_ops.loop_cond(less_op)
      switch_i = control_flow_ops.switch(merge_i, cond_op)

      add_i = math_ops.add(switch_i[1], enter_one)

      next_i = control_flow_ops.next_iteration(add_i)
      merge_i.op._update_input(1, next_i)

      exit_i = control_flow_ops.exit(switch_i[0])
      result = self.evaluate(exit_i)
    self.assertAllEqual(10, result)

  @test_util.run_v1_only("b/120545219")
  def testLoop_2(self):
    with self.cached_session():
      zero = constant_op.constant(0)
      one = constant_op.constant(1)
      n = constant_op.constant(10)

      enter_i = gen_control_flow_ops.enter(zero, "foo", False)
      enter_one = gen_control_flow_ops.enter(one, "foo", True)
      enter_n = gen_control_flow_ops.enter(n, "foo", True)

      merge_i = control_flow_ops.merge([enter_i, enter_i])[0]

      less_op = math_ops.less(merge_i, enter_n)
      cond_op = control_flow_ops.loop_cond(less_op)
      switch_i = control_flow_ops.switch(merge_i, cond_op)

      add_i = math_ops.add(switch_i[1], enter_one)

      with ops.device(test.gpu_device_name()):
        next_i = control_flow_ops.next_iteration(add_i)
      merge_i.op._update_input(1, next_i)

      exit_i = control_flow_ops.exit(switch_i[0])
      result = self.evaluate(exit_i)
    self.assertAllEqual(10, result)

  @test_util.run_v1_only("b/120545219")
  def testDifferentFrame(self):
    with self.cached_session():
      data = array_ops.placeholder(dtypes.float32, shape=[])
      enter_1 = gen_control_flow_ops.enter(data, "foo_1", False)
      enter_2 = gen_control_flow_ops.enter(data, "foo_2", False)
      res = math_ops.add(enter_1, enter_2)
      with self.assertRaisesOpError("has inputs from different frames"):
        res.eval(feed_dict={data: 1.0})

  @test_util.run_deprecated_v1
  def testCondBool(self):
    values = constant_op.constant(10)
    fn1 = lambda: math_ops.add(values, 1)
    fn2 = lambda: math_ops.subtract(values, 1)
    with self.assertRaisesRegex(TypeError, "must not be a Python bool"):
      _ = tf_cond.cond(False, fn1, fn2)

  @test_util.run_deprecated_v1
  def testCondInt(self):
    p = array_ops.placeholder(dtypes.bool, shape=[])
    v = constant_op.constant(10)
    fn1 = lambda: math_ops.add(v, 1)
    fn2 = lambda: math_ops.subtract(v, 1)
    y = tf_cond.cond(p, fn1, fn2)
    grad = gradients_impl.gradients(y, [v])
    self.assertAllEqual([None], grad)

  def testCondOutputShape(self):
    x = constant_op.constant(1.0)
    b = tf_cond.cond(
        constant_op.constant(True), lambda: math_ops.square(x),
        lambda: math_ops.subtract(x, 1.))
    self.assertEqual(b.shape, tensor_shape.TensorShape([]))

  @test_util.run_v1_only("b/120545219")
  def testFetchable(self):
    with self.cached_session() as sess:
      x = array_ops.placeholder(dtypes.float32)
      tf_cond.cond(
          constant_op.constant(True), lambda: x + 2, lambda: x + 0)
      graph = ops.get_default_graph()
      for op in graph.get_operations():
        for t in op.inputs:
          if graph.is_fetchable(t.op):
            sess.run(t, feed_dict={x: 3})
          else:
            with self.assertRaisesRegex(ValueError,
                                        "has been marked as not fetchable"):
              sess.run(t, feed_dict={x: 3})

  @test_util.disable_control_flow_v2("Not relevant")
  @test_util.run_v1_only("b/120545219")
  def testFeedable(self):
    with self.cached_session() as sess:
      c = constant_op.constant(2)
      i0 = constant_op.constant(0)
      r = while_loop_tf.while_loop(
          lambda i: i < 1000, lambda i: math_ops.square(c) + i, [i0])
      self.assertEqual(1000, r.eval(feed_dict={i0: 0}))
      feedable_tensors = all_feedables()
      for t in feedable_tensors:
        sess.run(r, feed_dict={t: 3})
      graph = ops.get_default_graph()
      for op in graph.get_operations():
        for t in op.inputs:
          if t not in feedable_tensors and t.dtype is dtypes.int32:
            with self.assertRaisesRegex(ValueError, "may not be fed"):
              sess.run(r, feed_dict={t: 3})

  @test_util.run_v1_only("b/120545219")
  def testCondIndexedSlices(self):
    with self.cached_session():
      values = constant_op.constant([10])
      indices = constant_op.constant([0])
      x = indexed_slices.IndexedSlices(values, indices)
      pred = math_ops.less(1, 2)
      fn1 = lambda: indexed_slices.IndexedSlices(
          math_ops.add(x.values, 1), indices)
      fn2 = lambda: indexed_slices.IndexedSlices(
          math_ops.subtract(x.values, 1), indices)
      r = tf_cond.cond(pred, fn1, fn2)

      val = r.values
      ind = r.indices
    self.assertAllEqual([11], val)
    self.assertAllEqual([0], ind)

  def testCondMismatchedIndexedSlices(self):
    @eager_def_function.function
    def foo():
      values = constant_op.constant([10])
      indices = constant_op.constant([0])
      x = indexed_slices.IndexedSlices(values, indices)
      with self.assertRaisesRegex(TypeError,
                                  "Cannot reconcile tf.cond 0-th outputs"):
        tf_cond.cond(
            constant_op.constant(True), lambda: indexed_slices.IndexedSlices(
                math_ops.add(x.values, 1), indices),
            lambda: math_ops.add(x.values, 1), indices)
    foo()

  def testCondSparseTensor(self):
    values = constant_op.constant([2.0, 4.0], name="values")
    indices = constant_op.constant([[0], [3]],
                                   dtype=dtypes.int64,
                                   name="indices")
    shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
    x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
    pred = math_ops.less(1, 2)
    fn1 = lambda: sparse_tensor.SparseTensor(
        indices + 1, x.values + 1, dense_shape=shape)
    fn2 = lambda: sparse_tensor.SparseTensor(
        indices, x.values - 1, dense_shape=shape)
    r = tf_cond.cond(pred, fn1, fn2)
    self.assertAllEqual([3.0, 5.0], r.values)
    self.assertAllEqual([[1], [4]], r.indices)
    self.assertAllEqual(r.values.get_shape(), (2,))

  def testCondRaggedTensor(self):
    rt = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]])
    pred = math_ops.less(1, 2)
    fn1 = lambda: array_ops.concat([rt + 2, [[100]]], axis=0)
    fn2 = lambda: rt[:2] - 2
    result = tf_cond.cond(pred, fn1, fn2)
    self.assertAllEqual([3, 4, 5, 6, 7, 8, 100], result.values)
    self.assertAllEqual([0, 2, 3, 6, 7], result.row_splits)

  @test_util.run_v1_only("b/120545219")
  def testCondResource(self):

    with self.cached_session():
      rv = resource_variable_ops.ResourceVariable(True)
      self.evaluate(variables.global_variables_initializer())
      t = ops.convert_to_tensor(1.0)

      def case():
        assign = resource_variable_ops.assign_variable_op(rv.handle, False)
        with ops.control_dependencies([assign]):
          return array_ops.identity(t)

      self.assertEqual(
          1.0, self.evaluate(tf_cond.cond(rv, case, lambda: t)))

  @test_util.run_deprecated_v1
  def testCondResourceGradShape(self):
    rv1 = resource_variable_ops.ResourceVariable([1.0, 2.0])
    rv2 = resource_variable_ops.ResourceVariable([3.0, 4.0])
    pred = constant_op.constant(True)
    result = tf_cond.cond(pred, lambda: rv1, lambda: rv2)
    grads = gradients_impl.gradients(result, [rv1, rv2])
    self.assertAllEqual(grads[0].shape.as_list(), [2])
    self.assertAllEqual(grads[1].shape.as_list(), [2])

  @test_util.run_v1_only("b/120545219")
  def testCondWithTensorArrayGrad(self):
    with self.cached_session() as sess:
      with ops.device(test.gpu_device_name()):
        pred = array_ops.placeholder(dtypes.bool, [])
        x = constant_op.constant([1.0, 2.0, 3.0])
        y = tf_cond.cond(
            pred, lambda: map_fn.map_fn(lambda z: z * 2.0, x),
            lambda: constant_op.constant([1.0, 1.0, 1.0]))
        g = gradients_impl.gradients(y, x)[0]

      self.assertAllEqual(sess.run(g, {pred: True}), [2.0, 2.0, 2.0])
      self.assertAllEqual(sess.run(g, {pred: False}), [0.0, 0.0, 0.0])

  @test_util.run_v1_only("b/120545219")
  def testCondIndexedSlicesDifferentTypes(self):
    with self.cached_session():
      values = constant_op.constant([10])
      i_32 = ops.convert_to_tensor([0], name="one", dtype=dtypes.int32)
      i_64 = ops.convert_to_tensor([0], name="one", dtype=dtypes.int64)
      x = indexed_slices.IndexedSlices(values, i_32)
      pred = math_ops.less(1, 2)
      fn1 = lambda: indexed_slices.IndexedSlices(
          math_ops.add(x.values, 1), i_32)
      fn2 = lambda: indexed_slices.IndexedSlices(
          math_ops.subtract(x.values, 1), i_64)
      r = tf_cond.cond(pred, fn1, fn2)

      val = r.values
      ind = r.indices
    self.assertAllEqual([11], val)
    self.assertAllEqual([0], ind)
    self.assertTrue(ind.dtype == np.int64)

  @test_util.run_v1_only("b/120545219")
  def testCondColocation(self):
    with self.session():
      with ops.device("/cpu:0"):
        v = variables.Variable(7.0)

      x = constant_op.constant(10.0)
      pred = math_ops.less(1.0, 2.0)
      fn1 = lambda: math_ops.add(v, 1.0)
      fn2 = lambda: math_ops.subtract(x, 1.0)
      r = tf_cond.cond(pred, fn1, fn2)

      for op in x.graph.get_operations():
        if op.name == "cond/Add/Switch":
          self.assertDeviceEqual(op.device, "/cpu:0")

  def _testCond_1(self, use_gpu):
    with self.cached_session(use_gpu=use_gpu):
      x = constant_op.constant(10)
      pred = math_ops.less(1, 2)
      fn1 = lambda: math_ops.add(x, 1)
      fn2 = lambda: math_ops.subtract(x, 1)
      r = tf_cond.cond(pred, fn1, fn2)

      result = self.evaluate(r)
    self.assertAllEqual(11, result)

  def testCond_1(self):

    self._testCond_1(use_gpu=False)
    # TODO(b/116526896): Enable GPU tests.
    # self._testCond_1(use_gpu=True)

  def testCond_2(self):

    with self.cached_session():
      x = constant_op.constant(10)
      r = tf_cond.cond(
          math_ops.less(1, 0), lambda: math_ops.add(x, 1),
          lambda: math_ops.subtract(x, 1))
      result = self.evaluate(r)
    self.assertAllEqual(9, result)

  def testCond_3(self):

    with self.cached_session():
      x = constant_op.constant(10)
      pred = math_ops.less(1, 2)
      fn1 = lambda: math_ops.add(x, 1)
      fn2 = lambda: math_ops.subtract(x, 1)
      fn3 = lambda: math_ops.add(tf_cond.cond(pred, fn1, fn2), 1)
      r = tf_cond.cond(pred, fn3, fn2)

      result = self.evaluate(r)
    self.assertAllEqual(12, result)

  @test_util.run_in_graph_and_eager_modes
  def testCondPruning(self):
    v1 = variables.Variable(7)
    v2 = variables.Variable(7)
    v3 = variables.Variable(7)

    def f():
      age = constant_op.constant(3)
      max_age = constant_op.constant(2)
      pred = math_ops.greater(age, max_age)
      fn1 = lambda: [state_ops.assign(v1, 1).op, state_ops.assign(v2, 2).op]
      fn2 = lambda: [state_ops.assign(v3, 3).op, constant_op.constant(10).op]
      r = tf_cond.cond(pred, fn1, fn2)
      self.assertEqual(len(r), 2)
      return r[1]

    f_defun = eager_def_function.function(f)

    if not context.executing_eagerly():
      with self.cached_session():
        self.evaluate(variables.global_variables_initializer())
        result = self.evaluate(f())
        self.assertEqual(True, result)
        # Only second cond result was fetched, so v1 assign shouldn't run.
        self.assertEqual(7, self.evaluate(v1))
        self.assertEqual(2, self.evaluate(v2))
        self.assertEqual(7, self.evaluate(v3))

    result = f_defun()
    self.assertEqual(True, self.evaluate(result))
    # Both v1 and v2 branch assignments should be run in defun.
    self.assertEqual(1, self.evaluate(v1))
    self.assertEqual(2, self.evaluate(v2))
    self.assertEqual(7, self.evaluate(v3))

  def testCond_5(self):
    with self.cached_session():
      alive = constant_op.constant(True, name="alive")
      count = constant_op.constant(0, name="count")

      def body(i):
        return tf_cond.cond(
            alive, lambda: [math_ops.less(i, 3), math_ops.add(count, 1)],
            lambda: [alive, count])

      for i in range(10):
        alive, count = body(i)
      self.assertAllEqual(4, self.evaluate(count))

  @test_util.run_v1_only("b/120545219")
  def testCond_6(self):
    with self.cached_session():
      v1 = variables.Variable([7])

      age = constant_op.constant(3)
      pred = math_ops.greater(age, 4)
      fn1 = lambda: age
      fn2 = lambda: v1
      r = tf_cond.cond(pred, fn1, fn2)

      self.evaluate(variables.global_variables_initializer())
      result = self.evaluate(r)
      self.assertAllEqual(np.array([7]), result)

  def testCond_7(self):
    with self.cached_session() as sess:
      x = constant_op.constant(10)
      y = constant_op.constant(200)
      pred = math_ops.less(1, 2)
      fn1 = lambda: [math_ops.add(x, 1), math_ops.add(x, 2)]
      fn2 = lambda: [y, y]
      r = tf_cond.cond(pred, fn1, fn2)
      self.assertAllEqual([11, 12], self.evaluate(r))

  @parameterized.parameters(dtypes.float32, dtypes.float64)
  @test_util.run_v1_only("Uses tf.gradients")
  def testCondResourceGrad(self, dtype):
    init = constant_op.constant([7.], dtype=dtype)
    v1 = variables.Variable(init)

    age = constant_op.constant(3., dtype=dtype)
    pred = math_ops.greater(age, 4.)
    fn1 = lambda: age
    fn2 = lambda: v1
    r = tf_cond.cond(pred, fn1, fn2)

    grad = gradients_impl.gradients(r, v1)[0]
    self.evaluate(variables.global_variables_initializer())
    self.assertAllEqual(grad, [1.])

  @test_util.run_gpu_only
  @test_util.run_deprecated_v1
  def testCond_Device(self):
    x = constant_op.constant(-10.)

    # True branch function defined outside of device scope
    def true_fn():
      return math_ops.exp(x)

    with ops.device("CPU:0"):
      r = tf_cond.cond(
          constant_op.constant(True), true_fn, lambda: 0.)
      self.assertIn("cpu", r.device.lower())

    with session.Session() as sess:
      options = config_pb2.RunOptions(output_partition_graphs=True)
      run_metadata = config_pb2.RunMetadata()
      sess.run(r, options=options, run_metadata=run_metadata)
      # We expect that everything runs on CPU, even if GPU is available.
      self.assertEqual(len(run_metadata.partition_graphs), 1)

  def _count_matching_switch_nodes_on_device(self, run_metadata, device_str,
                                             dtype):
    # Returns the number of Switch nodes with type dtype placed on
    # `device_str`.
    device_graphs = [
        g for g in run_metadata.partition_graphs
        if device_str in g.node[0].device
    ]
    self.assertLen(device_graphs, 1)
    switch_nodes = [
        n for n in device_graphs[0].node
        if n.op == "Switch" and n.attr["T"].type == dtype.as_datatype_enum
    ]
    return len(switch_nodes)

  @test_util.run_gpu_only
  @test_util.run_deprecated_v1
  def testCondSwitchColocatedWithInputWhenInputExplicitlyPlacedOnCPU(self):
    x = array_ops.placeholder(dtypes.float32)

    # `arg` is used in the cond then branch so a Switch node is created for it.
    # We test that the Switch node gets placed on the same device as `arg`.
    # We force `arg` to be on CPU here.
    with ops.device("CPU:0"):
      arg = x + 10.

    def true_fn():
      with ops.device("CPU:0"):
        return arg + 1

    r = tf_cond.cond(constant_op.constant(True), true_fn, lambda: 0.)

    # Disable Loop_optimizer grappler pass for this test because it replaces
    # Switch with Identity when it's part of a dead branch.
    config = config_pb2.ConfigProto()
    config.graph_options.rewrite_options.loop_optimization = (
        rewriter_config_pb2.RewriterConfig.OFF)

    with self.session(config=config) as sess:
      run_metadata = config_pb2.RunMetadata()
      options = config_pb2.RunOptions(output_partition_graphs=True)
      sess.run(
          r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
      self.assertLen(run_metadata.partition_graphs, 2)
      # Check that the Switch for `arg` gets placed on CPU.
      self.assertEqual(
          self._count_matching_switch_nodes_on_device(run_metadata, "CPU",
                                                      dtypes.float32), 1)
      self.assertEqual(
          self._count_matching_switch_nodes_on_device(run_metadata, "GPU",
                                                      dtypes.float32), 0)

  @test_util.run_gpu_only
  @test_util.run_deprecated_v1
  def testCondSwitchColocatedWithInputWhenInputPlacedOnCPU(self):
    x = array_ops.placeholder(dtypes.float32)

    # `arg` is used in the cond then branch so a Switch node is created for it.
    # We test that the Switch node gets placed on the same device as `arg`.
    # Since arg is a dataset (and only has a CPU kernel), it gets placed on CPU
    # by placer.
    arg = dataset_ops.Dataset.range(8)

    def true_fn():
      return cardinality.cardinality(arg)

    r = tf_cond.cond(
        constant_op.constant(True), true_fn,
        lambda: constant_op.constant(0, dtypes.int64))

    # Disable Loop_optimizer grappler pass for this test because it replaces
    # Switch with Identity when it's part of a dead branch.
    config = config_pb2.ConfigProto()
    config.graph_options.rewrite_options.loop_optimization = (
        rewriter_config_pb2.RewriterConfig.OFF)

    with session.Session(config=config) as sess:
      run_metadata = config_pb2.RunMetadata()
      options = config_pb2.RunOptions(output_partition_graphs=True)
      sess.run(
          r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
      self.assertLen(run_metadata.partition_graphs, 2)
      # Check that the Switch for `arg` gets placed on CPU.
      self.assertEqual(
          self._count_matching_switch_nodes_on_device(run_metadata, "CPU",
                                                      dtypes.variant), 1)
      self.assertEqual(
          self._count_matching_switch_nodes_on_device(run_metadata, "GPU",
                                                      dtypes.variant), 0)

  @test_util.run_gpu_only
  @test_util.run_deprecated_v1
  def testCondSwitchColocatedWithInputWhenInputOnGPU(self):
    x = array_ops.placeholder(dtypes.float32)

    # `arg` is used in the cond then branch so a Switch node is created for it.
    # We test that the Switch node gets placed on the same device as `arg`.
    # Note: `arg` gets placed on GPU by default by the placer.
    arg = x + 10.

    def true_fn():
      with ops.device("CPU:0"):
        return arg + 1

    r = tf_cond.cond(constant_op.constant(True), true_fn, lambda: 0.)

    # Disable Loop_optimizer grappler pass for this test because it replaces
    # Switch with Identity when it's part of a dead branch.
    config = config_pb2.ConfigProto()
    config.graph_options.rewrite_options.loop_optimization = (
        rewriter_config_pb2.RewriterConfig.OFF)

    with session.Session(config=config) as sess:
      run_metadata = config_pb2.RunMetadata()
      options = config_pb2.RunOptions(output_partition_graphs=True)
      sess.run(
          r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
      self.assertEqual(len(run_metadata.partition_graphs), 2)
      # Check that the Switch for `arg` gets placed on GPU.
      self.assertEqual(
          self._count_matching_switch_nodes_on_device(run_metadata, "CPU",
                                                      dtypes.float32), 0)
      self.assertEqual(
          self._count_matching_switch_nodes_on_device(run_metadata, "GPU",
                                                      dtypes.float32), 1)

  def testCondAccessTrueBranchTensorInFalseBranchRaises(self):

    @eager_def_function.function
    def f():
      c = constant_op.constant(1.)
      inputs = {"c": c}

      def true_fn(inputs):
        inputs["c"] = array_ops.identity(inputs["c"], name="true_branch")
        return inputs["c"]

      def false_fn(inputs):
        return array_ops.identity(inputs["c"])

      pred = constant_op.constant(True)
      return tf_cond.cond(
          pred, lambda: true_fn(inputs), lambda: false_fn(inputs))

    prefix = "cond/" if context.executing_eagerly() else ""

    with self.assertRaisesRegex(
        ValueError,
        "Tensor %strue_branch:0 in true_fn is accessed from false_fn." %
        prefix):
      f()

  def testSwitchCaseAccessBranch1TensorInBranch4Raises(self):

    @eager_def_function.function
    def f():
      c = constant_op.constant(1.)
      inputs = {"c": c}

      def br1_fn(inputs):
        inputs["c"] = array_ops.identity(inputs["c"], name="br1_identity")
        return inputs["c"]

      def br4_fn(inputs):
        return array_ops.identity(inputs["c"])

      def other_fn():
        return array_ops.identity(c)

      return control_flow_switch_case.switch_case(
          constant_op.constant(2), [
              other_fn, lambda: br1_fn(inputs), other_fn, other_fn,
              lambda: br4_fn(inputs)
          ])

    prefix = "switch_case/indexed_case/" if context.executing_eagerly() else ""
    with self.assertRaisesRegex(
        ValueError, "Tensor %sbr1_identity:0 in branch 1 is "
        "accessed from branch 4." % prefix):
      f()

  def testCondListOutput(self):
    with self.cached_session() as sess:
      x = constant_op.constant(10)
      y = constant_op.constant(200)
      pred = math_ops.less(1, 2)
      fn1 = lambda: [math_ops.add(x, y), math_ops.add(x, y)]
      fn2 = lambda: [y, y]
      r = tf_cond.cond(pred, fn1, fn2)
      test_result = self.evaluate(r)
      self.assertListEqual([210, 210], test_result)

  def testTupleOutput(self):
    with self.cached_session() as sess:
      x = constant_op.constant(10)
      y = constant_op.constant(200)
      pred = math_ops.less(1, 2)
      fn1 = lambda: (math_ops.add(x, y), math_ops.add(x, y))
      fn2 = lambda: (y, y)
      r = tf_cond.cond(pred, fn1, fn2)
      test_result = self.evaluate(r)
      self.assertTupleEqual((210, 210), test_result)

  def testDictOutput(self):
    with self.cached_session() as sess:
      x = constant_op.constant(10)
      y = constant_op.constant(200)
      pred = math_ops.less(1, 2)
      fn1 = lambda: {"a": math_ops.add(x, y), "b": math_ops.add(x, y)}
      fn2 = lambda: {"a": y, "b": y}
      r = tf_cond.cond(pred, fn1, fn2)
      test_result = self.evaluate(r)
      self.assertDictEqual({"a": 210, "b": 210}, test_result)

  def testEmbeddedListOutput(self):
    x = constant_op.constant(10)
    y = constant_op.constant(200)
    pred = math_ops.less(1, 2)
    fn1 = lambda: [[math_ops.add(x, y), math_ops.add(x, y)]]
    fn2 = lambda: [[y, y]]
    # Pass strict=True flag as cond_v2 allows for tensors to be
    # in nested output structures as singletons
    r = tf_cond.cond(pred, fn1, fn2, strict=True)
    test_result = self.evaluate(r)
    self.assertListEqual([[210, 210]], test_result)

  def testEmbeddedTupleOutput(self):
    with self.cached_session() as sess:
      x = constant_op.constant(10)
      y = constant_op.constant(200)
      pred = math_ops.less(1, 2)
      fn1 = lambda: ((math_ops.add(x, y), math_ops.add(x, y)))
      fn2 = lambda: ((y, y))
      r = tf_cond.cond(pred, fn1, fn2)
      test_result = self.evaluate(r)
      self.assertTupleEqual(((210, 210)), test_result)

  def testEmbeddedDictOutput(self):
    with self.cached_session() as sess:
      x = constant_op.constant(10)
      y = constant_op.constant(200)
      pred = math_ops.less(1, 2)
      fn1 = lambda: {"a": {"c": math_ops.add(x, y)},
                     "b": {"d": math_ops.add(x, y)}}
      fn2 = lambda: {"a": {"c": y},
                     "b": {"d": y}}
      r = tf_cond.cond(pred, fn1, fn2)
      test_result = self.evaluate(r)
      self.assertDictEqual({"a": {"c": 210}, "b": {"d": 210}}, test_result)

  @test_util.run_v1_only("b/120545219")
  def testCheckNestedOutputStruct(self):
    with self.cached_session() as sess:
      x = constant_op.constant(10)
      y = constant_op.constant(200)
      pred = math_ops.less(1, 2)
      fn1 = lambda: {"a": math_ops.add(x, y), "b": math_ops.add(x, y)}
      fn2 = lambda: {"c": y, "d": y}
      v1_msg = "The two structures don't have the same nested structure"
      v2_msg = ("true_fn and false_fn arguments to tf.cond must have the same "
                "number, type, and overall structure of return values.")
      with self.assertRaisesRegex(
          TypeError if control_flow_util.ENABLE_CONTROL_FLOW_V2 else ValueError,
          v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
        tf_cond.cond(pred, fn1, fn2)

  @test_util.run_v1_only("b/120545219")
  def testCondWithControl(self):
    with self.cached_session() as sess:
      control_holder = array_ops.placeholder(dtypes.float32, shape=())
      a = constant_op.constant(3)

      def true_branch():
        with ops.control_dependencies([control_holder]):
          _ = a + 1
        return a + 2

      r = tf_cond.cond(
          constant_op.constant(True), true_branch,
          lambda: constant_op.constant(1))
      result = sess.run(r, feed_dict={control_holder: 5.})
      self.assertEqual(5, result)

  @test_util.run_v1_only("b/120545219")
  def testUninitializedRefIdentity(self):
    with self.cached_session() as sess:
      v = gen_state_ops.variable(
          shape=[1],
          dtype=dtypes.float32,
          name="v",
          container="",
          shared_name="")
      inited = state_ops.is_variable_initialized(v)
      v_f, v_t = control_flow_ops.ref_switch(v, inited)
      # Both v_f and v_t are uninitialized references. However, an actual use
      # of the reference in the 'true' branch in the 'tf.identity' op will
      # not 'fire' when v is uninitialized, so this is a valid construction.
      # This test tests that ref_identity allows uninitialized ref as input
      # so that this construction is allowed.
      v_f_op = gen_array_ops.ref_identity(v_f)
      v_t_op = gen_array_ops.ref_identity(v_t)
      with ops.control_dependencies([v_f_op]):
        assign_v = state_ops.assign(v, [1.0])
      with ops.control_dependencies([v_t_op]):
        orig_v = array_ops.identity(v)
      merged_op = control_flow_ops.merge([assign_v, orig_v])
      self.assertAllEqual([1.0], self.evaluate(merged_op.output))

  def testCondSwitchIdentity(self):
    # Make sure the recv identity is not removed by optimization.
    with session.Session(config=opt_cfg()) as sess:
      pred = constant_op.constant(True)

      def fn1():
        return control_flow_ops.no_op()

      def fn2():
        return control_flow_assert.Assert(False, ["Wrong branch!!!"])

      r = tf_cond.cond(pred, fn1, fn2)
      self.evaluate(r)

  def testCondRecvIdentity(self):
    # Make sure the switch identity is not removed by optimization.
    with session.Session(config=opt_cfg()) as sess:
      with ops.device(test.gpu_device_name()):
        pred = constant_op.constant(True)

      def fn1():
        return control_flow_ops.no_op()

      def fn2():
        with ops.device("/cpu:0"):
          return control_flow_assert.Assert(False, ["Wrong branch!!!"])

      r = tf_cond.cond(pred, fn1, fn2)
      self.evaluate(r)

  @test_util.run_deprecated_v1
  @test_util.enable_control_flow_v2
  def testDisableLoweringSwitchMerge(self):
    if test_util.is_gpu_available():
      self.skipTest(
          "Single threaded executor doesn't support partitioned graphs.  "
          "Skipping GPU test.")
    # Make pred feedable to ensure we don't constant-fold it out.
    run_opts = config_pb2.RunOptions(
        trace_level=config_pb2.RunOptions.FULL_TRACE)
    run_metadata_no_lowering = config_pb2.RunMetadata()
    run_metadata_with_lowering = config_pb2.RunMetadata()

    config = opt_cfg(do_constant_folding=False)

    pred = array_ops.placeholder_with_default(
        constant_op.constant(True), shape=())
    r = tf_cond.cond(pred, lambda: True, lambda: False)

    with session.Session(config=config) as sess:
      r_value = sess.run(
          r, options=run_opts, run_metadata=run_metadata_with_lowering)
      self.assertEqual(r_value, True)

    # Use the single threaded executor, which disables control flow lowering.
    config.experimental.executor_type = "SINGLE_THREADED_EXECUTOR"
    with session.Session(config=config) as sess:
      r_value = sess.run(
          r, options=run_opts, run_metadata=run_metadata_no_lowering)
      self.assertEqual(r_value, True)

    self.assertTrue(  # pylint: disable=g-complex-comprehension
        any("switch" in ns.node_name
            for dev_stat in run_metadata_with_lowering.step_stats.dev_stats
            for ns in dev_stat.node_stats))

    self.assertTrue(  # pylint: disable=g-complex-comprehension
        all("switch" not in ns.node_name
            for dev_stat in run_metadata_no_lowering.step_stats.dev_stats
            for ns in dev_stat.node_stats))

  @test_util.run_v1_only("b/120545219")
  def testCondGrad_1(self):
    with self.cached_session():
      x = constant_op.constant(10.0, name="x")
      pred = math_ops.less(1, 2)
      fn1 = lambda: array_ops.identity(x)
      fn2 = lambda: array_ops.identity(x)
      r = tf_cond.cond(pred, fn1, fn2)

      grad = gradients_impl.gradients(r, [x])[0]
      self.assertAllEqual(1.0, self.evaluate(grad))

  @test_util.run_deprecated_v1
  @test_util.enable_control_flow_v2
  def testCondComputeGradAfterSessRunFails(self):
    with self.cached_session():
      x = constant_op.constant(10.0, name="x")
      pred = math_ops.less(1, 2)

      def true_fn():
        a = x * x
        return a * a

      def false_fn():
        return x * x

      r = tf_cond.cond(pred, true_fn, false_fn)

      self.assertAllEqual(r, 10000.)
      grad = gradients_impl.gradients(r, [x])[0]
      with self.assertRaisesRegex(
          errors_impl.InvalidArgumentError,
          r"Connecting to invalid output 1 of source node cond which has 1 "
          r"outputs. Try using "
          "tf.compat.v1.experimental.output_all_intermediates\(True\)."):
        self.evaluate(grad)

  @test_util.run_deprecated_v1
  @test_util.enable_output_all_intermediates
  def testCondComputeGradAfterSessRun(self):
    with self.cached_session():
      x = constant_op.constant(10.0, name="x")
      pred = math_ops.less(1, 2)

      def true_fn():
        a = x * x
        return a * a

      def false_fn():
        return x * x

      r = tf_cond.cond(pred, true_fn, false_fn)

      self.assertAllEqual(r, 10000.)
      grad = gradients_impl.gradients(r, [x])[0]
      self.assertAllEqual(grad, 4000.)

  @test_util.run_deprecated_v1
  @test_util.enable_output_all_intermediates
  def testNestedCondComputeGradAfterSessRun(self):
    with self.cached_session():
      x = constant_op.constant(10.0, name="x")
      pred = math_ops.less(1, 2)

      def true_fn():

        def inner_true_fn():
          a = x * x
          return a * a

        def inner_false_fn():
          return x * x

        return tf_cond.cond(
            constant_op.constant(True), inner_true_fn, inner_false_fn)

      def false_fn():
        return x * x

      r = tf_cond.cond(pred, true_fn, false_fn)

      self.assertAllEqual(r, 10000.)
      grad = gradients_impl.gradients(r, [x])[0]
      self.assertAllEqual(grad, 4000.)

  @test_util.run_deprecated_v1
  def testCondGrad_2(self):
    with self.cached_session():
      c = array_ops.placeholder(dtypes.int32, shape=[])
      x = constant_op.constant(10.0)
      pred = math_ops.less(c, 2)
      fn1 = lambda: math_ops.multiply(x, 42.0)
      fn2 = lambda: math_ops.multiply(x, 3.0)
      r = tf_cond.cond(pred, fn1, fn2)

      grad = gradients_impl.gradients(r, [x])[0]
      self.assertAllEqual(42.0, grad.eval(feed_dict={c: 1}))
      self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3}))

  @test_util.disable_control_flow_v2(
      "b/110550782 (gradient w.r.t external variable)")
  @test_util.run_deprecated_v1
  def testCondGrad_3(self):
    with self.cached_session():
      c = array_ops.placeholder(dtypes.int32, shape=[])
      ox = constant_op.constant(10.0)
      pred = math_ops.less(c, 2)

      def fn1(x):
        m = x * x
        return gradients_impl.gradients(m, [ox])[0]

      fn2 = lambda: math_ops.multiply(ox, 3.0)
      y = math_ops.multiply(7.0, ox)
      r = tf_cond.cond(pred, lambda: fn1(y), fn2)

      self.assertAllEqual(980.0, r.eval(feed_dict={c: 1}))
      self.assertAllEqual(30.0, r.eval(feed_dict={c: 3}))

  @test_util.run_deprecated_v1
  def testCondGradMultiDevice(self):
    config = config_pb2.ConfigProto(device_count={"CPU": 2},
                                    allow_soft_placement=True)
    with self.cached_session(config=config) as sess:
      pred = array_ops.placeholder(dtypes.bool, [])
      x = array_ops.placeholder(dtypes.float32)
      y = array_ops.placeholder(dtypes.float32)

      with ops.device("/cpu:0"):
        z = tf_cond.cond(pred, lambda: x * y * 2.0, lambda: 2.0)

      with ops.device("/cpu:1"):
        grad = gradients_impl.gradients(z, x)[0]

      with ops.device("/cpu:0"):
        grad_grad = gradients_impl.gradients(grad, x)[0]

      self.assertEqual(sess.run(grad, {pred: True, x: 1.0, y: 2.0}), 4.0)
      self.assertEqual(sess.run(grad, {pred: False, x: 1.0, y: 2.0}), 0.0)

      # v1 control flow gets None second derivative for some reason.
      if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
        self.assertIsNone(grad_grad)
        return

      self.assertEqual(sess.run(grad_grad, {pred: True, x: 1.0, y: 2.0}), 0.0)
      self.assertEqual(sess.run(grad_grad, {pred: False, x: 1.0, y: 2.0}), 0.0)

  @test_util.run_v1_only("b/120545219")
  def testNestedCond_Simple(self):
    with self.cached_session():
      x = constant_op.constant(0., name="X")
      y = tf_cond.cond(
          constant_op.constant(True), lambda: x,
          lambda: tf_cond.cond(x < 1., lambda: x, lambda: x))
      result = gradients_impl.gradients(y, x)[0]
      self.assertEqual(1.0, self.evaluate(result))

      z = tf_cond.cond(
          constant_op.constant(False), lambda: x,
          lambda: tf_cond.cond(x < 1., lambda: x, lambda: x))
      result = gradients_impl.gradients(z, x)[0]
      self.assertEqual(1.0, self.evaluate(result))

  @test_util.run_v1_only("b/120545219")
  def testCondGrad_Gather(self):
    with self.cached_session() as sess:
      v1 = variables.Variable([1.0, 42.0])
      c = array_ops.placeholder(dtypes.int32, shape=[])
      pred = math_ops.less(c, 2)
      fn1 = lambda: array_ops.identity(v1)
      fn2 = lambda: array_ops.gather(v1, [1, 1])
      r = tf_cond.cond(pred, fn1, fn2)
      # The following `grad` is a Tensor since it is the aggregation of an
      # IndexedSlice and a Tensor. It is an `IndexedSlices` with control flow
      # v2.
      grad = gradients_impl.gradients(r, [v1])[0]
      self.evaluate(variables.global_variables_initializer())

      if control_flow_util.ENABLE_CONTROL_FLOW_V2:
        self.assertIsInstance(grad, indexed_slices.IndexedSlices)

      grad_value = sess.run(grad, feed_dict={c: 1})
      self.assertAllEqual(gradient_checker_v2._to_numpy(grad_value), [1.0, 1.0])

      grad_value = sess.run(grad, feed_dict={c: 3})
      self.assertAllEqual(gradient_checker_v2._to_numpy(grad_value), [0.0, 2.0])

  @test_util.run_deprecated_v1
  def testCondGrad_ResourceVarSparseRead(self):
    # NOTE(skyewm): this test is interesting because the
    # ResourceVariable.sparse_read gradient function returns IndexedSlices.
    var = resource_variable_ops.ResourceVariable(
        np.ones((4, 2), dtype=np.float32))
    x = constant_op.constant(1.0)
    r = tf_cond.cond(
        constant_op.constant(True),
        lambda: x * math_ops.reduce_sum(var.sparse_read([1, 2])),
        lambda: constant_op.constant(np.zeros((2, 3)),
                                     dtype=dtypes.float32))
    grad = gradients_impl.gradients(r, var)[0]

    self.evaluate(variables.global_variables_initializer())
    grad_val = self.evaluate(grad)
    self.assertIsInstance(grad_val, indexed_slices.IndexedSlicesValue)
    self.assertAllEqual(gradient_checker_v2._to_numpy(grad_val), [[0., 0.],
                                                                  [1., 1.],
                                                                  [1., 1.],
                                                                  [0., 0.]])

  def testCondGrad_MultiGather(self):
    # NOTE(skyewm): this test is interesting because the array_ops.gather and
    # ResourceVariable.sparse_read gradient functions returns IndexedSlices.
    var = resource_variable_ops.ResourceVariable(
        np.ones((4, 2), dtype=np.float32))
    x1 = constant_op.constant(np.ones((3, 3), dtype=np.float32))
    x2 = constant_op.constant(2.0)

    def true_fn():
      y1 = var.sparse_read([1, 2])
      y2 = array_ops.gather(x1, [2]) * x2
      y3 = x2 * [1., 1., 1.]
      return y1, y2, y3

    def false_fn():
      y1 = np.zeros((2, 2), dtype=np.float32)
      y2 = array_ops.gather(x1, [2]) * x2
      y3 = array_ops.gather(x1, [2])
      return y1, y2, y3

    @eager_def_function.function
    def foo():
      r = tf_cond.cond(constant_op.constant(True), true_fn, false_fn)
      return gradients_impl.gradients(r, [var, x1, x2])

    grad = foo()
    self.evaluate(variables.global_variables_initializer())
    var_grad, x1_grad, x2_grad = self.evaluate(grad)
    self.assertIsInstance(var_grad, indexed_slices.IndexedSlicesValue)
    self.assertAllEqual(gradient_checker_v2._to_numpy(var_grad), [[0., 0.],
                                                                  [1., 1.],
                                                                  [1., 1.],
                                                                  [0., 0]])
    self.assertIsInstance(x1_grad, indexed_slices.IndexedSlicesValue)
    self.assertAllEqual(gradient_checker_v2._to_numpy(x1_grad), [[0., 0., 0.],
                                                                 [0., 0., 0.],
                                                                 [2., 2., 2.]])
    self.assertIsInstance(x1_grad, indexed_slices.IndexedSlicesValue)
    self.assertEqual(gradient_checker_v2._to_numpy(x2_grad), 6.)

  @test_util.run_v1_only("b/120545219")
  def testCondPredicateTensor(self):
    """Regression test for lowering predicate from non-first output of an op."""

    @eager_def_function.function
    def foo():
      return constant_op.constant("foo"), constant_op.constant(True)

    r = tf_cond.cond(foo()[1], lambda: 1.0, lambda: 2.0)
    self.assertEqual(self.evaluate(r), 1.0)

  @test_util.run_v1_only("Tests Session.run() pruning logic.")
  def testCondFeedConstantPredicate(self):
    with self.cached_session() as sess:
      value = constant_op.constant(37.0)
      predicate = constant_op.constant(True)
      cond_output = tf_cond.cond(
          predicate, lambda: constant_op.constant(0.0), lambda: value)
      result = array_ops.identity(cond_output)
      self.assertEqual(37.0, sess.run(result, feed_dict={predicate: False}))
      self.assertEqual(0.0, sess.run(result, feed_dict={predicate: True}))
      self.assertEqual(0.0, sess.run(result))

  @test_util.run_v1_only("Tests Session.run() pruning logic.")
  def testCondFeedPlaceholderWithDefaultPredicate(self):
    with self.cached_session() as sess:
      value = constant_op.constant(37.0)
      predicate = array_ops.placeholder_with_default(
          constant_op.constant(True), [])
      cond_output = tf_cond.cond(
          predicate, lambda: constant_op.constant(0.0), lambda: value)
      result = array_ops.identity(cond_output)
      self.assertAllEqual(37.0, sess.run(result, feed_dict={predicate: False}))
      self.assertAllEqual(0.0, sess.run(result, feed_dict={predicate: True}))
      self.assertAllEqual(0.0, sess.run(result))

  def testCondTensorDeps(self):
    t = array_ops.identity(1.)

    @eager_def_function.function
    def f():
      with ops.control_dependencies([t]):
        return array_ops.identity(2.)

    f.get_concrete_function()

  @test_util.run_in_graph_and_eager_modes
  def testCondAutoControlDeps(self):
    if test_util.is_gpu_available():
      self.skipTest("b/128676188 causes OOM on opensource gpu tests")

    print_prefix = "testCondAutoControlDeps: "

    def branch_fn():
      enqueue_print_op("A")
      enqueue_print_op("B")
      with ops.control_dependencies([enqueue_print_op("C")]):
        return constant_op.constant(10)

    def build_cond():
      return tf_cond.cond(
          constant_op.constant(True), branch_fn, lambda: 0)

    def build_nested_cond():
      return tf_cond.cond(
          constant_op.constant(True), build_cond, lambda: 0)

    # In v1 graph mode, pruning should make only "C" print.
    if not context.executing_eagerly():
      with self.cached_session():
        with self.captureWritesToStream(sys.stderr) as printed:
          self.assertEqual(self.evaluate(build_cond()), 10)
        self.assertEqual(["C"], filter_test_messages(printed.contents()))

        with self.captureWritesToStream(sys.stderr) as printed:
          self.assertEqual(self.evaluate(build_nested_cond()), 10)
        self.assertEqual(["C"], filter_test_messages(printed.contents()))

    # In defuns, all prints should execute in program order.
    # This doesn't work with legacy control flow.
    if control_flow_util.ENABLE_CONTROL_FLOW_V2:

      @eager_def_function.function
      def cond():
        return build_cond()

      with self.captureWritesToStream(sys.stderr) as printed:
        self.assertEqual(self.evaluate(cond()), 10)
      self.assertEqual(["A", "B", "C"],
                       filter_test_messages(printed.contents()))

      @eager_def_function.function
      def nested_cond():
        return build_nested_cond()

      with self.captureWritesToStream(sys.stderr) as printed:
        self.assertEqual(self.evaluate(nested_cond()), 10)
      self.assertEqual(["A", "B", "C"],
                       filter_test_messages(printed.contents()))

    # wrap_function should prune.
    def pruned_cond():
      return build_cond()
    pruned_cond = wrap_function.wrap_function(pruned_cond, [])

    with self.captureWritesToStream(sys.stderr) as printed:
      self.assertEqual(self.evaluate(pruned_cond()), 10)
    self.assertEqual(["C"], filter_test_messages(printed.contents()))

    def pruned_nested_cond():
      return build_nested_cond()
    pruned_nested_cond = wrap_function.wrap_function(pruned_nested_cond, [])

    with self.captureWritesToStream(sys.stderr) as printed:
      self.assertEqual(self.evaluate(pruned_nested_cond()), 10)
    self.assertEqual(["C"], filter_test_messages(printed.contents()))


  @test_util.run_in_graph_and_eager_modes
  @test_util.disable_tfrt("b/179459136")
  def testWhileAutoControlDeps(self):
    # Legacy while_loop fails this test because it produces deprecation notices
    # in stderr.
    if not control_flow_util.ENABLE_CONTROL_FLOW_V2: return

    def cond(i, unused_x):
      enqueue_print_op("A")
      return i < 2

    def body(i, x):
      enqueue_print_op("B")
      with ops.control_dependencies([enqueue_print_op("C")]):
        x = array_ops.identity(x)
      with ops.control_dependencies([enqueue_print_op("D")]):
        return i + 1, x

    def build_while():
      return while_loop_tf.while_loop(
          cond, body, [constant_op.constant(0),
                       constant_op.constant(0)])

    def build_nested_while():
      return tf_cond.cond(
          constant_op.constant(True), build_while, lambda: [0, 0])

    # In v1 graph mode, pruning should make only "D" print.
    if not context.executing_eagerly():
      with self.cached_session():
        with self.captureWritesToStream(sys.stderr) as printed:
          self.assertEqual(self.evaluate(build_while()[0]), 2)
        self.assertEqual(["D", "D"], filter_test_messages(printed.contents()))

        with self.captureWritesToStream(sys.stderr) as printed:
          self.assertEqual(self.evaluate(build_nested_while()[0]), 2)
        self.assertEqual(["D", "D"], filter_test_messages(printed.contents()))

    # In defuns, all prints should execute in program order.
    @eager_def_function.function
    def while_loop():
      return build_while()[0]

    with self.captureWritesToStream(sys.stderr) as printed:
      self.assertEqual(self.evaluate(while_loop()), 2)
    self.assertEqual(["A", "B", "C", "D", "A", "B", "C", "D", "A"],
                     filter_test_messages(printed.contents()))

    @eager_def_function.function
    def nested_while_loop():
      return build_nested_while()[0]

    with self.captureWritesToStream(sys.stderr) as printed:
      self.assertEqual(self.evaluate(nested_while_loop()), 2)
    self.assertEqual(["A", "B", "C", "D", "A", "B", "C", "D", "A"],
                     filter_test_messages(printed.contents()))

    # wrap_function should prune.
    def pruned_while():
      return build_while()[0]
    pruned_while = wrap_function.wrap_function(pruned_while, [])

    with self.captureWritesToStream(sys.stderr) as printed:
      self.assertEqual(self.evaluate(pruned_while()), 2)
    self.assertEqual(["D", "D"], filter_test_messages(printed.contents()))

    def pruned_nested_while():
      return build_nested_while()[0]
    pruned_nested_while = wrap_function.wrap_function(pruned_nested_while, [])

    with self.captureWritesToStream(sys.stderr) as printed:
      self.assertEqual(self.evaluate(pruned_nested_while()), 2)
    self.assertEqual(["D", "D"], filter_test_messages(printed.contents()))

  # Microbenchmark: 256,000 iterations/s.
  def testWhile_1(self):
    with self.cached_session():
      n = constant_op.constant(0)
      c = lambda x: math_ops.less(x, 10000)
      b = lambda x: math_ops.add(x, 1)
      r = while_loop_tf.while_loop(c, b, [n], parallel_iterations=20)
      self.assertEqual(10000, self.evaluate(r))

  @test_util.run_v1_only("b/120545219")
  def testWhileExternalControlDependencies(self):
    with self.cached_session():
      v = variables.Variable(0.0)
      self.evaluate(v.initializer)
      increment = v.assign_add(1.0).read_value()

      def body_fn(i):
        with ops.control_dependencies([increment]):
          return i + 1

      result = while_loop_tf.while_loop(
          cond=lambda i: i < 2, body=body_fn, loop_vars=[1])
      self.assertAllEqual(result, 2)
      self.assertAllEqual(v.read_value(), 1.0)

  @test_util.run_v1_only("b/120545219")
  def testWhileExternalControlDependenciesNoInput(self):
    with self.cached_session():
      v = variables.Variable(0.0)
      self.evaluate(v.initializer)
      # TODO(apassos): figure out why the reading is necessary here.
      increment = v.assign_add(1.0).read_value()

      def body_fn(unused_i):
        with ops.control_dependencies([increment]):
          return constant_op.constant(5, name="five")

      result = while_loop_tf.while_loop(
          cond=lambda i: i < 5, body=body_fn, loop_vars=[0])
      self.evaluate(result)
      self.assertAllEqual(self.evaluate(v), 1.0)

  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
  @test_util.run_v1_only("b/120545219")
  def testWhileWithRefs_1(self):
    with self.cached_session() as sess:
      x = variable_v1.VariableV1(0)._ref()  # pylint: disable=protected-access
      i = constant_op.constant(0)
      c = lambda i, x: math_ops.less(i, 100)

      self.assertEqual(x.dtype, dtypes.int32_ref)

      def b(i, x):
        self.assertEqual(x.dtype, dtypes.int32_ref)
        return (i + 1, gen_array_ops.ref_identity(x))

      r = while_loop_tf.while_loop(c, b, [i, x], parallel_iterations=5)

      self.evaluate(variables.global_variables_initializer())

      self.assertEqual(r[0].dtype, dtypes.int32)
      self.assertEqual(r[1].dtype, dtypes.int32_ref)

      value_i, value_x = self.evaluate(r)

    self.assertEqual(100, value_i)
    self.assertEqual(0, value_x)

  def testWhile_2(self):
    with self.cached_session():
      s = constant_op.constant(0)
      r = isum(s)
      self.assertAllEqual(45, self.evaluate(r))

  def testWhileWithMaximumIterations(self):
    with self.cached_session():
      s = constant_op.constant([1, 2, 3, 4, 5])
      r = isum(s, maximum_iterations=3)
      self.assertAllEqual([1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3], self.evaluate(r))

  @test_util.run_v1_only("b/120545219")
  def testWhileWithMaximumIterationsAndSingleArgument(self):
    with self.cached_session():
      r = while_loop_tf.while_loop(
          lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
      self.assertEqual(1, self.evaluate(r))

  @test_util.run_v1_only("b/120545219")
  def testXLAGradInLoop(self):
    # We have an optimization that moves certain reduction ops, this test makes
    # sure we don't do that for XLA ops.

    # Use dynamic inputs, which triggers the creation of "BroadcastGradientArgs"
    # and "Shape" op.
    input1 = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None])
    input2 = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None])
    def cond(i1, i2):
      return False

    def body(i1, i2):
      return math_ops.add(i1, i2), math_ops.add(i1, i2)

    xla_context = control_flow_ops.XLAControlFlowContext()
    xla_context.Enter()

    out1, _ = while_loop_tf.while_loop(
        cond, body, (input1, input2), maximum_iterations=2)
    g = gradients_impl.gradients(out1, [input1])

    for op in out1.graph.get_operations():
      # Test that the "Shape" is directly passed to BroadcastGradientArgs
      # instead of being pushed to the stack.
      if op.type == "BroadcastGradientArgs":
        self.assertEqual(op.inputs[0].op.type, "Shape")
        self.assertEqual(op.inputs[1].op.type, "Shape")
    xla_context.Exit()


  @test_util.disable_control_flow_v2("b/115776323 (max_iters)")
  @test_util.run_v1_only("b/120545219")
  def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self):
    v = constant_op.constant(1.0)

    def training_loop_with_gradient(i):
      out = while_loop_tf.while_loop(
          lambda i_, _: i_ < 3,
          lambda i_, j: [i_ + 1, j * v], [0, 1.0],
          maximum_iterations=i)
      g = gradients_impl.gradients(out, v)
      with ops.control_dependencies(g):
        return i + 1

    xla_context = control_flow_ops.XLAControlFlowContext()
    xla_context.Enter()
    # Create training loop, ensure we can call gradient() of
    # while_loop inside the training loop.
    loop = while_loop_tf.while_loop(
        lambda i: i < 3, training_loop_with_gradient, [0])
    xla_context.Exit()

    loop_execute = array_ops.identity(loop)  # Because loop is not fetchable.

    # Should execute without issue.
    self.assertEqual(3, self.evaluate(loop_execute))

  @test_util.run_v1_only("b/120545219")
  def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self):
    if control_flow_util.ENABLE_CONTROL_FLOW_V2:
      self.skipTest("WhileV2 does lazy evaluation of maximum_iterations")
    v = constant_op.constant(1.0)

    def inner_body(i, x):
      out = while_loop_tf.while_loop(
          lambda i, _: i < 3,
          lambda i, j: [i + 1, j * v], [0, x],
          maximum_iterations=i)
      return out

    def create_while_loop(maximum_iterations=None):
      return while_loop_tf.while_loop(
          lambda i, _: i < 3,
          inner_body, [0, 1.0],
          maximum_iterations=maximum_iterations)

    loop_no_xla = create_while_loop(maximum_iterations=5)
    # maximum_iterations is fine outside of an XLA scope
    gs = gradients_impl.gradients(loop_no_xla, v)
    self.evaluate(gs)  # This should execute without error.

    xla_context = control_flow_ops.XLAControlFlowContext()
    xla_context.Enter()
    loop_no_maxiter = create_while_loop()
    loop_with_maxiter = create_while_loop(maximum_iterations=2)
    xla_context.Exit()

    with self.assertRaisesRegex(
        ValueError,
        r"Cannot create a gradient accumulator for tensor '.+' inside "
        r"XLA while_loop because maximum_iterations was not passed to "
        r"the tf.while_loop call \('.+'\)."):
      _ = gradients_impl.gradients(loop_no_maxiter, v)

    with self.assertRaisesRegex(
        ValueError,
        r"Cannot create a gradient accumulator for tensor '.+' inside XLA "
        r"while_loop. maximum_iterations tensor '.+' for while_loop context "
        r"'.+' must be statically known \(e.g. a constant value or known "
        r"shape dimension\), or be defined at or outside the while loop "
        r"context '.*' \(currently defined in '.*'\)"):
      _ = gradients_impl.gradients(loop_with_maxiter, v)

  @test_util.run_v1_only("b/120545219")
  def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
    v = constant_op.constant(1.0)

    def create_while_loop():
      max_iter_holder = []

      def create_mi():
        max_iter_holder.append(array_ops.placeholder(dtypes.int32, shape=()))
        return 1.0

      _ = tf_cond.cond(
          constant_op.constant(True), create_mi, create_mi)

      return while_loop_tf.while_loop(
          lambda i, _: i < 3,
          lambda i, x: (i + 1, v * x), (0, 1.0),
          maximum_iterations=max_iter_holder[0])

    if control_flow_util.ENABLE_CONTROL_FLOW_V2:
      xla_context = control_flow_ops.XLAControlFlowContext()
      xla_context.Enter()
      with self.assertRaisesRegex(ValueError, r"must be from the same graph.*"):
        loop = create_while_loop()
      xla_context.Exit()
    else:
      xla_context = control_flow_ops.XLAControlFlowContext()
      xla_context.Enter()
      loop = create_while_loop()
      xla_context.Exit()
      with self.assertRaisesRegex(
          ValueError,
          r"Cannot create a gradient accumulator for tensor '.+' inside XLA "
          r"while_loop. maximum_iterations tensor '.*Placeholder:0' for "
          r"while_loop context '.+' must be statically known \(e.g. a constant "
          r"value or known shape dimension\), or be defined at or outside the "
          r"while loop context '' \(currently defined in 'cond/.+'\)"):
        _ = gradients_impl.gradients(loop, v)

  @test_util.run_v1_only("b/120545219")
  def testNestedWhileLoopWithMaxItersFromOuterContextInXLAContext(self):
    if test_util.is_gpu_available():
      self.skipTest("b/128646372, b/128645947 fails in opensource build")

    v = constant_op.constant(1.0)

    p = array_ops.placeholder(dtype=dtypes.int32)

    def mid_body_builder(iterations):

      def mid_body(i, x):
        r = while_loop_tf.while_loop(
            lambda *_: True,
            lambda i, x: (i + 1, v * x), (0, x),
            maximum_iterations=iterations,
            name="inner")
        return (i + 1, gradients_impl.gradients(x + r[1], v)[0])

      return mid_body

    def outer_body(i, x):
      iterations = array_ops.size(p, name="iterations")
      return (i + 1, x + while_loop_tf.while_loop(
          lambda *_: True,
          mid_body_builder(iterations), (0, x),
          maximum_iterations=iterations,
          name="mid")[1])

    def create_while_loop():
      with ops.device("/cpu:0"):
        r = while_loop_tf.while_loop(
            lambda *_: True,
            outer_body, (0, 1.0),
            maximum_iterations=5,
            name="outer")
        return array_ops.identity(r[1])

    xla_context = control_flow_ops.XLAControlFlowContext()
    xla_context.Enter()
    final_with_xla_context = create_while_loop()
    xla_context.Exit()

    final_without_xla_context = create_while_loop()

    with self.session(use_gpu=False) as sess:
      opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
      run_metadata_without_xla_context = config_pb2.RunMetadata()
      run_metadata = config_pb2.RunMetadata()

      final_value_without_xla_context = sess.run(
          final_without_xla_context,
          feed_dict={p: [0, 0, 0]},
          options=opts,
          run_metadata=run_metadata_without_xla_context)

      final_value_with_xla_context = sess.run(
          final_with_xla_context,
          feed_dict={p: [0, 0, 0]},
          options=opts,
          run_metadata=run_metadata)

      if control_flow_util.ENABLE_CONTROL_FLOW_V2:
        # With while_v2 on xla, run_metadata only contains the unlowered While
        # op so node_stats does not have statistics for the pushes. So as a
        # loose check we check the pushes in the lowered version.
        for dev in run_metadata_without_xla_context.step_stats.dev_stats:
          if "/device:CPU" in dev.device:
            node_stats = dev.node_stats
        stack_push_count = len([
            x for x in node_stats
            if re.match(r".*TensorListPushBack_?\d*", x.node_name)
        ])
      else:
        for dev in run_metadata.step_stats.dev_stats:
          if "/device:CPU" in dev.device:
            node_stats = dev.node_stats
        stack_push_op = "StackPushV2"
        stack_push_count = len(
            [x for x in node_stats if x.node_name.endswith("StackPushV2")])
      # Pushes to the stack = product of maximum_iterations values;
      # the last two "3"s comes from size(p), when p == [0, 0, 0].
      self.assertEqual(stack_push_count, 5 * 3 * 3, str(node_stats))

      self.assertAllClose(final_value_with_xla_context,
                          final_value_without_xla_context)

  # Have more than 10 parallel iterations and hence exercise k-bound
  # most of the time.
  @test_util.run_deprecated_v1
  def testWhile_3(self):
    with self.cached_session():

      def compute(i, m, c, o):
        m, c = [math_ops.add(m, 1), math_ops.add(c, 1)]
        o = math_ops.add(o, m)
        o = math_ops.add(o, c)
        i = math_ops.add(i, 1)
        return [i, m, c, o]

      i = ops.convert_to_tensor(0)
      m = ops.convert_to_tensor(0)
      c = ops.convert_to_tensor(0)
      o = ops.convert_to_tensor(0)
      d = ops.convert_to_tensor(100)
      r = while_loop_tf.while_loop(
          lambda i, m, c, o: math_ops.less(i, d), compute, [i, m, c, o])
      result = r[3]
    self.assertAllEqual(10100, result)

  @test_util.run_deprecated_v1
  def testWhile_4(self):
    with self.cached_session():

      def compute(i, m, c, o):
        m, c = [array_ops.gather(x, i), array_ops.gather(x, i)]
        o = math_ops.add(o, m)
        o = math_ops.add(o, c)
        i = math_ops.add(i, 1)
        return [i, m, c, o]

      i = ops.convert_to_tensor(0)
      m = ops.convert_to_tensor(0)
      c = ops.convert_to_tensor(0)
      o = ops.convert_to_tensor(0)
      x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
      s = array_ops.size(x)
      r = while_loop_tf.while_loop(
          lambda i, m, c, o: math_ops.less(i, s), compute, [i, m, c, o])
      result = r[3]
    self.assertAllEqual(42, result)

  @test_util.run_v1_only("b/120545219")
  def testWhile_5(self):
    with self.cached_session():

      def compute(i, c, o):
        c = array_ops.strided_slice(x, array_ops.expand_dims(i, 0),
                                    [1] + array_ops.expand_dims(i, 0))
        o = array_ops.concat([o, c], 0)
        i = math_ops.add(i, 1)
        return [i, c, o]

      i = ops.convert_to_tensor(0)
      c = ops.convert_to_tensor([0])
      o = ops.convert_to_tensor([0])
      x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
      s = array_ops.size(x)
      r = while_loop_tf.while_loop(
          lambda i, c, o: math_ops.less(i, s),
          compute,
          [i, c, o],
          [
              i.get_shape(),
              tensor_shape.unknown_shape(),
              tensor_shape.unknown_shape(),
          ],
      )
      result = r[2]
    self.assertAllEqual(np.array([0, 1, 2, 3, 4, 5, 6]), result)

  @test_util.run_gpu_only
  @test_util.run_deprecated_v1
  def testWhile_Device(self):

    # Body function defined outside of device scope
    def body(x):
      return math_ops.exp(x)

    with ops.device("CPU:0"):
      r = while_loop_tf.while_loop(
          lambda x: x < 10, body, [constant_op.constant(-10.0)])
      self.assertIn("cpu", r.device.lower())

    with session.Session() as sess:
      options = config_pb2.RunOptions(output_partition_graphs=True)
      run_metadata = config_pb2.RunMetadata()
      sess.run(r, options=options, run_metadata=run_metadata)
      # We expect that everything runs on CPU, even if GPU is available.
      self.assertEqual(len(run_metadata.partition_graphs), 1)

  @test_util.disable_control_flow_v2("b/116338794 (buffer_reuse)")
  @test_util.run_v1_only("b/120545219")
  def testBufferForwarding(self):
    run_options = config_pb2.RunOptions(
        trace_level=config_pb2.RunOptions.FULL_TRACE)
    run_metadata = config_pb2.RunMetadata()

    with self.cached_session() as sess:
      with ops.device("/cpu:0"):
        c = constant_op.constant(2)
        i0 = constant_op.constant(0)
        r = while_loop_tf.while_loop(
            lambda i: i < 1000, lambda i: math_ops.square(c) + i, [i0])
      r_val = sess.run(r, options=run_options, run_metadata=run_metadata)
      self.assertEqual(1000, r_val)
      self.assertTrue(run_metadata.HasField("step_stats"))
      unique_allocs = set()
      for node_stat in run_metadata.step_stats.dev_stats[0].node_stats:
        for output in node_stat.output:
          unique_allocs.add(
              output.tensor_description.allocation_description.ptr)
      # Prior to cl/147536680, the number of unique allocations was about 1005.
      self.assertLess(len(unique_allocs), 756)

  def _testWhile_Gpu_1(self, use_gpu):
    with self.cached_session(use_gpu=use_gpu):
      n = constant_op.constant(1.0)
      c = lambda x: math_ops.less(x, 10.0)
      b = lambda x: math_ops.add(x, 1.0)
      r = while_loop_tf.while_loop(c, b, [n])
      self.assertAllClose(10.0, self.evaluate(r))

  def testWhile_Gpu_1(self):
    self._testWhile_Gpu_1(use_gpu=False)
    self._testWhile_Gpu_1(use_gpu=True)

  def _testWhile_Gpu_2(self, use_gpu):
    with self.cached_session(use_gpu=use_gpu):
      n = constant_op.constant(1.0)
      c = lambda x: math_ops.less(x, 10.0)

      def b(x):
        with ops.device("/cpu:0"):
          return math_ops.add(x, 1.0)

      r = while_loop_tf.while_loop(c, b, [n])
      self.assertAllClose(10.0, self.evaluate(r))

  def testWhile_Gpu_2(self):
    self._testWhile_Gpu_2(use_gpu=False)
    self._testWhile_Gpu_2(use_gpu=True)

  def testWhileShape(self):
    with self.cached_session():
      i = constant_op.constant(0)
      m = array_ops.ones([2, 2])
      c = lambda i, j: math_ops.less(i, 2)

      def _b(i, j):
        new_i = math_ops.add(i, 1)
        new_j = array_ops.tile(j, [2, 2])
        return [new_i, new_j]

      r = while_loop_tf.while_loop(
          c, _b, [i, m],
          [i.get_shape(), tensor_shape.unknown_shape()])
      r = r[1] * array_ops.ones([8, 8])
      self.assertAllEqual(np.ones((8, 8)), self.evaluate(r))

  @test_util.disable_control_flow_v2("b/131265085")
  @test_util.run_v1_only("b/131265085")
  def testWhileBadShape(self):
    x = constant_op.constant([2.0, 4.0], name="values")
    i = constant_op.constant(0)
    c = lambda i, _: math_ops.less(i, 10)
    b = lambda i, x: [i + 1, x + 1]
    with self.assertRaisesRegex(ValueError, "is not compatible with"):
      # Shape of x is [2], but we specify a shape of [5].
      while_loop_tf.while_loop(
          c, b, [i, x], [i.shape, tensor_shape.TensorShape([5])])

  @test_util.run_in_graph_and_eager_modes
  def testWhileBadBodyReturn(self):
    x = constant_op.constant([2.0, 4.0], name="values")
    i = constant_op.constant(0)
    c = lambda i, *x: math_ops.less(i, 10)

    # body accepts N values and returns N+1 values.
    b = lambda i, *x: (i, i) + x

    with self.assertRaisesRegex(
        ValueError, "The two structures don't have the same nested structure."):
      while_loop_tf.while_loop(c, b, [i, x])

  @test_util.run_deprecated_v1
  def testWhileWithNonTensorInput_Scalar(self):
    with self.cached_session():
      n = 0
      c = lambda x: x < 10000
      b = lambda x: x + 1
      r = while_loop_tf.while_loop(c, b, [n], parallel_iterations=20)
      self.assertEqual(10000, self.evaluate(r))

  def testWhileWithNonTensorInput_Vector(self):
    with self.cached_session():
      n = np.array([0])  # Note, [0] would not work here; that is a list
      c = lambda x: x[0] < 10000
      b = lambda x: array_ops_stack.stack([x[0] + 1])
      r = while_loop_tf.while_loop(c, b, [n], parallel_iterations=20)
      self.assertEqual([10000], self.evaluate(r))

  def testWhileShapeInference(self):
    with self.cached_session():
      i = constant_op.constant(0)
      m = array_ops.ones([2, 2])
      c = lambda i, j: math_ops.less(i, 2)

      def b(i, j):
        new_i = math_ops.add(i, 1)
        new_j = array_ops.concat([j, j], 0)
        return [new_i, new_j]

      r = while_loop_tf.while_loop(
          c, b, [i, m],
          [i.get_shape(), tensor_shape.TensorShape([None, 2])])
      self.assertTrue(r[1].shape.is_compatible_with([8, 2]))

  @test_util.run_v1_only("b/120545219")
  def testWhileShapeInferenceBadShape(self):
    with self.cached_session():
      i = constant_op.constant(0)
      m = array_ops.ones([2, 2])
      c = lambda i, j: math_ops.less(i, 2)
      b = lambda i, j: [i + 1, array_ops.concat([j, j], 0)]
      with self.assertRaisesRegex(
          ValueError,
          r".*\(2, 2\).*\(4, 2\) after one iteration\. To allow the shape to "
          r"vary across iterations, use the `shape_invariants` argument of "
          r"tf.while_loop to specify a less-specific shape\."):
        while_loop_tf.while_loop(c, b, [i, m])

  def testWhileShapeInferenceSparseTensor(self):
    values = constant_op.constant([2.0, 4.0], name="values")
    indices = constant_op.constant([[0], [3]],
                                   dtype=dtypes.int64,
                                   name="indices")
    shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
    i = constant_op.constant(0)
    x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)

    def c(i, _):
      return i < 10

    def b1(i, x):  # modifies values.  (shape of components is not changed.)
      return [
          i + 1,
          sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape)
      ]

    def b2(i, x):  # adds new values.  (shape of components is changed.)
      return [
          i + 1,
          sparse_ops.sparse_add(
              x,
              sparse_tensor.SparseTensor(
                  indices=math_ops.cast(
                      array_ops.fill([1, 1], i), dtypes.int64),
                  values=array_ops.fill([1], 1.0),
                  dense_shape=x.dense_shape))
      ]

    def b3(i, x):  # modifies rank.  (shape of all components is changed.)
      return [
          i + 1,
          sparse_tensor.SparseTensor(
              array_ops.concat([x.indices, [[i], [i]]], axis=1), x.values * 2.0,
              array_ops.concat([x.dense_shape, [10]], axis=0))
      ]

    def check_shapes(r, indices, values, dense_shape):
      self.assertTrue(r.indices.shape.is_compatible_with(indices))
      self.assertTrue(r.values.shape.is_compatible_with(values))
      self.assertTrue(r.dense_shape.shape.is_compatible_with(dense_shape))

    # Default shape invariant; b1 only modifies values.
    _, r = while_loop_tf.while_loop(c, b1, [i, x])
    check_shapes(r, indices=[None, 1], values=[None], dense_shape=[1])

    # Default shape invariant; b2 adds new values
    _, r = while_loop_tf.while_loop(c, b2, [i, x])
    check_shapes(r, indices=[None, 1], values=[None], dense_shape=[1])

    # Explicit shape invariant, allowing any rank; b1 only modifies values.
    _, r = while_loop_tf.while_loop(
        c, b1, [i, x],
        [i.get_shape(), tensor_shape.TensorShape([None])])
    check_shapes(r, indices=[None, None], values=[None], dense_shape=[None])

    # Explicit shape invariant, allowing any rank; b3 modifies rank.
    _, r = while_loop_tf.while_loop(
        c, b3, [i, x],
        [i.get_shape(), tensor_shape.TensorShape([None])])
    check_shapes(r, indices=[None, None], values=[None], dense_shape=[None])

    # Shape invariant with ndims=None.  Technically, this isn't supported
    # according to the docs, but we support it for backwards compatibility.
    _, r = while_loop_tf.while_loop(
        c, b1, [i, x],
        [i.get_shape(), tensor_shape.TensorShape(None)])
    check_shapes(r, indices=[None, None], values=[None], dense_shape=[None])
    _, r = while_loop_tf.while_loop(
        c, b3, [i, x],
        [i.get_shape(), tensor_shape.TensorShape(None)])
    check_shapes(r, indices=[None, None], values=[None], dense_shape=[None])

  @test_util.disable_control_flow_v2("b/131265085")
  @test_util.run_v1_only("b/131265085")
  def testWhileBadShapeSparseTensor(self):
    values = constant_op.constant([2.0, 4.0], name="values")
    indices = constant_op.constant([[0], [3]],
                                   dtype=dtypes.int64,
                                   name="indices")
    shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
    i = constant_op.constant(0)
    x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
    c = lambda i, _: i < 10
    b1 = lambda i, x: [i+1, x]
    def b2(i, x):  # modifies rank.  (shape of all components is changed.)
      return [
          i + 1,
          sparse_tensor.SparseTensor(
              array_ops.concat([x.indices, [[i], [i]]], axis=1), x.values * 2.0,
              array_ops.concat([x.dense_shape, [10]], axis=0))
      ]

    # Explicit shape invariant, with a specific (incompatible) rank.
    with self.assertRaisesRegex(ValueError, "is not compatible with"):
      while_loop_tf.while_loop(
          c, b1, [i, x],
          [i.get_shape(), tensor_shape.TensorShape([5])])

    # Default shape invariant, but b2 modifies rank (which is not allowed).
    with self.assertRaises(ValueError):
      while_loop_tf.while_loop(c, b2, [i, x])

  def testWhileShapeInferenceIndexedSlices(self):
    with self.cached_session():
      values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values")
      indices = constant_op.constant([0, 3], name="indices")
      shape = constant_op.constant([10, 2], name="dense_shape")
      i = constant_op.constant(0)
      x = indexed_slices.IndexedSlices(values, indices, dense_shape=shape)

      def c(i, _):
        return i < 10

      def b(i, x):
        return [
            i + 1,
            indexed_slices.IndexedSlices(x.values * 2.0, x.indices,
                                         x.dense_shape)
        ]

      _, r = while_loop_tf.while_loop(c, b, [i, x])
      self.assertEqual(r.dense_shape.get_shape()[0], 2)
      self.assertEqual(r.values.get_shape(), tensor_shape.TensorShape([2, 2]))

      _, r = while_loop_tf.while_loop(
          c, b, [i, x],
          [i.get_shape(), tensor_shape.TensorShape([None, 2])])
      self.assertEqual(r.dense_shape.get_shape()[0], 2)
      self.assertTrue(r.values.get_shape().is_compatible_with([None, 2]))

  @test_util.disable_control_flow_v2("b/131265085")
  @test_util.run_v1_only("b/131265085")
  def testWhileBadShapeIndexedSlices(self):
    values = constant_op.constant([2.0, 4.0], name="values")
    indices = constant_op.constant([[0], [3]],
                                   dtype=dtypes.int64,
                                   name="indices")
    shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
    i = constant_op.constant(0)
    x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
    c = lambda i, _: 10
    b = lambda i, x: [i+1, x]

    # Explicit shape invariant, with a specific (incompatible) rank.
    with self.assertRaisesRegex(ValueError, "is not compatible with"):
      while_loop_tf.while_loop(
          c, b, [i, x],
          [i.get_shape(), tensor_shape.TensorShape([5])])

  def testWhileShapeInferenceRaggedTensor(self):
    i = constant_op.constant(0)
    x = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]])
    c = lambda i, _: i < 10

    def b1(i, x):  # Adds new values to rows (but doesn't create new rows)
      return [
          i + 1,
          array_ops.concat([x, x], axis=1)
      ]

    def b2(i, x):  # Adds new rows.
      return [
          i + 1,
          array_ops.concat([x, x], axis=0)
      ]

    def check_shapes(r, values, splits):
      self.assertTrue(r.values.shape.is_compatible_with(values))
      self.assertTrue(r.row_splits.shape.is_compatible_with(splits))

    # Default shape invariant; b1 adds new values to rows.
    _, r = while_loop_tf.while_loop(c, b1, [i, x])
    check_shapes(r, values=[None], splits=[4])

    # Default shape invariant; b2 adds new rows (not allowed).
    if not context.executing_eagerly():
      with self.assertRaises(ValueError):
        _, r = while_loop_tf.while_loop(c, b2, [i, x])

    # Explicit shape invariant; b1 adds new values to rows.
    # (deprecated: use TensorShape instead of RaggedTensorSpec)
    _, r = while_loop_tf.while_loop(
        c, b1, [i, x],
        [i.get_shape(), tensor_shape.TensorShape([None, None])])
    check_shapes(r, values=[None], splits=[None])

    # Explicit shape invariant; b1 adds new values to rows.
    _, r = while_loop_tf.while_loop(c, b1, [i, x], [
        i.get_shape(),
        ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)
    ])
    check_shapes(r, values=[None], splits=[None])

    # Explicit shape invariant; b2 adds new rows.
    _, r = while_loop_tf.while_loop(c, b2, [i, x], [
        i.get_shape(),
        ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)
    ])
    check_shapes(r, values=[None], splits=[None])

  def testWhileShapeInferenceRaggedTensorRaggedRank2(self):
    i = constant_op.constant(0)
    x = ragged_factory_ops.constant([[[1, 2], [3], [4, 5, 6]],
                                     [[], [8, 9, 10]]])
    c = lambda i, _: i < 10
    def b(i, x):
      return [
          i + 1,
          array_ops.concat([x, x[..., i:i+1]], axis=-1)
      ]

    _, r = while_loop_tf.while_loop(c, b, [i, x])
    self.assertEqual(r.row_splits.shape.as_list(), [3])
    self.assertIn(r.values.row_splits.shape.as_list(), ([6], [None]))
    self.assertIn(r.values.values.shape.as_list(), ([49], [None]))

  def testWhileShapeInvariantTensorSpec(self):
    i = constant_op.constant(0)
    x = constant_op.constant([1])
    c = lambda i, _: i < 10
    b = lambda i, x: (i + 1, array_ops_stack.stack([x, x]))
    shape_invariants = [
        tensor_lib.TensorSpec([], dtype=dtypes.int32),
        tensor_lib.TensorSpec(None, dtype=dtypes.int32)]
    while_loop_tf.while_loop(c, b, [i, x], shape_invariants)

  # TODO(b/131265085) Remove this decorator when bug is fixed.
  @test_util.build_as_function_and_v1_graph
  def testWhileShapeInvariantWrongTypeSpecType(self):
    c = lambda i, _: i < 10
    b = lambda i, x: (i + 1, x)
    i = constant_op.constant(0)
    x = sparse_tensor.SparseTensor([[0]], [1.0], [10])
    shape_invariants = [
        tensor_lib.TensorSpec([], dtype=dtypes.int32),
        sparse_tensor.SparseTensorSpec([None])]
    while_loop_tf.while_loop(c, b, [i, x], shape_invariants)

    x2 = constant_op.constant([1])
    with self.assertRaises(TypeError):
      while_loop_tf.while_loop(c, b, [i, x2], shape_invariants)

    x3 = ragged_factory_ops.constant([[1, 2], [3]])
    with self.assertRaises(TypeError):
      while_loop_tf.while_loop(c, b, [i, x3], shape_invariants)

    i2 = constant_op.constant(0.0)
    with self.assertRaises(TypeError):
      while_loop_tf.while_loop(c, b, [i2, x], shape_invariants)

  # TODO(b/131265085) Remove this decorator when bug is fixed.
  @test_util.build_as_function_and_v1_graph
  def testWhileShapeInvariantBadType(self):
    i = constant_op.constant(0)
    x = constant_op.constant([1])
    c = lambda i, _: i < 10
    b = lambda i, x: (i + 1, x)
    with self.assertRaises((ValueError, TypeError)):
      while_loop_tf.while_loop(c, b, [i, x], ["foo", "bar"])

  def _testNestedWhile_1(self, use_gpu):
    with self.cached_session(use_gpu=use_gpu):
      n = constant_op.constant(0)

      def cpu_sum(s):
        c = lambda i, s: math_ops.less(i, 10)

        def b(i, s):
          i1 = math_ops.add(i, 1)
          with ops.device("/cpu:0"):
            s1 = math_ops.add(i, s)
          return i1, s1

        _, r_s = while_loop_tf.while_loop(c, b, [n, s])
        return r_s

      c = lambda x: math_ops.less(x, 200)
      b = lambda x: math_ops.add(x, cpu_sum(n))
      r = while_loop_tf.while_loop(c, b, [n])
      self.assertEqual(225, self.evaluate(r))

  def testNestedWhile_1(self):
    self._testNestedWhile_1(use_gpu=False)
    self._testNestedWhile_1(use_gpu=True)

  def _testNestedWhile_2(self, use_gpu):
    # Test the cases that A -> Enter and Exit -> A are partitioned.
    with self.cached_session(use_gpu=use_gpu):
      s0 = constant_op.constant(2.0)

      def inner_loop(s):
        c = lambda s: math_ops.less(s, 20.0)

        def b(s):
          s1 = math_ops.add(s, s)
          return s1

        r_s = while_loop_tf.while_loop(c, b, [s], parallel_iterations=1)
        return r_s

      outer_c = lambda x: math_ops.less(x, 3000.0)

      def outer_b(x):
        x = logging_ops.Print(x, [x])  # Edge "Print -> Enter" is partitioned
        x = inner_loop(x)
        with ops.device("/cpu:0"):
          x = math_ops.square(x)  # Edge "Exit -> Square" is partitioned
        return x

      r = while_loop_tf.while_loop(
          outer_c, outer_b, [s0], parallel_iterations=1)
      self.assertEqual(1048576.0, self.evaluate(r))

  def testNestedWhile_2(self):
    self._testNestedWhile_2(use_gpu=False)
    self._testNestedWhile_2(use_gpu=True)

  @test_util.run_v1_only("b/120545219")
  def testWhileWithControl_1(self):
    with self.cached_session():
      n = constant_op.constant(0)
      r = constant_op.constant(0)
      condition = lambda n_, r_: math_ops.less(n_, 10)

      def body(n_, r_):
        n_ = math_ops.add(n_, 1)
        with r_.graph.control_dependencies([r_]):
          r_ = constant_op.constant(12)
        return [n_, r_]

      res = while_loop_tf.while_loop(
          condition, body, [n, r], parallel_iterations=1)
      self.assertAllEqual(12, res[1])

  @test_util.run_deprecated_v1
  def testWhileWithControl_2(self):
    with self.cached_session():
      r = constant_op.constant(0)
      condition = lambda r_: math_ops.less(r_, 10)

      def body(r_):
        with r_.graph.control_dependencies([r_]):
          r_ = constant_op.constant(12)
        return [r_]

      res = while_loop_tf.while_loop(
          condition, body, [r], parallel_iterations=1)
      self.assertAllEqual(12, self.evaluate(res))

  @test_util.run_v1_only("b/120545219")
  def testWhileWithControl_3(self):
    with self.cached_session() as sess:
      b = array_ops.placeholder(dtypes.bool)
      c = constant_op.constant(1)
      x0 = constant_op.constant(0)
      with ops.control_dependencies([b]):
        r = while_loop_tf.while_loop(lambda x: x < 10, lambda x: x + c, [x0])
      self.assertEqual(10, sess.run(r, {b: True}))

  @test_util.run_v1_only("b/120545219")
  def testWhileWithControl_4(self):
    with self.cached_session() as sess:
      b = array_ops.placeholder(dtypes.bool)
      c = constant_op.constant(1)
      x0 = constant_op.constant(0)
      with ops.control_dependencies([b]):
        r = while_loop_tf.while_loop(
            lambda x: x < 10, lambda x: x + array_ops.identity(c), [x0])
      self.assertEqual(10, sess.run(r, {b: True}))

  @test_util.run_v1_only("b/120545219")
  def testWhileWithControl_5(self):
    with self.cached_session() as sess:
      b = array_ops.placeholder(dtypes.bool)
      c = constant_op.constant(1)
      x0 = constant_op.constant(0)

      def body(x):
        with ops.control_dependencies([b]):
          return x + c

      r = while_loop_tf.while_loop(lambda x: x < 10, body, [x0])
      self.assertEqual(10, sess.run(r, {b: True}))

  def testWhileCondWithControl(self):
    # Ensure that no control edges by an outer control dependency context are
    # added to nodes inside cond/while contexts.
    with self.cached_session() as sess:
      const_true = lambda: constant_op.constant(True)
      const_false = lambda: constant_op.constant(False)
      cond = lambda i: tf_cond.cond(i > 0, const_true, const_false)
      body = lambda i: tf_cond.cond(i > 0, lambda: i - 1, lambda: i)

      with ops.control_dependencies([control_flow_ops.no_op()]):
        loop = while_loop_tf.while_loop(cond, body, (constant_op.constant(5),))
      self.assertEqual(0, self.evaluate(loop))

  @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
  @test_util.run_v1_only("b/120545219")
  def testWhileCondWithControl_1(self):
    with self.cached_session():
      v = variable_scope.get_variable(
          "v", [], initializer=init_ops.constant_initializer(2))
      i0 = constant_op.constant(0)
      with ops.control_dependencies([i0]):

        def loop_condition(i):
          return i < 4

        def loop_body(i):
          some_cond = tf_cond.cond(
              constant_op.constant(True),
              lambda: state_ops.assign(v, math_ops.square(v)), lambda: v)
          with ops.control_dependencies([some_cond]):
            return i + 1

      r = while_loop_tf.while_loop(loop_condition, loop_body, (i0,))
      self.evaluate(variables.global_variables_initializer())
      self.assertEqual(4, self.evaluate(r))
      self.assertAllClose(65536.0, self.evaluate(v))

  @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
  @test_util.run_v1_only("b/120545219")
  def testWhileCondExitControl(self):

    with self.cached_session():
      v = variables.Variable(1)

      def false_branch():
        cond = lambda i: i < 100

        def body(i):
          x = state_ops.assign(v, i)
          return x + 1

        loop = while_loop_tf.while_loop(cond, body, [0])
        # Make sure to handle correctly control edge from Exit to a node.
        with ops.control_dependencies([loop]):
          return constant_op.constant(6.0)

      r = tf_cond.cond(
          constant_op.constant(False), lambda: constant_op.constant(1.0),
          false_branch)
      self.evaluate(variables.global_variables_initializer())
      self.assertEqual(6.0, self.evaluate(r))
      self.assertEqual(99, self.evaluate(v))

  def testCondWhile_1(self):

    with self.cached_session():
      n = ops.convert_to_tensor(0, name="n")
      c = lambda x: math_ops.less(x, 10)
      b = lambda x: math_ops.add(x, 1)
      r = tf_cond.cond(
          math_ops.less(0, 1), lambda: while_loop_tf.while_loop(c, b, [n]),
          lambda: n)
      self.assertAllEqual(10, self.evaluate(r))

  def testCondWhile_2(self):

    with self.cached_session():
      n = ops.convert_to_tensor(0)
      c = lambda x: math_ops.less(x, 10)
      b = lambda x: math_ops.add(x, 1)
      r = tf_cond.cond(
          math_ops.less(1, 0), lambda: math_ops.add(n, 1),
          lambda: while_loop_tf.while_loop(c, b, [n]))
      self.assertAllEqual(10, self.evaluate(r))

  def _testCondWhile_3(self, use_gpu):
    with self.cached_session(use_gpu=use_gpu) as sess:
      p = array_ops.placeholder(dtypes.bool)
      n = constant_op.constant(0.0)

      def c(x):
        return math_ops.less(x, 10.0)

      def b(x):
        with ops.device("/cpu:0"):
          x1 = math_ops.add(x, 1.0)
        return x1

      r = tf_cond.cond(
          p,
          lambda: while_loop_tf.while_loop(c, b, [n]),
          lambda: math_ops.multiply(n, 2.0),
      )
      r1 = gradients_impl.gradients(r, [n])
      self.assertEqual(10., sess.run(r, {p: True}))
      self.assertEqual([1.0], sess.run(r1, {p: True}))
      self.assertEqual(0.0, sess.run(r, {p: False}))
      self.assertEqual([2.0], sess.run(r1, {p: False}))

  @test_util.run_deprecated_v1
  def testCondWhile_3(self):
    self._testCondWhile_3(use_gpu=False)
    self._testCondWhile_3(use_gpu=True)

  def testWhileCond_1(self):

    with self.cached_session():
      i = ops.convert_to_tensor(0, name="i")
      n = ops.convert_to_tensor(10, name="n")
      one = ops.convert_to_tensor(1, name="one")
      c = lambda x: math_ops.less(x, n)
      # pylint: disable=undefined-variable
      # for OSS build
      b = lambda x: tf_cond.cond(
          constant_op.constant(True),
          lambda: math_ops.add(x, one), lambda: math_ops.subtract(x, one))
      # pylint: enable=undefined-variable
      r = while_loop_tf.while_loop(c, b, [i])
      self.assertAllEqual(10, self.evaluate(r))

  def testWhileCond_2(self):

    with self.cached_session():
      n = ops.convert_to_tensor(0, name="n")
      c = lambda x: math_ops.less(x, 10)
      b = lambda x: tf_cond.cond(
          constant_op.constant(True), lambda: math_ops.add(x, 1), lambda: n)
      r = while_loop_tf.while_loop(c, b, [n])
      self.assertAllEqual(10, self.evaluate(r))

  def testWhileCond_3(self):

    with self.cached_session():
      n = ops.convert_to_tensor(0)
      c = lambda x: math_ops.less(x, 10)
      # pylint: disable=undefined-variable
      # for OSS build
      b = lambda x: tf_cond.cond(math_ops.less(0, 1),
                                 lambda: math_ops.add(x, 1),
                                 lambda: math_ops.subtract(x, 1))
      # pylint: enable=undefined-variable
      r = while_loop_tf.while_loop(c, b, [n])
      self.assertAllEqual(10, self.evaluate(r))

  @test_util.run_deprecated_v1
  def testWhileCondGradMultiDevice(self):
    config = config_pb2.ConfigProto(device_count={"CPU": 2},
                                    allow_soft_placement=True)
    with self.cached_session(config=config) as sess:
      pred = array_ops.placeholder(dtypes.bool, [])
      x_init = constant_op.constant(1.0)

      with ops.device("/cpu:0"):
        z = while_loop_tf.while_loop(
            lambda i, _: i < 3, lambda i, x:
            (i + 1, tf_cond.cond(pred, lambda: x * 2.0, lambda: 10.0)),
            [0, x_init])

      with ops.device("/cpu:1"):
        grad = gradients_impl.gradients(z, x_init)[0]

      with ops.device("/cpu:0"):
        grad_grad = gradients_impl.gradients(grad, x_init)[0]

      self.assertEqual(sess.run(grad, {pred: True}), 8.0)
      self.assertEqual(sess.run(grad, {pred: False}), 0.0)

      if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
        return

      self.assertEqual(sess.run(grad_grad, {pred: True}), 0.0)
      self.assertEqual(sess.run(grad_grad, {pred: False}), 0.0)

  # NOTE: It is ok to have parallel_iterations > 1
  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
  @test_util.run_deprecated_v1
  def testWhileUpdateVariable_1(self):
    with self.cached_session():
      select = variables.Variable([3.0, 4.0, 5.0])
      n = constant_op.constant(0)

      def loop_iterator(j):
        return math_ops.less(j, 3)

      def loop_body(j):
        ns = state_ops.scatter_update(select, j, 10.0)
        nj = math_ops.add(j, 1)
        op = control_flow_ops.group(ns)
        nj = control_flow_ops.with_dependencies([op], nj)
        return [nj]

      r = while_loop_tf.while_loop(
          loop_iterator, loop_body, [n], parallel_iterations=1)
      self.evaluate(variables.global_variables_initializer())
      self.assertEqual(3, self.evaluate(r))
      result = self.evaluate(select)
      self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)

  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
  @test_util.run_v1_only("b/120545219")
  def testWhileUpdateVariable_2(self):
    with self.cached_session():
      select1 = variables.Variable([3.0, 4.0, 5.0])
      select2 = variables.Variable([3.0, 4.0, 5.0])
      n = constant_op.constant(0)

      def loop_iterator(j):
        return math_ops.less(j, 3)

      def loop_body(j):
        ns1 = state_ops.scatter_update(select1, j, 10.0)
        ns2 = state_ops.scatter_update(select2, j, 10.0)
        nj = math_ops.add(j, 1)
        op = control_flow_ops.group(ns1, ns2)
        nj = control_flow_ops.with_dependencies([op], nj)
        return [nj]

      r = while_loop_tf.while_loop(
          loop_iterator, loop_body, [n], parallel_iterations=1)
      self.evaluate(variables.global_variables_initializer())
      self.assertEqual(3, self.evaluate(r))
      result1 = self.evaluate(select1)
      self.assertAllClose(np.array([10.0, 10.0, 10.0]), result1)
      result2 = self.evaluate(select2)
      self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2)

  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
  @test_util.run_v1_only("b/120545219")
  def testWhileUpdateVariable_3(self):
    with self.cached_session():
      select = variables.Variable([3.0, 4.0, 5.0])
      n = constant_op.constant(0)

      def loop_iterator(j, _):
        return math_ops.less(j, 3)

      def loop_body(j, _):
        ns = state_ops.scatter_update(select, j, 10.0)
        nj = math_ops.add(j, 1)
        return [nj, ns]

      r = while_loop_tf.while_loop(
          loop_iterator,
          loop_body, [n, array_ops.identity(select)],
          parallel_iterations=1)
      self.evaluate(variables.global_variables_initializer())
      result = r[1]
    self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)

  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
  @test_util.run_v1_only("b/120545219")
  def testWhileUpdateVariable_4(self):
    with self.cached_session():
      var_a = variables.Variable(0, name="a")
      var_b = variables.Variable(0, name="b")
      self.evaluate(variables.global_variables_initializer())

      c = constant_op.constant(0, name="c")
      asn1 = state_ops.assign_add(var_a, 1, name="a_add")

      # Loop condition
      def pred(i):
        return math_ops.less(i, 10)

      # Loop body
      def loop_body(i):
        asn2 = state_ops.assign_add(var_b, asn1, name="b_add")
        with ops.control_dependencies([asn2]):
          ni = math_ops.add(i, 1, name="i_add")
        return ni

      lpa = while_loop_tf.while_loop(
          pred, loop_body, [c], parallel_iterations=1)

      self.assertEqual(0, self.evaluate(var_b))
      self.evaluate(lpa)  # Run the loop
      self.assertEqual(10, self.evaluate(var_b))

  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
  @test_util.run_v1_only("b/120545219")
  def testWhileUpdateVariable_5(self):
    with self.cached_session():
      # Create some variables.
      var_a = variables.Variable(0, name="a")
      var_b = variables.Variable(0, name="b")
      self.evaluate(variables.global_variables_initializer())

      # Change condition to check var_b
      def pred(_):
        return math_ops.less(var_b, 10)

      # Change body to increment var_b
      def loop_body(i):
        asn1 = state_ops.assign_add(
            var_a, constant_op.constant(1), name="a_add")
        asn2 = state_ops.assign_add(
            var_b, constant_op.constant(1), name="b_add")
        with ops.control_dependencies([asn1, asn2]):
          inc_b = array_ops.identity(var_b)
        return inc_b

      lpa = while_loop_tf.while_loop(
          pred, loop_body, [var_b], parallel_iterations=1, name="loop")

      self.assertEqual(0, self.evaluate(var_b))
      self.evaluate(lpa)  # Run the loop
      self.assertEqual(10, self.evaluate(var_a))
      self.assertEqual(10, self.evaluate(var_b))

  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
  @test_util.run_v1_only("b/120545219")
  def testWhileUpdateVariable_6(self):
    with self.cached_session():
      # Create some variables.
      var_a = variables.Variable(0, name="a")
      var_b = variables.Variable(0, name="b")
      c = constant_op.constant(0)
      self.evaluate(variables.global_variables_initializer())

      # Loop condition
      def pred(i):
        return math_ops.less(i, 10)

      # Loop body
      def loop_body(i):
        asn1 = state_ops.assign_add(var_a, 1, name="a_add")
        with ops.control_dependencies([asn1]):
          asn2 = state_ops.assign_add(var_b, var_a, name="b_add")
        with ops.control_dependencies([asn2]):
          ni = math_ops.add(i, 1, name="i_add")
          return ni

      lpa = while_loop_tf.while_loop(
          pred, loop_body, [c], parallel_iterations=1, name="loop")

      self.assertEqual(0, self.evaluate(var_b))
      self.evaluate(lpa)  # Run the loop
      self.assertEqual(55, self.evaluate(var_b))
      self.assertEqual(10, self.evaluate(var_a))

  @test_util.run_v1_only("b/120545219")
  def testWhileQueue_1(self):
    with self.cached_session():
      q = data_flow_ops.FIFOQueue(-1, dtypes.int32)
      i = constant_op.constant(0)

      def c(i):
        return math_ops.less(i, 10)

      def b(i):
        ni = math_ops.add(i, 1)
        ni = control_flow_ops.with_dependencies([q.enqueue((i,))], ni)
        return ni

      r = while_loop_tf.while_loop(c, b, [i], parallel_iterations=1)
      self.assertEqual([10], self.evaluate(r))
      for i in range(10):
        self.assertEqual([i], self.evaluate(q.dequeue()))

  @test_util.run_v1_only("b/120545219")
  def testWhileTimeOut(self):
    run_options = config_pb2.RunOptions(timeout_in_ms=1)
    with self.cached_session() as sess:
      n = constant_op.constant(0)
      c = lambda x: True
      b = lambda x: math_ops.add(x, 1)
      r = while_loop_tf.while_loop(c, b, [n])
      with self.assertRaises(errors_impl.DeadlineExceededError):
        sess.run(r, options=run_options)

  @test_util.disable_control_flow_v2("b/117119329 (stack)")
  @test_util.run_v1_only("b/120545219")
  def testWhileStack_1(self):
    with self.cached_session():
      s = gen_data_flow_ops.stack_v2(-1, dtypes.int32, stack_name="foo")
      i = constant_op.constant(0)

      def c(i):
        return math_ops.less(i, 10)

      def b(i):
        ni = math_ops.add(i, 1)
        ni = control_flow_ops.with_dependencies(
            [gen_data_flow_ops.stack_push_v2(s, i)], ni)
        return ni

      r = while_loop_tf.while_loop(c, b, [i], parallel_iterations=1)

      x = constant_op.constant(0)

      def c1(i, _):
        return math_ops.greater(i, 0)

      def b1(i, x):
        ni = math_ops.subtract(i, 1)
        nx = x + gen_data_flow_ops.stack_pop_v2(s, dtypes.int32)
        return [ni, nx]

      _, rx = while_loop_tf.while_loop(
          c1,
          b1, [r, x],
          [r.get_shape(), tensor_shape.unknown_shape()],
          parallel_iterations=1)
      self.assertEqual(45, self.evaluate(rx))

  def _testWhileGrad_ColocateGradients(self, colocate):
    gpu_dev_name = test.gpu_device_name() if test.is_gpu_available(
    ) else "/device:CPU:0"

    graph = ops.Graph()
    with graph.as_default():
      v = constant_op.constant(2.0, name="v")
      c = lambda v: math_ops.less(v, 100.0)

      def b(x):
        with ops.device(gpu_dev_name):
          return math_ops.square(x)

      loop = while_loop_tf.while_loop(c, b, [v], parallel_iterations=1)
      r = gradients_impl.gradients(
          loop, v, colocate_gradients_with_ops=colocate)[0]

    r_ops = graph.get_operations()
    r_devices = [(op.name, op.device) for op in r_ops]

    self.assertTrue(any("Square" in op.name for op in r_ops))

    for (name, dev) in r_devices:
      if not colocate and name.endswith("Square"):
        # Only forward graph contain gpu in Square device
        self.assertTrue(gpu_dev_name in dev)
      elif colocate and "Square" in name:
        # Forward and backward graphs contain gpu in Square/Square_grad devices
        self.assertTrue(gpu_dev_name in dev)
      else:
        self.assertFalse(gpu_dev_name in dev)

    with self.session(graph=graph) as sess:
      self.assertAllClose(1024.0, self.evaluate(r))

  @test_util.disable_control_flow_v2("b/116351701 (colocation)")
  @test_util.run_v1_only("b/120545219")
  def testWhileGrad_ColocateGradients(self):
    self._testWhileGrad_ColocateGradients(colocate=False)
    self._testWhileGrad_ColocateGradients(colocate=True)

  @test_util.run_v1_only("b/120545219")
  def testWhileGrad_Square(self):
    with self.cached_session():
      v = constant_op.constant(2.0, name="v")
      c = lambda v: math_ops.less(v, 100.0)
      b = math_ops.square
      r = while_loop_tf.while_loop(c, b, [v], parallel_iterations=1)
      r = tf_cond.cond(math_ops.less(1, 2), lambda: r, lambda: v)

      r = gradients_impl.gradients(r, v)[0]
      self.assertAllClose(1024.0, self.evaluate(r))

  @test_util.run_v1_only("b/120545219")
  def testWhileGrad_Shape(self):
    with self.cached_session():
      x = array_ops.placeholder(dtypes.float32, shape=[None])
      v = constant_op.constant([2.0], name="v")
      n = constant_op.constant(0, name="n")
      c = lambda i, v: math_ops.less(i, 5)
      b = lambda i, v: [i + 1, math_ops.multiply(x, v)]
      r = while_loop_tf.while_loop(
          c,
          b, [n, v],
          [n.get_shape(), tensor_shape.unknown_shape()],
          parallel_iterations=1)

      r = gradients_impl.gradients(r[1], x)[0]
      self.assertEqual([None], r.get_shape().as_list())
      self.assertAllClose([810.0, 2560.0], r.eval(feed_dict={x: [3.0, 4.0]}))

  @test_util.run_deprecated_v1
  def testWhileGrad_BaseShape(self):
    with self.cached_session() as sess:
      x = array_ops.placeholder(dtypes.float32, [None])
      v0 = constant_op.constant([2.0, 2.0], name="v")
      c = lambda v: constant_op.constant(False)
      b = lambda v: math_ops.multiply(v, x)
      r = while_loop_tf.while_loop(c, b, [v0])
      y = math_ops.square(x)

      r = gradients_impl.gradients([r, y], x)[0]
      self.assertAllClose([2.0, 4.0], sess.run(r, feed_dict={x: [1.0, 2.0]}))

  @test_util.run_deprecated_v1
  @test_util.enable_output_all_intermediates
  def testWhileGradAfterSessionRun(self):
    v0 = constant_op.constant(2.)
    r = while_loop_tf.while_loop(
        lambda _: True, lambda v: v * v, [v0], maximum_iterations=3)

    self.assertAllEqual(r, 256.)
    grad = gradients_impl.gradients(r, v0)[0]
    self.assertAllClose(grad, 1024.)

  @test_util.run_deprecated_v1
  @test_util.enable_output_all_intermediates
  def testNestedWhileGradAfterSessionRun(self):
    v0 = constant_op.constant(2.)

    def body(v):
      inner_v0 = constant_op.constant(1.)
      return while_loop_tf.while_loop(
          lambda _: True, lambda x: x * v, [inner_v0], maximum_iterations=2)

    r = while_loop_tf.while_loop(
        lambda _: True, body, [v0], maximum_iterations=3)

    self.assertAllEqual(r, 256.)
    grad = gradients_impl.gradients(r, v0)[0]
    self.assertAllClose(grad, 1024.)

  @test_util.run_v1_only("b/120545219")
  def testWhileGrad_MultipleUses(self):
    with self.cached_session():
      v = constant_op.constant(2.0, name="v")
      c = lambda v: math_ops.less(v, 100.0)
      b = math_ops.square
      r = while_loop_tf.while_loop(c, b, [v], parallel_iterations=1)
      r = math_ops.multiply(r, r)

      r = gradients_impl.gradients(r, v)[0]
      self.assertEqual(524288.0, self.evaluate(r))

  @test_util.run_v1_only("b/120545219")
  def testWhileGrad_LoopAdd(self):
    with self.cached_session():
      v = constant_op.constant(2.0, name="v")
      c = lambda v: math_ops.less(v, 100.0)
      b = math_ops.square
      r = while_loop_tf.while_loop(c, b, [v], parallel_iterations=1)
      r = math_ops.add(r, r)

      r = gradients_impl.gradients(r, v)[0]
      self.assertAllClose(2048.0, self.evaluate(r))

  def _testWhileGrad_Mul(self, use_gpu, p_iters):
    with self.cached_session(use_gpu=use_gpu) as sess:
      a = constant_op.constant(3.0, name="a")
      v = constant_op.constant(2.0, name="v")
      c = lambda v: math_ops.less(v, 100.0)
      b = lambda v: math_ops.multiply(v, a)
      r = while_loop_tf.while_loop(c, b, [v], parallel_iterations=p_iters)

      grad_a, grad_v = gradients_impl.gradients(r, [a, v])
      grad_a_val, grad_v_val = self.evaluate([grad_a, grad_v])
      self.assertAllClose(216.0, grad_a_val)
      self.assertAllClose(81.0, grad_v_val)

  @test_util.run_deprecated_v1
  def testWhileGrad_Mul(self):
    self._testWhileGrad_Mul(use_gpu=False, p_iters=1)
    self._testWhileGrad_Mul(use_gpu=False, p_iters=10)
    self._testWhileGrad_Mul(use_gpu=True, p_iters=1)
    self._testWhileGrad_Mul(use_gpu=True, p_iters=10)

  def testWhileGradInControlDeps(self):

    @eager_def_function.function
    def f():
      x_init = constant_op.constant(2.)
      loop_cond = lambda i, x: math_ops.less(i, 2)
      loop_body = lambda i, x: [i + 1, x**2]
      _, x = while_loop_tf.while_loop(loop_cond, loop_body, [0, x_init])
      with ops.control_dependencies([x]):
        (grad,) = gradients_impl.gradients(x, x_init)
        return grad

    self.assertAllEqual(f(), 4. * 2.**3)  # 4 * x_init ^ 3

  @test_util.run_deprecated_v1
  def testTfFunctionInV1WhileLoop(self):

    # This test specifically tests that creating a Const node inside a
    # tf.function inside a v1 while_loop while inlining is turned on works.
    config = opt_cfg()
    assert config.graph_options.optimizer_options.do_function_inlining
    with session.Session(config=config):

      @eager_def_function.function
      def loop_body(i):
        # Here we create the const.
        return i + 1.

      loop_cond = lambda i: True
      x = while_loop_tf.while_loop(
          loop_cond, loop_body, [0.], maximum_iterations=5)
      self.assertAllEqual(x, 5.)

  def _testNestedWhileCondWhileGrad(self, use_gpu):

    with self.cached_session(use_gpu=use_gpu):
      v = constant_op.constant(1.0)

      def inner_loop(s):
        z = constant_op.constant(0)
        c = lambda i, x: math_ops.less(i, 4)
        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
        return while_loop_tf.while_loop(c, b, [z, s])

      c = lambda x: math_ops.less(x, 128.0)

      def b(x):
        return tf_cond.cond(
            constant_op.constant(True),
            lambda: math_ops.square(inner_loop(x)[1]),
            lambda: math_ops.multiply(x, 2.0))

      r = while_loop_tf.while_loop(c, b, [v])
      r = gradients_impl.gradients(r, v)[0]
      self.assertAllClose(512.0, self.evaluate(r))

  @test_util.run_deprecated_v1
  def testNestedWhileCondWhileGrad(self):
    self._testNestedWhileCondWhileGrad(use_gpu=False)

  @test_util.run_deprecated_v1
  def testNestedWhileCondWhileGradGpu(self):
    self._testNestedWhileCondWhileGrad(use_gpu=True)

  @test_util.run_v1_only("b/120545219")
  def testWhileGrad_Variable(self):
    with self.cached_session():
      a = variables.Variable(3.0)
      v = constant_op.constant(2.0, name="v")
      c = lambda v: math_ops.less(v, 100.0)
      b = lambda v: math_ops.multiply(v, a)
      r = while_loop_tf.while_loop(c, b, [v], parallel_iterations=1)

      r = gradients_impl.gradients(r, a)
      self.evaluate(variables.global_variables_initializer())
      self.assertAllClose(216.0, r[0])

  @test_util.run_deprecated_v1
  def testWhileGrad_ResourceVariable(self):
    with self.cached_session():
      a = resource_variable_ops.ResourceVariable(3.0)
      v = constant_op.constant(2.0, name="v")
      c = lambda v: math_ops.less(v, 100.0)
      b = lambda v: math_ops.multiply(v, a)
      r = while_loop_tf.while_loop(c, b, [v], parallel_iterations=1)

      g = gradients_impl.gradients(r, a)
      self.evaluate(variables.global_variables_initializer())
      self.assertAllClose(216.0, g[0])

  def testWhileGrad_EagerResourceVariable(self):
    with context.eager_mode():
      a = resource_variable_ops.ResourceVariable(
          np.ones([2, 2], dtype=np.float32))
      v = constant_op.constant(1.0)

      @eager_def_function.function
      def fn():
        r = while_loop_tf.while_loop(
            lambda i, _: i < 2, lambda i, x:
            (i + 1, x * math_ops.reduce_sum(a) * v), [0, 1.0])[1]
        return gradients_impl.gradients(r, [v])[0]

      self.assertEqual(self.evaluate(fn()), 32.)

  def testWhileGrad_ResourceVarInFunctionCall(self):

    @eager_def_function.function
    def foo(x, var):
      return x + math_ops.reduce_sum(var.sparse_read([1, 3]))

    @eager_def_function.function
    def bar(var):
      r = while_loop_tf.while_loop(
          lambda i, _: i < 2, lambda i, x: (i + 1, foo(x, var)), [0, 0.0])[1]
      return gradients_impl.gradients(r, var)[0]

    var = resource_variable_ops.ResourceVariable([1., 2., 3., 4.])
    self.evaluate(variables.global_variables_initializer())
    grad = self.evaluate(bar(var))
    self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 2., 0., 2.])

  def testWhileGrad_ResourceVarInNestedFunctionCall(self):

    @eager_def_function.function
    def foo(x, var):
      return x + math_ops.reduce_sum(var.sparse_read([1, 3]))

    @eager_def_function.function
    def foo2(x, var):
      return foo(x, var)

    @eager_def_function.function
    def bar(var):
      r = while_loop_tf.while_loop(
          lambda i, _: i < 2, lambda i, x: (i + 1, foo2(x, var)), [0, 0.0])[1]
      return gradients_impl.gradients(r, var)[0]

    var = resource_variable_ops.ResourceVariable([1., 1., 1., 1.])
    self.evaluate(variables.global_variables_initializer())
    grad = self.evaluate(bar(var))
    self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 2., 0., 2.])

  def testWhileGrad_ResourceVarInLoopInFunctionCall(self):
    if test.is_gpu_available():
      self.skipTest("b/128635252")

    @eager_def_function.function
    def foo(x, var):
      return while_loop_tf.while_loop(
          lambda j, _: j < 3,
          lambda j, y: (j + 1,
                        y + math_ops.reduce_sum(var.sparse_read([1, 2]))),
          [0, x])[1]

    @eager_def_function.function
    def bar(var):
      r = while_loop_tf.while_loop(
          lambda i, _: i < 2,
          lambda i, x: (i + 1, foo(x, var)),
          [0, 0.0])[1]
      return gradients_impl.gradients(r, var)[0]

    var = resource_variable_ops.ResourceVariable([1., 1., 1., 1.])
    self.evaluate(variables.global_variables_initializer())
    grad = self.evaluate(bar(var))
    self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 6., 6., 0.])

  def testWhileCondGrad_ResourceVarInFunctionCall(self):

    @eager_def_function.function
    def foo(x, var):
      return x + var.sparse_read([1])[0]

    def body(i, x):
      return (i + 1, tf_cond.cond(
          math_ops.equal(i % 2, 0),
          lambda: foo(x, var1),
          lambda: foo(x, var2)))

    @eager_def_function.function
    def bar(var1, var2):
      r = while_loop_tf.while_loop(lambda i, _: i < 4, body, [0, 0.0])
      return gradients_impl.gradients(r, [var1, var2])

    var1 = resource_variable_ops.ResourceVariable([1., 2., 3.])
    var2 = resource_variable_ops.ResourceVariable([4., 5.])
    self.evaluate(variables.global_variables_initializer())
    grads = self.evaluate(bar(var1, var2))
    self.assertAllEqual(gradient_checker_v2._to_numpy(grads[0]), [0., 2., 0.])
    self.assertAllEqual(gradient_checker_v2._to_numpy(grads[1]), [0., 2.])

  @test_util.run_deprecated_v1
  def testWhileGrad_ResourceVarSparseRead(self):
    # NOTE(skyewm): this test is interesting because the gradient is the
    # aggregation result of IndexedSlices and Tensors.
    var = resource_variable_ops.ResourceVariable(np.ones(5),
                                                 dtype=dtypes.float32)
    r = while_loop_tf.while_loop(
        lambda i, _: i < 3, lambda i, x:
        (i + 1, x * math_ops.reduce_sum(var.sparse_read([1, 3]))),
        [0, constant_op.constant(1.0)])[1]
    grad = gradients_impl.gradients(r, var)[0]

    self.evaluate(variables.global_variables_initializer())
    grad_val = self.evaluate(grad)
    arr = gradient_checker_v2._to_numpy(grad_val)
    self.assertAllEqual(arr, [0., 12., 0., 12., 0.])

  @test_util.run_deprecated_v1
  def testWhileGrad_MultiResourceVarSparseRead(self):
    # NOTE(skyewm): this test is interesting because the gradient is the
    # aggregation result of IndexedSlices and Tensors.
    var1 = resource_variable_ops.ResourceVariable(np.ones(5),
                                                  dtype=dtypes.float32)
    var2 = resource_variable_ops.ResourceVariable(np.ones(3),
                                                  dtype=dtypes.float32)
    x1_init = constant_op.constant([0., 0.])
    x2_init = constant_op.constant(1.)
    x3_init = constant_op.constant(1.)

    def body(i, unused_x1, x2, x3):
      y1 = var1.sparse_read([1, 3])
      y2 = x2 * 2
      y3 = x3 * math_ops.reduce_sum(var2.sparse_read([0]))
      return i + 1, y1, y2, y3

    r = while_loop_tf.while_loop(
        lambda i, x1, x2, x3: i < 3, body,
        [0, x1_init, x2_init, x3_init])[1:]
    var1_grad, var2_grad = gradients_impl.gradients(r, [var1, var2])

    self.evaluate(variables.global_variables_initializer())
    var1_grad_val = self.evaluate(var1_grad)
    var2_grad_val = self.evaluate(var2_grad)
    self.assertAllEqual(gradient_checker_v2._to_numpy(var1_grad_val),
                        [0., 1., 0., 1., 0.])
    self.assertAllEqual(gradient_checker_v2._to_numpy(var2_grad_val),
                        [3., 0., 0.])

  def testWhileGrad_Gather(self):
    # NOTE(skyewm): this test is interesting because the gather gradient
    # function returns an IndexedSlices.
    @tf_function_in_tf2
    def fn():
      x = constant_op.constant([1., 1., 1., 1., 1.])
      y = while_loop_tf.while_loop(
          lambda i, _: i < 3, lambda i, x:
          (i + 1, x + array_ops.gather(x, [0])), [0, x[:1]])[1]
      z = y * 3.0
      grad = gradients_impl.gradients(z, x)[0]
      return y, grad
    y, grad = fn()
    self.assertEqual(self.evaluate(y), 8.)
    self.assertAllEqual(self.evaluate(grad), [24., 0., 0., 0., 0.])

  def testWhileGrad_GatherNoFanOut(self):
    # NOTE(skyewm): this test is interesting because the gather gradient
    # function returns an IndexedSlices.
    @tf_function_in_tf2
    def fn():
      x = constant_op.constant([1., 1., 1., 1., 1.])
      y = while_loop_tf.while_loop(
          lambda i, _: i < 3,
          lambda i, x: (i + 1, array_ops.gather(x, [0])),
          [0, x[:1]])[1]
      z = y * 3.0
      grad = gradients_impl.gradients(z, x)[0]
      return y, grad
    y, grad = fn()
    self.assertEqual(self.evaluate(y), 1.)
    self.assertAllEqual(self.evaluate(grad), [3., 0., 0., 0., 0.])

  @test_util.run_v1_only("b/120545219")
  def testWhileGradInCond(self):

    with self.cached_session():
      n = ops.convert_to_tensor(1.0, name="n")
      x = array_ops.placeholder(dtypes.float32, shape=None)
      c = lambda n: math_ops.less(n, 10.0)
      b = lambda n: math_ops.add(n, x)

      def fn1():
        r = while_loop_tf.while_loop(c, b, [n], [tensor_shape.unknown_shape()])
        return gradients_impl.gradients(r, x)[0]

      r = tf_cond.cond(math_ops.less(1, 2), fn1, lambda: x)
      self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))

  @test_util.disable_control_flow_v2("b/116340060")
  @test_util.run_v1_only("b/120545219")
  def testGradInWhileWrtInitialLoopVal(self):
    with self.cached_session():
      x = array_ops.placeholder(dtypes.float32, shape=(), name="x")
      y = x + 1

      def body(i, v):
        z = v * 2
        return i + 1, gradients_impl.gradients(z, x)[0]

      with self.assertRaisesRegex(
          ValueError,
          "Cannot compute gradient inside while loop with respect to op 'x'. "
          "We do not support taking the gradient wrt or through the initial "
          "value of a loop variable. Gradients can be computed through "
          "loop invariants or wrt the input parameters to the loop body."):
        while_loop_tf.while_loop(lambda i, x: i < 3, body, [0, y])

  @test_util.run_v1_only("b/120545219")
  def testWhileGradInWhile(self):
    with self.cached_session():
      n = ops.convert_to_tensor(1.0, name="n")
      x = array_ops.placeholder(dtypes.float32, shape=None)
      c = lambda n: math_ops.less(n, 10.0)
      b = lambda n: math_ops.add(n, x)

      def b1(n):
        r = while_loop_tf.while_loop(c, b, [n], [tensor_shape.unknown_shape()])
        return gradients_impl.gradients(r, x)

      r = while_loop_tf.while_loop(
          lambda n: n < 6.0, b1, [n], [tensor_shape.unknown_shape()])
      self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))

  @test_util.run_v1_only("b/120545219")
  def testCondGradInNestedWhiles(self):

    def outer_body(i, x):
      _, x = while_loop_tf.while_loop(lambda j, x: j < 3, inner_body, [0, 0.0])
      return i + 1, x

    def inner_body(j, x):
      y = tf_cond.cond(math_ops.less(x, 1), lambda: 2 * x, lambda: x)
      return j + 1, gradients_impl.gradients(y, x)[0]

    i, x = while_loop_tf.while_loop(lambda i, x: i < 3, outer_body, [0, 0.0])

    with self.cached_session() as sess:
      i_val, x_val = self.evaluate([i, x])
      self.assertEqual(i_val, 3)
      self.assertAllClose(x_val, 1.0)

  @test_util.run_gpu_only
  def testGpuResourceAccess(self):
    with ops.device(test.gpu_device_name()):
      var = resource_variable_ops.ResourceVariable(constant_op.constant(3.0))

    @eager_def_function.function
    def foo():
      return while_loop_tf.while_loop(
          lambda i, _: i < 3, lambda i, x:
          (i + 1,
           tf_cond.cond(
               constant_op.constant(True), lambda: x + var, lambda: x)),
          [0, 0.0])[1]

    self.evaluate(variables.global_variables_initializer())
    self.assertEqual(self.evaluate(foo()), 9.0)

  def testNestedResourceAccess(self):
    var = resource_variable_ops.ResourceVariable(constant_op.constant(3.0))

    @eager_def_function.function
    def test_fn():
      x = constant_op.constant(0.0)
      r = while_loop_tf.while_loop(
          # Outer loop condition
          lambda i, y: i < 2,
          # Outer loop body
          lambda i, y: (
              i + 1,
              y + tf_cond.cond(
                  constant_op.constant(True),
                  # True branch
                  lambda: while_loop_tf.while_loop(
                      # Inner loop condition
                      lambda j, z: j < 3,
                      # Inner loop body
                      lambda j, z: (j + 1, z + math_ops.square(var)),
                      # Inner initial loop value
                      [0, y])[1],
                  # False branch
                  lambda: (0.0))),
          # Outer initial loop value
          [0, x])[1]

      grad = gradients_impl.gradients(r, x)[0]
      return r, grad

    self.evaluate(variables.global_variables_initializer())
    r, grad = self.evaluate(test_fn())
    # 2 * 3 * 3^2
    self.assertEqual(r, 81.0)
    # v1 control flow gets the wrong answer!!!
    # Gradient computation:
    #   f(x) = x + 3^2
    #   inner_loop(x) = f(f(f(x))) = x + 3*3^2 = x + 27
    #   g(x) = x + inner_loop(x) = 2x + 27
    #   outer_loop(x) = g(g(x)) = 4x + 81
    #   outer_loop'(x) = 4
    # Note that v1 control flow gets 4.0 as well if the cond is removed.
    if control_flow_util.ENABLE_CONTROL_FLOW_V2:
      self.assertEqual(grad, 4.0)

  def testWhile_NestedInput(self):
    with self.cached_session() as sess:
      named = collections.namedtuple("named", ("a", "b"))
      loop_vars = [
          named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
          (constant_op.constant(2.0), constant_op.constant(3.0)),
          constant_op.constant(4.0)
      ]
      c = lambda lv0, _1, _2: lv0.a < 100.0

      def b(lv0, lv1, lv2):
        lv0 = named(a=lv0.a + 1, b=lv0.b)
        lv1 = (lv1[0] + 1, lv1[1])
        lv2 += 2
        return [lv0, lv1, lv2]

      r = while_loop_tf.while_loop(c, b, loop_vars)

      self.assertTrue(isinstance(r, list))
      self.assertTrue(isinstance(r[0], named))
      self.assertTrue(isinstance(r[1], tuple))
      self.assertTrue(isinstance(r[2], tensor_lib.Tensor))

      r_flattened = nest.flatten(r)
      self.assertEqual([100.0, 1.0, 102.0, 3.0, 4.0 + 100 * 2.0],
                       self.evaluate(r_flattened))

  @test_util.run_v1_only("b/120545219")
  def testWhile_NestedBadArityFails(self):
    with self.cached_session():
      named = collections.namedtuple("named", ("a", "b"))
      loop_vars = [
          named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
          (constant_op.constant(2.0), constant_op.constant(3.0)),
          constant_op.constant(4.0)
      ]
      c = lambda lv0, _1, _2: lv0.a < 100.0

      def b(lv0, lv1, _):
        return [lv0, lv1]

      with self.assertRaisesRegex(ValueError, "the same number of elements"):
        while_loop_tf.while_loop(c, b, loop_vars)

  @test_util.run_v1_only("b/120545219")
  def testWhileGrad_ys_xs(self):
    with self.cached_session():
      x = constant_op.constant(3.0, name="x")
      y = constant_op.constant(2.0, name="y")

      c = lambda x, y: math_ops.less(x, 100.0)

      def b(x, y):
        y1 = math_ops.add(x, y)
        x1 = math_ops.multiply(x, y1)
        return x1, y1

      rx, ry = while_loop_tf.while_loop(c, b, [x, y], parallel_iterations=1)

      r = gradients_impl.gradients([rx, ry], x)
      self.assertAllClose(304.0, r[0])
      r = gradients_impl.gradients([rx, ry], y)
      self.assertAllClose(124.0, r[0])
      r = gradients_impl.gradients([rx], x)
      self.assertAllClose(295.0, r[0])
      r = gradients_impl.gradients([rx], y)
      self.assertAllClose(120.0, r[0])

  @test_util.run_deprecated_v1
  def testWhileGrad_Dependency(self):
    with self.cached_session():
      i = constant_op.constant(0, name="i")
      x = constant_op.constant(2.0, name="x")

      c = lambda i, x: math_ops.less(i, 10)

      def b(i, x):
        x = math_ops.multiply(x, 2.0)
        i = math_ops.add(i, 1)
        return i, x

      ri, rx = while_loop_tf.while_loop(c, b, [i, x], parallel_iterations=1)

      r = gradients_impl.gradients([ri, rx], x)
      self.assertAllClose(1024.0, r[0])
      r = gradients_impl.gradients([rx], x)
      self.assertAllClose(1024.0, r[0])

  @test_util.run_v1_only("b/120545219")
  def testWhileGrad_NoGradient(self):
    with self.cached_session():
      v = constant_op.constant(2.0, name="v")
      c = lambda v: math_ops.less(v, 100.0)
      b = math_ops.square
      r = while_loop_tf.while_loop(c, b, [v], back_prop=False)
      r = math_ops.add(r, v)
      r = gradients_impl.gradients(r, v)
      self.assertAllClose(1.0, r[0])

  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
  @test_util.run_v1_only("b/120545219")
  def testWhileGrad_NoDependency(self):
    with self.cached_session() as sess:
      variable = variables.Variable(array_ops.ones([2, 3]))
      duration = array_ops.zeros([], dtype=dtypes.int32)

      def cond(duration, tensor, _):
        del tensor
        return duration < 10

      def body(duration, tensor, _):
        return (duration + 1, tensor, tensor)

      loop_vars = [duration, variable, variable]
      tensors = while_loop_tf.while_loop(
          cond=cond, body=body, loop_vars=loop_vars)
      cost = math_ops.reduce_sum(tensors[2])
      grad = gradients_impl.gradients(cost, [variable])
      self.evaluate(variables.global_variables_initializer())
      self.assertAllClose(np.ones([2, 3]), sess.run(grad[0]))

  @test_util.run_deprecated_v1
  def testWhileGrad_Const(self):
    with self.cached_session() as sess:
      c0 = constant_op.constant(0.0, name="c0")
      c1 = constant_op.constant(1.0, name="c1")
      duration = constant_op.constant(0, name="t")

      def cond(duration, _):
        return duration < 1

      def body(duration, _):
        return duration + 1, c1

      loop_vars = [duration, c0]
      tensors = while_loop_tf.while_loop(
          cond=cond, body=body, loop_vars=loop_vars)
      cost = math_ops.reduce_sum(tensors[1])
      grad = gradients_impl.gradients(cost, [c0])
      self.assertAllClose(0.0, sess.run(grad[0]))

  @test_util.run_v1_only("b/120545219")
  def testWhileGrad_SerialTwoLoops(self):
    with self.cached_session():
      i = constant_op.constant(0, name="i")
      x = constant_op.constant(2.0, name="x")

      c = lambda i, x: math_ops.less(i, 5)

      def b(i, x):
        x = math_ops.multiply(x, 2.0)
        i = math_ops.add(i, 1)
        return i, x

      _, rx = while_loop_tf.while_loop(c, b, [i, x], parallel_iterations=1)
      _, rx = while_loop_tf.while_loop(c, b, [i, rx], parallel_iterations=1)

      r = gradients_impl.gradients([rx], x)
      self.assertAllClose(1024.0, r[0])

  @test_util.run_v1_only("b/120545219")
  def testWhileGrad_ParallelTwoLoops(self):
    with self.cached_session():
      i = constant_op.constant(0, name="i")
      x = constant_op.constant(2.0, name="x")

      c = lambda i, x: math_ops.less(i, 5)

      def b(i, x):
        x = math_ops.multiply(x, 2.0)
        i = math_ops.add(i, 1)
        return i, x

      _, r1 = while_loop_tf.while_loop(c, b, [i, x], parallel_iterations=1)
      _, r2 = while_loop_tf.while_loop(c, b, [i, x], parallel_iterations=1)
      rx = math_ops.add(r1, r2)

      r = gradients_impl.gradients([rx], x)
      self.assertAllClose(64.0, r[0])

  @test_util.run_v1_only("b/120545219")
  def testWhileGrad_OneOutputWithControlDependencyOnSecond(self):
    with self.cached_session():
      i = constant_op.constant(0, name="i")
      x = constant_op.constant(1.0, name="x")
      y = constant_op.constant(1.0, name="y")
      c = lambda i, *_: math_ops.less(i, 1, name="cond_less")

      def b(i, xi, yi):
        # return (i + 1, xi, xi + yi)
        return (math_ops.add(i, 1, name="inc"), array_ops.identity(
            xi, name="xi"), math_ops.add(xi, yi, name="xi_plus_yi"))

      _, x_f, y_f = while_loop_tf.while_loop(c, b, [i, x, y])
      with ops.control_dependencies([x_f]):
        y_f_d = array_ops.identity(y_f, name="y_f_d")

      self.assertAllClose(2.0, self.evaluate(y_f_d))  # y_f_d = 1.0 + 1.0
      g = gradients_impl.gradients([y_f_d], [x])[0]
      self.assertTrue(g is not None)
      self.assertAllClose(1.0,
                          self.evaluate(g))  # y_f_d = x + 1.0, dy_f_d/dx = 1.0

  def _testNestedWhileGrad_Simple(self, use_gpu):
    with self.cached_session(use_gpu=use_gpu):
      v = constant_op.constant(1.0)

      def inner_loop(s):
        c = lambda x: math_ops.less(x, 4.0)
        b = lambda x: math_ops.multiply(x, 2.0)
        return while_loop_tf.while_loop(c, b, [s])

      c = lambda x: math_ops.less(x, 2.0)
      b = lambda x: math_ops.multiply(inner_loop(x), 2.0)
      r = while_loop_tf.while_loop(c, b, [v])

      r = gradients_impl.gradients(r, v)[0]
      self.assertAllClose(8.0, self.evaluate(r))

  @test_util.run_deprecated_v1
  def testNestedWhileGrad_Simple(self):
    self._testNestedWhileGrad_Simple(use_gpu=False)
    self._testNestedWhileGrad_Simple(use_gpu=True)

  @test_util.run_v1_only("b/120545219")
  def testNestedWhileGrad_SerialInner(self):
    with self.cached_session():
      v = constant_op.constant(1.0)

      def inner_loop1(s):
        z = constant_op.constant(0)
        c = lambda i, x: math_ops.less(i, 4)
        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
        return while_loop_tf.while_loop(c, b, [z, s])

      def inner_loop2(s):
        z = constant_op.constant(0)
        c = lambda i, x: math_ops.less(i, 4)
        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
        return while_loop_tf.while_loop(c, b, [z, s])

      c = lambda x: math_ops.less(x, 128.0)
      b = lambda x: inner_loop2(inner_loop1(x)[1])[1]
      r = while_loop_tf.while_loop(c, b, [v])

      r = gradients_impl.gradients(r, v)[0]
      self.assertAllClose(256.0, self.evaluate(r))

  @test_util.run_deprecated_v1
  def testNestedWhileGrad_ParallelInner(self):
    with self.cached_session():
      v = constant_op.constant(1.0)

      def inner_loop1(s):
        z = constant_op.constant(0)
        c = lambda i, x: math_ops.less(i, 4)
        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
        return while_loop_tf.while_loop(c, b, [z, s])

      def inner_loop2(s):
        z = constant_op.constant(0)
        c = lambda i, x: math_ops.less(i, 4)
        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
        return while_loop_tf.while_loop(c, b, [z, s])

      c = lambda x: math_ops.less(x, 128.0)
      b = lambda x: math_ops.multiply(inner_loop1(x)[1], inner_loop2(x)[1])
      r = while_loop_tf.while_loop(c, b, [v])

      r = gradients_impl.gradients(r, v)[0]
      self.assertAllClose(512.0, self.evaluate(r))

  @test_util.run_v1_only("b/120545219")
  def testNestedWhileGrad_ParallelIterations(self):
    # Make sure the stack pushes and pops of an inner loop are executed in
    # the sequential order of the iterations of its outer loop.
    with self.cached_session() as sess:

      def inner_loop(t):
        fn = lambda n: n + math_ops.square(var)
        return map_fn.map_fn(fn=fn, elems=t, parallel_iterations=10)

      def outer_loop(inp):
        return map_fn.map_fn(
            fn=inner_loop, elems=inp, parallel_iterations=10)

      var = variables.Variable(constant_op.constant(3.0))
      inp = constant_op.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
      res = outer_loop(inp)
      optimizer = adam.AdamOptimizer(learning_rate=0.001)
      train_op = optimizer.minimize(math_ops.reduce_mean(math_ops.square(res)))
      self.evaluate(variables.global_variables_initializer())
      self.evaluate(train_op)
      self.assertAllClose(2.999, var.read_value())

  def _testWhileCondGrad_Simple(self, use_gpu):
    with self.cached_session(use_gpu=use_gpu):
      v = ops.convert_to_tensor(2.0, name="v")
      n = ops.convert_to_tensor(100.0, name="n")
      one = ops.convert_to_tensor(1.0, name="one")
      c = lambda x: math_ops.less(x, n)
      # pylint: disable=undefined-variable
      # for OSS build
      b = lambda x: tf_cond.cond(constant_op.constant(True),
                                 lambda: math_ops.square(x),
                                 lambda: math_ops.subtract(x, one))
      # pylint: enable=undefined-variable
      r = while_loop_tf.while_loop(c, b, [v])
      r = gradients_impl.gradients(r, v)[0]
      self.assertAllClose(1024.0, self.evaluate(r))

  @test_util.run_deprecated_v1
  def testWhileCondGrad_Simple(self):
    self._testWhileCondGrad_Simple(use_gpu=False)
    self._testWhileCondGrad_Simple(use_gpu=True)

  @test_util.run_deprecated_v1
  def testWhileCondGrad_UnknownShape(self):
    with self.cached_session() as sess:
      v = array_ops.placeholder(dtypes.float32)
      n = ops.convert_to_tensor(100.0, name="n")
      one = ops.convert_to_tensor(1.0, name="one")
      c = lambda x: math_ops.less(x, n)
      # pylint: disable=undefined-variable
      # for OSS build
      b = lambda x: tf_cond.cond(constant_op.constant(True),
                                 lambda: math_ops.square(x),
                                 lambda: math_ops.subtract(x, one))
      # pylint: enable=undefined-variable
      r = while_loop_tf.while_loop(c, b, [v])
      r = gradients_impl.gradients(r, v)[0]
      r = sess.run(r, feed_dict={v: 2.0})
      self.assertAllClose(1024.0, r)

  @test_util.run_deprecated_v1
  def testWhileGrad_Concat(self):
    with self.cached_session() as sess:
      x = variable_scope.get_variable("x", initializer=[[1., 2.]])
      i0 = constant_op.constant(0)
      h0 = array_ops.zeros([0, 2])

      def condition(i, _):
        return i < 2

      def body(i, h):
        return i + 1, array_ops.concat([h, x], 0)

      _, h = while_loop_tf.while_loop(
          condition, body, [i0, h0],
          [i0.get_shape(), tensor_shape.TensorShape([None, 2])])
      s = math_ops.reduce_sum(h)

      optimizer = gradient_descent.GradientDescentOptimizer(0.01)
      op = optimizer.minimize(s)

      self.evaluate(variables.global_variables_initializer())
      self.evaluate(op)
      self.assertAllClose([[0.98000002, 1.98000002]], self.evaluate(x))

  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
  @test_util.run_v1_only("b/120545219")
  def testWhileWithRefsWithGradients_1(self):
    with self.cached_session() as sess:
      x = variable_v1.VariableV1(0.)._ref()  # pylint: disable=protected-access
      i = constant_op.constant(0)
      c = lambda i, x: math_ops.less(i, 10)

      self.assertEqual(x.dtype, dtypes.float32_ref)

      def body(i, x):
        self.assertEqual(x.dtype, dtypes.float32_ref)
        return [i + 1, gen_array_ops.ref_identity(x)]

      r = while_loop_tf.while_loop(c, body, [i, x], parallel_iterations=5)

      grad_ys = [variable_v1.VariableV1(73)._ref()]  # pylint: disable=protected-access
      grad = gradients_impl.gradients([r[1]], [x], grad_ys=grad_ys)

      self.evaluate(variables.global_variables_initializer())

      self.assertEqual(r[0].dtype, dtypes.int32)
      self.assertEqual(r[1].dtype, dtypes.float32_ref)

      value_i, value_x, value_x_grad = sess.run(r + grad)

    self.assertEqual(10, value_i)
    self.assertEqual(0, value_x)
    self.assertEqual(73, value_x_grad)

  @test_util.deprecated_graph_mode_only
  def testWhileGrad_IndexedSlices(self):
    with self.cached_session():
      values = constant_op.constant([2.0, 4.0], name="values")
      indices = constant_op.constant([0, 3], name="indices")
      shape = constant_op.constant([10], name="dense_shape")
      i = constant_op.constant(0)
      x = indexed_slices.IndexedSlices(values, indices, dense_shape=shape)

      def c(i, _):
        return i < 10

      def b(i, x):
        return [
            i + 1,
            indexed_slices.IndexedSlices(x.values * 2.0, x.indices,
                                         x.dense_shape)
        ]

      _, r = while_loop_tf.while_loop(c, b, [i, x])
      r = gradients_impl.gradients(r.values, values)[0]
      self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r))

  @test_util.deprecated_graph_mode_only
  def testWhileGrad_SparseTensor(self):
    with self.cached_session():
      values = constant_op.constant([2.0, 4.0], name="values")
      indices = constant_op.constant(
          [[0], [3]], dtype=dtypes.int64, name="indices")
      shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
      i = constant_op.constant(0)
      x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)

      def c(i, _):
        return i < 10

      def b(i, x):
        return [
            i + 1,
            sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape)
        ]

      _, r = while_loop_tf.while_loop(c, b, [i, x])
      r = gradients_impl.gradients(r.values, values)[0]
      self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r))

  @test_util.deprecated_graph_mode_only
  def testCallGradInLoop(self):
    with self.cached_session() as sess:
      i0 = constant_op.constant(0)
      params = constant_op.constant(5.0)
      params_1 = math_ops.square(params)

      def c(i, _):
        return i < 10

      def b(i, x):
        data = constant_op.constant([1.0, 2.0, 3.0])
        data = math_ops.multiply(data, params_1)
        x1 = x + gradients_impl.gradients(data, params)[0]
        return i + 1, x1

      output_grad = while_loop_tf.while_loop(
          c, b, [i0, constant_op.constant(0.0)])
      self.assertAllClose(600.0, self.evaluate(output_grad)[1])

  @test_util.run_deprecated_v1
  def testWhileAndTensorArray(self):
    with self.cached_session() as sess:
      param = constant_op.constant(2.0)
      n0 = constant_op.constant(0)
      y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")

      def c(i, _):
        return i < 10

      def b(i, y):
        return [
            i + 1,
            map_fn.map_fn(lambda x: math_ops.multiply(x, param), y)
        ]

      r = while_loop_tf.while_loop(c, b, [n0, y0], parallel_iterations=1)
      r = gradients_impl.gradients(r, param)[0]
      self.assertAllClose(107520.0, self.evaluate(r))

  @test_util.run_deprecated_v1
  def testNestedWhileAndTensorArray(self):
    n = constant_op.constant(3.0)

    def Body(row, ta):

      def InnerBody(row, col, ta):
        # Note: row and col are 1-based.
        ta = ta.write(
            math_ops.cast(n * (row - 1.) + col - 1., dtypes.int32), row * col)
        return row, col + 1., ta

      ta = while_loop_tf.while_loop(
          lambda _, col, _1: col <= n,
          InnerBody, [row, constant_op.constant(1.), ta],
          return_same_structure=False)[2]
      return row + 1., ta

    ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=9)
    ta = while_loop_tf.while_loop(
        lambda row, _: row <= n,
        Body, [constant_op.constant(1.), ta],
        return_same_structure=False)[1]

    output = array_ops.reshape(ta.stack(), [3, 3])
    self.assertAllEqual(
        self.evaluate(output), [[1., 2., 3.], [2., 4., 6.], [3., 6., 9.]])
    # TODO(b/117675481): This does not work with current TA. Enable with new TA.
    # grad = gradients_impl.gradients(output, [n])
    # self.assertEqual(self.evaluate(grad), 3.5)

  @test_util.run_deprecated_v1
  def testWhileGrad_StopGrad(self):
    with self.cached_session():
      x = constant_op.constant(3.0, name="x")
      y = constant_op.constant(2.0, name="y")

      c = lambda x, y: math_ops.less(x, 100.0)

      def b(x, y):
        y1 = math_ops.square(y)
        x1 = math_ops.add(math_ops.square(x), y1)
        return x1, y1

      rx, ry = while_loop_tf.while_loop(c, b, [x, y])

      r = gradients_impl.gradients(rx, y)[0]
      self.assertEqual(136.0, self.evaluate(r))
      r = gradients_impl.gradients(ry, y)[0]
      self.assertEqual(32.0, self.evaluate(r))

      r = gradients_impl.gradients(array_ops.stop_gradient(rx), y)[0]
      self.assertEqual(r, None)
      r = gradients_impl.gradients(array_ops.stop_gradient(ry), y)[0]
      self.assertEqual(r, None)

      r = gradients_impl.gradients(
          array_ops.stop_gradient(math_ops.square(rx)), y)[0]
      self.assertEqual(r, None)
      r = gradients_impl.gradients(
          array_ops.stop_gradient(math_ops.add(rx, ry)), x)[0]
      self.assertEqual(r, None)
      r = gradients_impl.gradients(
          array_ops.stop_gradient(math_ops.add(rx, ry)), y)[0]
      self.assertEqual(r, None)

      r = gradients_impl.gradients(math_ops.add(rx, ry), y)[0]
      self.assertEqual(168.0, self.evaluate(r))
      r = gradients_impl.gradients(
          math_ops.add(rx, array_ops.stop_gradient(ry)), y)[0]
      self.assertEqual(136.0, self.evaluate(r))
      r = gradients_impl.gradients(
          math_ops.add(array_ops.stop_gradient(rx), ry), y)[0]
      self.assertEqual(32.0, self.evaluate(r))

  @test_util.run_deprecated_v1
  def testWhileGrad_StopGradInside(self):
    with self.cached_session():
      x = constant_op.constant(3.0, name="x")
      y = constant_op.constant(2.0, name="y")

      c = lambda x, y: math_ops.less(x, 100.0)

      def b(x, y):
        y1 = array_ops.stop_gradient(math_ops.square(y))
        x1 = math_ops.add(math_ops.square(x), y1)
        return x1, y1

      rx, _ = while_loop_tf.while_loop(c, b, [x, y])

      r = gradients_impl.gradients(rx, y)[0]
      self.assertAllClose(0.0, self.evaluate(r))
      r = gradients_impl.gradients(rx, x)[0]
      self.assertAllClose(156.0, self.evaluate(r))

  @test_util.run_deprecated_v1
  def testWhileGrad_StopGradInsideNoShape(self):
    with self.cached_session() as sess:
      x = array_ops.placeholder(dtypes.float32)
      y = array_ops.placeholder(dtypes.float32)

      c = lambda x, y: math_ops.less(math_ops.reduce_sum(x), 100.0)

      def b(x, y):
        y1 = array_ops.stop_gradient(math_ops.square(y, name="stopped"))
        x1 = math_ops.add(math_ops.square(x), y1)
        return x1, y1

      rx, _ = while_loop_tf.while_loop(c, b, [x, y])

      grad_y = gradients_impl.gradients(rx, y)[0]
      grad_x = gradients_impl.gradients(rx, x)[0]
      feed_dict = {x: [3.0, 4.0], y: [2.0, 3.0]}
      self.assertAllClose([0.0, 0.0], sess.run(grad_y, feed_dict=feed_dict))
      self.assertAllClose([156.0, 400.0], sess.run(grad_x, feed_dict=feed_dict))
      name = "gradients/while/stopped_grad"
      all_ops = x.graph.get_operations()
      self.assertFalse(any(name in op.name for op in all_ops))

  @test_util.run_deprecated_v1
  def testWhileGradGradFail(self):
    theta = variables.Variable(initial_value=1.)

    def fn(prev, x):
      return prev + x * theta

    result = functional_ops.scan(fn, np.array([1., 2., 3.], dtype=np.float32))
    grad_theta = gradients_impl.gradients(result, theta)
    if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
      with self.assertRaisesRegex(TypeError, "Second-order gradient"):
        gradients_impl.gradients(grad_theta, theta)
    grad_theta_stopped = array_ops.stop_gradient(grad_theta)
    gradients_impl.gradients(grad_theta_stopped, theta)

  @test_util.run_deprecated_v1
  def testStopGradOnWhileGrad(self):
    with self.cached_session():
      x = constant_op.constant(2.0, name="x")
      y = constant_op.constant(2.0, name="y")

      c = lambda x: math_ops.less(x, 100.0)
      b = lambda x: math_ops.multiply(x, y)
      rx = while_loop_tf.while_loop(c, b, [x])

      rg = gradients_impl.gradients(rx, y)[0]
      rg = array_ops.stop_gradient(rg)
      r = math_ops.add(math_ops.square(y), rx)
      r = math_ops.add(r, rg)
      r = gradients_impl.gradients(r, y)[0]
      self.assertEqual(388.0, self.evaluate(r))

  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
  @test_util.run_deprecated_v1
  def testWhileGradientWithNontrainablePath1(self):
    q = variables.Variable([7., 8.])

    def cond(_, y):
      del y
      return False

    def body(x, _):
      return x, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q)

    _, y = while_loop_tf.while_loop(cond, body, (math_ops.argmin(q), 0.))
    dy_dq, = gradients_impl.gradients(y, q)
    self.assertIsNotNone(dy_dq)
    with self.cached_session() as sess:
      self.evaluate(q.initializer)
      self.assertAllClose([0., 0.], self.evaluate(dy_dq))

  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
  @test_util.run_v1_only("b/120545219")
  def testWhileGradientWithNontrainablePath2(self):
    q = variables.Variable([7., 8.])

    def cond(_, y):
      return math_ops.equal(y, 0.)

    def body(x, _):
      zero = constant_op.constant(0, dtype=dtypes.int64)
      return zero, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q)

    _, y = while_loop_tf.while_loop(cond, body, (math_ops.argmin(q), 0.))
    dy_dq, = gradients_impl.gradients(y, q)
    self.assertIsNotNone(dy_dq)
    with self.cached_session() as sess:
      self.evaluate(q.initializer)
      self.assertAllClose([1., 1.], self.evaluate(dy_dq))

  @test_util.run_v1_only("b/120545219")
  def testIssue16504(self):
    c = constant_op.constant(np.arange(100), dtype=dtypes.float32)
    w = variables.Variable(
        initial_value=np.ones(100), dtype=dtypes.float32) / 100
    k = variables.Variable(0, dtype=dtypes.int32)
    chg_w = constant_op.constant(np.inf, dtype=dtypes.float32)

    def cond(k, _, chg_w):
      return math_ops.logical_and(k < 10, chg_w > 1e-3)

    def body(k, w, chg_w):
      grad, = gradients_impl.gradients(-math_ops.reduce_sum(w * c), w)
      w_n = w * math_ops.exp(-0.1 * grad)
      w_n /= math_ops.reduce_sum(w_n)
      chg_w = (
          math_ops.reduce_sum(math_ops.abs(w_n - w)) / math_ops.reduce_sum(
              math_ops.abs(w)))
      return k + 1, w_n, chg_w

    _, w, _ = while_loop_tf.while_loop(cond, body, [k, w, chg_w])
    grad, = gradients_impl.gradients(w, c)
    self.assertIsNotNone(grad)

  @test_util.run_v1_only("b/120545219")
  def testStopGradMultiFlows(self):
    with self.cached_session():

      def body(i, y, r):
        x = variable_scope.get_variable(
            "x",
            shape=(),
            dtype=dtypes.float32,
            initializer=init_ops.ones_initializer())
        y *= x
        return [i + 1, y, r + math_ops.reduce_sum(y)]

      i0 = constant_op.constant(0)
      y0 = array_ops.ones(5)
      r0 = constant_op.constant(0.0)
      cond = lambda i, y, r: i < 1
      _, _, r = while_loop_tf.while_loop(
          cond, body, [i0, y0, r0], back_prop=True)

      vars_ = variables.global_variables()
      grads = linalg_ops.norm(gradients_impl.gradients(r, vars_)[0])
      z = math_ops.add(r, array_ops.stop_gradient(math_ops.reduce_sum(grads)))
      result = gradients_impl.gradients(z, vars_)[0]
      self.evaluate(variables.global_variables_initializer())
      self.assertEqual(5.0, self.evaluate(result))

  @test_util.run_v1_only("b/120545219")
  def testOneValueCond(self):

    with self.cached_session():
      c = array_ops.placeholder(dtypes.int32, shape=[])
      one = ops.convert_to_tensor(1, name="one")
      two = ops.convert_to_tensor(2, name="two")
      p = math_ops.greater_equal(c, 1)
      i = tf_cond.cond(p, lambda: one, lambda: two)
      self.assertTrue(isinstance(i, tensor_lib.Tensor))

      # True case: c = 2 is >= 1
      self.assertEqual([1], i.eval(feed_dict={c: 2}))

      # False case: c = 0 is not >= 1
      self.assertEqual([2], i.eval(feed_dict={c: 0}))

  @test_util.run_deprecated_v1
  def testExampleCond(self):

    with self.cached_session():
      x = ops.convert_to_tensor([-2.0, 2.0], name="x")
      d = array_ops.placeholder(dtypes.int32, shape=[])

      def l2():
        return math_ops.sqrt(math_ops.reduce_sum(math_ops.square(x)))

      def l1():
        return math_ops.reduce_sum(math_ops.abs(x))

      i = tf_cond.cond(math_ops.equal(d, 2), l2, l1)
      self.assertAllClose(4.0, i.eval(feed_dict={d: 1}))
      self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2}))

  @test_util.run_v1_only("b/120545219")
  def testCase(self):
    with self.cached_session():
      x = constant_op.constant(1)
      y = constant_op.constant(2)
      z = constant_op.constant(3)
      f1 = lambda: constant_op.constant(17)
      f2 = lambda: constant_op.constant(23)
      f3 = lambda: constant_op.constant(-1)

      r1 = control_flow_case.case({
          x < y: f1,
          x > z: f2
      },
                                  default=f3,
                                  exclusive=True)
      self.assertAllEqual(r1, 17)

      r2 = control_flow_case.case([(y > z, f1), (y > x, f2)], default=f3)
      self.assertAllEqual(r2, 23)

      # Duplicate events can happen, first one is selected
      r3 = control_flow_case.case([(x < y, f1), (x < y, f2)], default=f3)
      self.assertAllEqual(r3, 17)

      # Duplicate events cause an error if exclusive = True
      r4 = control_flow_case.case([(x < y, f1), (x < y, f2)],
                                  default=f3,
                                  exclusive=True)
      with self.assertRaisesOpError("Input error:"):
        self.evaluate(r4)

      # Check that the default is called if none of the others are
      r5 = control_flow_case.case({x > y: f1}, default=f3)
      self.assertAllEqual(r5, -1)

      ran_once = [False, False, False]

      def break_run_twice(ix):

        def _break():
          ran_once[ix] = True
          return constant_op.constant(ix)

        return _break

      # Should not fail - each conditional gets called exactly once
      # except default.  Default gets called twice: once to create an
      # empty output and once for the actual cond switch.
      r6 = control_flow_case.case([(x < y, break_run_twice(0)),
                                   (x > y, break_run_twice(1))],
                                  default=lambda: constant_op.constant(2))

      self.assertAllEqual(r6, 0)

  @test_util.run_v1_only("b/120545219")
  def testCaseSideEffects(self):
    with self.cached_session() as sess:
      v0 = variables.Variable(-1)
      v1 = variables.Variable(-1)
      v2 = variables.Variable(-1)

      a = lambda: control_flow_ops.with_dependencies([state_ops.assign(v0, 0)], 0)
      b = lambda: control_flow_ops.with_dependencies([state_ops.assign(v1, 1)], 1)
      c = lambda: control_flow_ops.with_dependencies([state_ops.assign(v2, 2)], 2)

      x = constant_op.constant(1)
      y = constant_op.constant(2)

      r0 = control_flow_case.case(((x < y, a), (x > y, b)),
                                  default=c,
                                  exclusive=True)
      r1 = control_flow_case.case(((x > y, a), (x < y, b)),
                                  default=c,
                                  exclusive=True)
      r2 = control_flow_case.case(((x > y, a), (x > y, b)),
                                  default=c,
                                  exclusive=True)

      self.evaluate(variables.global_variables_initializer())
      self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3)
      self.assertEqual(2, self.evaluate(r2))
      self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1, -1, 2])

      self.evaluate(variables.global_variables_initializer())
      self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3)
      self.assertEqual(1, self.evaluate(r1))
      self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1, 1, -1])

      self.evaluate(variables.global_variables_initializer())
      self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3)
      self.assertEqual(0, self.evaluate(r0))
      self.assertAllEqual(self.evaluate([v0, v1, v2]), [0, -1, -1])

  @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
  @test_util.run_v1_only("b/120545219")
  def testOneOpCond(self):
    with self.cached_session():
      v = variables.Variable(0)
      c = ops.convert_to_tensor(0)
      one = ops.convert_to_tensor(1)
      two = ops.convert_to_tensor(2)
      p = math_ops.greater_equal(c, 1)

      def a():
        return state_ops.assign(v, one)

      def b():
        return state_ops.assign(v, two)

      i = tf_cond.cond(p, a, b)
      self.assertTrue(isinstance(i, tensor_lib.Tensor))
      self.evaluate(variables.global_variables_initializer())

      self.assertEqual(0, self.evaluate(v))

      # True case: c = 2 is >= 1, v is set to 1.
      self.assertEqual(1, i.eval(feed_dict={c.name: 2}))
      self.assertEqual(1, self.evaluate(v))

      # False case: c = 0 is not >= 1, v is set to 2.
      self.assertEqual(2, i.eval(feed_dict={c.name: 0}))
      self.assertEqual(2, self.evaluate(v))

  @test_util.run_v1_only("b/120545219")
  def testWithOpsDependencies(self):
    with self.cached_session() as sess:
      v = variable_v1.VariableV1(0.0)
      c = constant_op.constant(10)

      # Fetching v directly will result in an uninitialized error
      with self.assertRaisesOpError("Attempting to use uninitialized value"):
        self.evaluate([c, v])

      # Use a control dependency to ensure init_variable is run
      # while asking for c
      real_v = control_flow_ops.with_dependencies(
          name="real_tensor",
          output_tensor=v._ref(),  # pylint: disable=protected-access
          dependencies=[v.initializer])
      c_val, real_v_val = self.evaluate([c, real_v])

    # Ensure the result of 'real_c' is the same as 'c'
    self.assertAllEqual(10, c_val)

    # Ensure that 'v' is initialized
    self.assertAllClose(0.0, real_v_val)

  @test_util.run_v1_only("b/120545219")
  def testWithTensorDependencies(self):
    with self.cached_session():
      v = variable_v1.VariableV1(0.0)
      c1 = constant_op.constant(10)
      c2 = constant_op.constant(20)

      # c1_with_init_v depends on the init op for v
      c1_with_init_v = control_flow_ops.with_dependencies(
          name="c1_with_init_v", output_tensor=c1, dependencies=[v.initializer])
      # c2_with_c1 depends on the value of c1_with_init_v
      c2_with_c1_dep = control_flow_ops.with_dependencies(
          name="c2_with_c1_dep",
          output_tensor=c2,
          dependencies=[c1_with_init_v])

      # Fetching v directly will result in an uninitialized error
      with self.assertRaisesOpError("Attempting to use uninitialized value"):
        self.evaluate(v)

      # Get the value of 'c2_with_c1_dep', which should cause 'v'
      # to be initialized.
      self.assertAllEqual(20, self.evaluate(c2_with_c1_dep))

      # Ensure that 'v' is initialized
      self.assertAllClose(0.0, self.evaluate(v))

  @test_util.run_v1_only("b/120545219")
  def testWithIndexedSlicesDependencies(self):
    with self.cached_session():
      v = variable_v1.VariableV1(
          np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(np.float32))
      v_at_1 = indexed_slices.IndexedSlices(v, constant_op.constant([1]))
      gather_v_at_1 = array_ops.gather(v_at_1.values, v_at_1.indices)
      v_at_1_after_init = control_flow_ops.with_dependencies([v.initializer],
                                                             v_at_1)
      gather_v_at_1_after_init = array_ops.gather(v_at_1_after_init.values,
                                                  v_at_1_after_init.indices)

      # Fetching gather_v_at_1 will result in an uninitialized error
      with self.assertRaisesOpError("Attempting to use uninitialized value"):
        self.evaluate(gather_v_at_1)

      # Getting gather_v_at_1_after_init will work, and initialize v.
      self.assertAllEqual([[10.0, 11.0]],
                          self.evaluate(gather_v_at_1_after_init))

      # Double check that 'v' is initialized
      self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]],
                          self.evaluate(v))

  def testDependenciesDevice(self):
    with ops.Graph().as_default():
      # device set on tensor => same device on dep.
      with ops.device("/job:ps"):
        vd = variable_v1.VariableV1([0.0])
      with_vd_dep = control_flow_ops.with_dependencies([vd.initializer], vd)
      self.assertTrue("/job:ps" in with_vd_dep.device)

      # No device set on tensor => no device on dep.
      vnod = variable_v1.VariableV1([0.0])
      with_vnod_dep = control_flow_ops.with_dependencies([vnod.initializer],
                                                         vnod)
      self.assertDeviceEqual(None, with_vnod_dep.device)

      # device set on tensor, default device on graph => default device on dep.
      vdef = variable_v1.VariableV1([0.0], name="vdef")
      with ops.device("/job:worker/device:GPU:1"):
        with_vdef_dep = control_flow_ops.with_dependencies([vdef.initializer],
                                                           vdef)
        # The device is empty, but the colocation constraint is set.
        self.assertDeviceEqual("", with_vdef_dep.device)
        self.assertEqual([b"loc:@vdef"], with_vdef_dep.op.colocation_groups())

  @test_util.run_v1_only("b/120545219")
  def testGroup(self):
    with self.cached_session() as sess:
      v1 = variable_v1.VariableV1([0.0])
      v2 = variable_v1.VariableV1([1.0])

      # Group init1 and init2 and run.
      init = control_flow_ops.group(v1.initializer, v2.initializer)
      # Fetching v1 directly will result in an uninitialized error
      with self.assertRaisesOpError("Attempting to use uninitialized value"):
        self.evaluate(v1)

      # Runs "init" before fetching v1 and v2.
      init.run()
      v1_val, v2_val = self.evaluate([v1, v2])

    # Ensure that v1 and v2 are initialized
    self.assertAllClose([0.0], v1_val)
    self.assertAllClose([1.0], v2_val)

  @test_util.run_v1_only("b/120545219")
  def testGroupEmpty(self):
    op = control_flow_ops.group()
    self.assertEqual(op.type, "NoOp")
    self.assertEqual(op.control_inputs, [])

  @test_util.run_deprecated_v1
  def testMergeShapes(self):
    # All inputs unknown.
    p1 = array_ops.placeholder(dtypes.float32)
    p2 = array_ops.placeholder(dtypes.float32)
    p3 = array_ops.placeholder(dtypes.float32)
    m, index = control_flow_ops.merge([p1, p2, p3])
    self.assertIs(None, m.get_shape().ndims)
    self.assertEqual([], index.get_shape())

    # All inputs known with different ranks.
    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
    p2 = array_ops.placeholder(dtypes.float32, shape=[1, 2, 3])
    m, index = control_flow_ops.merge([p1, p2])
    self.assertIs(None, m.get_shape().ndims)
    self.assertEqual([], index.get_shape())

    # All inputs known with some dimensions different.
    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
    p2 = array_ops.placeholder(dtypes.float32, shape=[2, 1])
    m, index = control_flow_ops.merge([p1, p2])
    self.assertEqual([None, None], m.get_shape().as_list())
    self.assertEqual([], index.get_shape())

    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
    p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
    m, index = control_flow_ops.merge([p1, p2])
    self.assertEqual([None, 2], m.get_shape().as_list())
    self.assertEqual([], index.get_shape())

    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
    p2 = array_ops.placeholder(dtypes.float32, shape=[2, 2])
    m, index = control_flow_ops.merge([p1, p2])
    self.assertEqual([None, 2], m.get_shape().as_list())
    self.assertEqual([], index.get_shape())

    # All inputs known with same dimensions.
    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
    p2 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
    m, index = control_flow_ops.merge([p1, p2])
    self.assertEqual([1, 2], m.get_shape().as_list())
    self.assertEqual([], index.get_shape())

    p1 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
    p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
    m, index = control_flow_ops.merge([p1, p2])
    self.assertEqual([None, 2], m.get_shape().as_list())
    self.assertEqual([], index.get_shape())

    p1 = array_ops.placeholder(dtypes.float32, shape=[None, None])
    p2 = array_ops.placeholder(dtypes.float32, shape=[None, None])
    m, index = control_flow_ops.merge([p1, p2])
    self.assertEqual([None, None], m.get_shape().as_list())
    self.assertEqual([], index.get_shape())

  @test_util.run_v1_only("b/120545219")
  def testRefSelect(self):
    index = array_ops.placeholder(dtypes.int32)

    # All inputs unknown.
    p1 = array_ops.placeholder(dtypes.float32)
    p2 = array_ops.placeholder(dtypes.float32)
    p3 = array_ops.placeholder(dtypes.float32)
    v1 = variable_v1.VariableV1(p1, validate_shape=False)
    v2 = variable_v1.VariableV1(p2, validate_shape=False)
    v3 = variable_v1.VariableV1(p3, validate_shape=False)
    self.assertIs(None, v1.get_shape().ndims)
    s = control_flow_ops.ref_select(index, [v1, v2, v3])
    self.assertIs(None, s.get_shape().ndims)

    # All inputs known but different.
    v1 = variable_v1.VariableV1([[1, 2]])
    v2 = variable_v1.VariableV1([[2], [1]])
    s = control_flow_ops.ref_select(index, [v1, v2])
    self.assertIs(None, s.get_shape().ndims)

    # All inputs known and same.
    v1 = variable_v1.VariableV1([[1, 2]])
    v2 = variable_v1.VariableV1([[1, 2]])
    s = control_flow_ops.ref_select(index, [v1, v2])
    self.assertEqual([1, 2], s.get_shape())

    # Possibly the same but not guaranteed.
    v1 = variable_v1.VariableV1([[1., 2.]])
    p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
    v2 = variable_v1.VariableV1(p2, validate_shape=False)
    s = control_flow_ops.ref_select(index, [v1, v2])
    self.assertEqual(None, s.get_shape())

  @test_util.run_deprecated_v1
  def testRunLoopTensor(self):
    with self.cached_session() as sess:
      tensor_list = []

      def condition(t):
        return t < constant_op.constant(5)

      def body(_):
        tensor_list.append(constant_op.constant(5))
        return constant_op.constant(10)

      result = while_loop_tf.while_loop(
          condition, body, [constant_op.constant(4)])
      self.assertEqual(10, self.evaluate(result))

      # Ensure that we cannot run a tensor that escapes the loop body
      # accidentally.
      with self.assertRaises(ValueError):
        sess.run(tensor_list[0])

  @test_util.run_v1_only("b/120545219")
  def testWhilePyFuncBasic(self):

    def func(x):
      return np.square(x)

    with self.cached_session():
      r = while_loop_tf.while_loop(
          lambda i, v: i < 4, lambda i, v:
          [i + 1, script_ops.py_func(func, [v], [dtypes.float32])[0]],
          [constant_op.constant(0),
           constant_op.constant(2.0, dtypes.float32)],
          [tensor_shape.unknown_shape(),
           tensor_shape.unknown_shape()])
      self.assertEqual(self.evaluate(r[1]), 65536.0)

  @test_util.run_v1_only("b/120545219")
  def testWhileFuncBasic(self):

    @function.Defun(dtypes.float32)
    def func(x):
      return math_ops.square(math_ops.square(x))

    with self.cached_session():
      x = constant_op.constant(2.0, dtypes.float32)
      r = while_loop_tf.while_loop(
          lambda i, v: i < 2, lambda i, v: [i + 1, func(v)],
          [constant_op.constant(0), x],
          [tensor_shape.unknown_shape(),
           tensor_shape.unknown_shape()])
      grad = gradients_impl.gradients(r, x)[0]
      self.assertEqual(self.evaluate(r[1]), 65536.0)
      self.assertEqual(self.evaluate(grad), 524288.0)
      # while_v2 does not have stacks.
      if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
        self.assertEqual(
            len([op for op in x.graph.get_operations() if op.type == "StackV2"
                ]), 1)


  @test_util.run_v1_only("b/120545219")
  def testQIntSwitchMerge(self):
    with self.cached_session(force_gpu=test.is_gpu_available()) as sess:
      constant_qint = constant_op.constant(np.array([42]), dtypes.qint8)
      cond = constant_op.constant(True, dtypes.bool)
      v_f, v_t = control_flow_ops.switch(constant_qint, cond)
      result = control_flow_ops.merge([v_f, v_t])
      self.evaluate(result)

  @test_util.run_v1_only("b/120545219")
  def testQIntRefSwitchMerge(self):
    with self.cached_session(use_gpu=test.is_gpu_available()) as sess:
      var_qint = gen_state_ops.variable(
          shape=[1], dtype=dtypes.qint8, name="v", container="", shared_name="")
      assign_op = state_ops.assign(
          var_qint, constant_op.constant(np.array([42]), dtypes.qint8))
      self.evaluate(assign_op)

      cond = constant_op.constant(True, dtypes.bool)
      v_f, v_t = control_flow_ops.ref_switch(var_qint, cond)
      result = control_flow_ops.ref_merge([v_f, v_t])
      self.evaluate(result)

  @test_util.run_v1_only("b/120545219")
  def testUInt64SwitchMerge(self):
    with self.cached_session(force_gpu=test.is_gpu_available()) as sess:
      constant_uint64 = constant_op.constant(np.array([42]), dtypes.uint64)
      cond = constant_op.constant(True, dtypes.bool)
      v_f, v_t = control_flow_ops.switch(constant_uint64, cond)
      result = control_flow_ops.merge([v_f, v_t])
      self.evaluate(result)

  def testSwitchEagerMode(self):
    if not context.executing_eagerly():
      return
    input_data = [1, 2, 3, 4]
    vf, vt = control_flow_ops.switch(input_data, False)
    self.assertAllEqual(vf, input_data)
    self.assertAllEqual(vt, [])

  @test_util.run_deprecated_v1
  def testQIntArgAndRet(self):

    @function.Defun(dtypes.qint8)
    def func(x):
      return x

    with self.cached_session(force_gpu=test.is_gpu_available()) as sess:
      qint = constant_op.constant(np.array([42]), dtypes.qint8)
      result = func(qint)
      self.evaluate(result)

  def testSparseIdentity(self):
    st1 = sparse_tensor.SparseTensor([[0, 5]], ['x'], [10, 10])
    st2 = control_flow_ops._Identity(st1)
    self.assertAllEqual(st1.indices, st2.indices)
    self.assertAllEqual(st1.values, st2.values)
    self.assertAllEqual(st1.dense_shape, st2.dense_shape)

  def testSparseEnterExit(self):
    st1 = sparse_tensor.SparseTensor([[0, 5]], ['x'], [10, 10])
    st2 = control_flow_ops._Enter(st1, "foo_1")
    st3 = control_flow_ops.exit(st2)
    self.assertAllEqual(st1.indices, st3.indices)
    self.assertAllEqual(st1.values, st3.values)
    self.assertAllEqual(st1.dense_shape, st3.dense_shape)

  def _buildWhileWithShapeInvariants(self, shape_invariants):
    r = constant_op.constant([1, 2])

    def cond(_):
      return False

    def body(_):
      return constant_op.constant([1])

    return while_loop_tf.while_loop(
        cond, body, [r], shape_invariants=shape_invariants)

  def testWhileOutputShapeWithShapeInvariantsUnknownRank(self):
    @eager_def_function.function
    def runTest():
      while_output = self._buildWhileWithShapeInvariants(
          [tensor_shape.TensorShape(None)])
      self.assertIsNone(while_output.shape.rank)
    runTest()

  def testWhileOutputShapeWithShapeInvariantsPartialShape(self):
    @eager_def_function.function
    def runTest():
      while_output = self._buildWhileWithShapeInvariants(
          [tensor_shape.TensorShape([None])])
      self.assertAllEqual(while_output.shape.as_list(), [None])
    runTest()

  def testFunctionInWhile(self):

    @eager_def_function.function
    def body(x):
      return x + 1

    r = while_loop_tf.while_loop(lambda x: x < 5, body, [0])
    self.assertAllEqual(r, 5.)


class ControlFlowContextCheckTest(test.TestCase):

  def _getWhileTensor(self):
    """Creates and returns a tensor from a while context."""
    tensor = []

    def body(i):
      if not tensor:
        tensor.append(constant_op.constant(1))
      return i + tensor[0]

    while_loop_tf.while_loop(lambda i: i < 10, body, [0])
    return tensor[0]

  def _getCondTensor(self):
    cond_tensor = []

    def true_fn():
      if not cond_tensor:
        cond_tensor.append(constant_op.constant(1))
      return cond_tensor[0]

    tf_cond.cond(
        math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0))
    return cond_tensor[0]

  @test_util.run_v1_only("b/120545219")
  def testInvalidContext(self):
    # Accessing a while loop tensor outside of control flow is illegal.
    while_tensor = self._getWhileTensor()
    with self.assertRaisesRegex(
        ValueError,
        "Cannot use 'while/Const_1' as input to 'Add' because 'while/Const_1' "
        "is in a while loop. See info log for more details."):
      math_ops.add(1, while_tensor)

  @test_util.run_v1_only("b/120545219")
  def testInvalidContextInCond(self):
    # Accessing a while loop tensor in cond is illegal.
    while_tensor = self._getWhileTensor()
    with self.assertRaisesRegex(
        ValueError, "Cannot use 'while/Const_1' as input to 'cond/Add' because "
        "'while/Const_1' is in a while loop. See info log for more details."):
      # TODO(skyewm): this passes if we return while_tensor directly instead
      # of using it as input to another op.
      tf_cond.cond(
          math_ops.less(1, 2), lambda: math_ops.add(1, while_tensor),
          lambda: constant_op.constant(0))

  @test_util.run_v1_only("b/120545219")
  def testInvalidContextInWhile(self):
    # Accessing a while loop tensor in a different while loop is illegal.
    while_tensor = self._getWhileTensor()
    with self.assertRaisesRegex(
        ValueError,
        "Cannot use 'while/Const_1' as input to 'while_1/Add' because they are "
        "in different while loops. See info log for more details."):
      while_loop_tf.while_loop(
          lambda i: i < 10, lambda x: math_ops.add(1, while_tensor), [0])

    with self.assertRaisesRegex(
        ValueError,
        "Cannot use 'while/Const_1' as input to 'while_2/NextIteration' "
        "because they are in different while loops. See info log for more "
        "details."):
      while_loop_tf.while_loop(lambda i: i < 10, lambda i: while_tensor, [0])

  def testValidCondContext(self):
    # Accessing a tensor from a cond context is OK (although dangerous).
    cond_tensor = self._getCondTensor()
    math_ops.add(1, cond_tensor)

  def testValidCondContextBranches(self):
    # Accessing a tensor from a cond context from the other branch's cond
    # context is OK (although dangerous).
    cond_tensor = []

    def branch_fn():
      if not cond_tensor:
        cond_tensor.append(constant_op.constant(1))
      return cond_tensor[0]

    tf_cond.cond(math_ops.less(1, 2), branch_fn, branch_fn)

  @test_util.run_v1_only("b/120545219")
  def testValidWhileContext(self):
    # Accessing a tensor in a nested while is OK.
    def body(_):
      c = constant_op.constant(1)
      return while_loop_tf.while_loop(lambda i: i < 3, lambda i: i + c, [0])

    while_loop_tf.while_loop(lambda i: i < 5, body, [0])

  @test_util.run_v1_only("b/120545219")
  def testValidNestedContexts(self):
    # Accessing a tensor from a cond context in a while context, all inside an
    # outer while context, is OK.
    def body(_):
      cond_tensor = self._getCondTensor()
      # Create another cond containing the while loop for good measure
      return tf_cond.cond(
          math_ops.less(1, 2),
          lambda: while_loop_tf.while_loop(lambda i: i < 3,
                                           lambda i: i + cond_tensor, [0]),
          lambda: constant_op.constant(0))

    while_loop_tf.while_loop(lambda i: i < 5, body, [0])

  @test_util.run_v1_only("b/120545219")
  def testInvalidNestedContexts(self):
    # Accessing a tensor from a while context in a different while context, all
    # inside a cond context, is illegal.
    def true_fn():
      while_tensor = self._getWhileTensor()
      return while_loop_tf.while_loop(
          lambda i: i < 3, lambda i: i + while_tensor, [0])

    with self.assertRaisesRegex(
        ValueError,
        "Cannot use 'cond/while/Const_1' as input to 'cond/while_1/add' because"
        " they are in different while loops. See info log for more details."):
      tf_cond.cond(
          math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0))


class TupleTest(test.TestCase):

  @test_util.run_v1_only("b/120545219")
  def testTensors(self):
    for v1_first in [True, False]:
      with self.cached_session():
        v1 = variable_v1.VariableV1([1.0])
        add1 = math_ops.add(
            control_flow_ops.with_dependencies([v1.initializer], v1._ref()),  # pylint: disable=protected-access
            2.0)
        v2 = variable_v1.VariableV1([10.0])
        add2 = math_ops.add(
            control_flow_ops.with_dependencies([v2.initializer], v2._ref()),  # pylint: disable=protected-access
            20.0)
        t1, _, t2 = control_flow_ops.tuple([add1, None, add2])

        # v1 is not initialized.
        with self.assertRaisesOpError("Attempting to use uninitialized value"):
          self.evaluate(v1)

        # v2 is not initialized.
        with self.assertRaisesOpError("Attempting to use uninitialized value"):
          self.evaluate(v2)

        if v1_first:
          # Getting t1 initializes v2.
          self.assertAllClose([3.0], self.evaluate(t1))
          self.assertAllClose([10.0], self.evaluate(v2))
        else:
          # Getting t2 initializes v1.
          self.assertAllClose([30.0], self.evaluate(t2))
          self.assertAllClose([1.0], self.evaluate(v1))

  @test_util.run_v1_only("b/120545219")
  def testIndexedSlices(self):
    for v1_first in [True, False]:
      with self.cached_session():
        v1 = variable_v1.VariableV1(
            np.array([[0.0, 1.0], [10.0, 11.0], [20.0,
                                                 21.0]]).astype(np.float32))
        v1_at_1 = indexed_slices.IndexedSlices(
            control_flow_ops.with_dependencies([v1.initializer], v1._ref()),  # pylint: disable=protected-access
            constant_op.constant([1]))

        v2 = variable_v1.VariableV1(
            np.array([[0.1, 1.1], [10.1, 11.1], [20.1,
                                                 21.1]]).astype(np.float32))
        v2_at_1 = indexed_slices.IndexedSlices(
            control_flow_ops.with_dependencies([v2.initializer], v2._ref()),  # pylint: disable=protected-access
            constant_op.constant([1]))

        st1, st2 = control_flow_ops.tuple([v1_at_1, v2_at_1])
        g1 = array_ops.gather(st1.values, st1.indices)
        g2 = array_ops.gather(st2.values, st2.indices)

        # v1 is not initialized.
        with self.assertRaisesOpError("Attempting to use uninitialized value"):
          self.evaluate(v1)

        # v2 is not initialized.
        with self.assertRaisesOpError("Attempting to use uninitialized value"):
          self.evaluate(v2)

        if v1_first:
          # Getting g1 initializes v2.
          self.assertAllClose([[10.0, 11.0]], self.evaluate(g1))
          self.assertAllClose([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]],
                              self.evaluate(v2))
        else:
          # Getting g2 initializes v1.
          self.assertAllClose([[10.1, 11.1]], self.evaluate(g2))
          self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]],
                              self.evaluate(v1))

  def testAcceptTensorsAsControlInputs(self):
    with self.cached_session():
      var = variable_v1.VariableV1(0)
      assign = state_ops.assign(var, 1)
      t, = control_flow_ops.tuple(
          [constant_op.constant(0)], control_inputs=[assign])

      # Should trigger the assign.
      self.evaluate(t)

      self.assertEqual(1, self.evaluate(var))


class AssertTest(test.TestCase):

  @test_util.run_deprecated_v1
  def testGuardedAssertDoesNotCopyWhenTrue(self):
    if test_util.is_gpu_available():
      self.skipTest("b/128646478 fails in opensource")

    with self.session() as sess:
      with ops.device(test.gpu_device_name()):
        value = constant_op.constant(1.0)
      with ops.device("/cpu:0"):
        true = constant_op.constant(True)
        guarded_assert = control_flow_assert.Assert(
            true, [value], name="guarded")
        unguarded_assert = gen_logging_ops._assert(
            true, [value], name="unguarded")
      opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
      guarded_metadata = config_pb2.RunMetadata()
      sess.run(guarded_assert, options=opts, run_metadata=guarded_metadata)
      unguarded_metadata = config_pb2.RunMetadata()
      sess.run(unguarded_assert, options=opts, run_metadata=unguarded_metadata)
      guarded_nodestat_names = [
          n.node_name
          for d in guarded_metadata.step_stats.dev_stats
          for n in d.node_stats
      ]
      unguarded_nodestat_names = [
          n.node_name
          for d in unguarded_metadata.step_stats.dev_stats
          for n in d.node_stats
      ]
      guarded_memcpy_nodestat_names = [
          n for n in guarded_nodestat_names if "MEMCPYDtoH" in n
      ]
      unguarded_memcpy_nodestat_names = [
          n for n in unguarded_nodestat_names if "MEMCPYDtoH" in n
      ]
      if "GPU" in [d.device_type for d in device_lib.list_local_devices()]:
        # A copy was performed for the unguarded assert
        self.assertLess(0, len(unguarded_memcpy_nodestat_names),
                        str(unguarded_nodestat_names))
      # No copy was performed for the guarded assert
      self.assertEqual([], guarded_memcpy_nodestat_names)


class WhileOpBenchmark(test.Benchmark):
  """Evaluate the performance of while_loop op."""

  def _getInitVariables(self):
    batch_size = 10
    image_size = 256
    kernel_size = 3
    depth = 16

    init_step = constant_op.constant(-1)
    image = variable_scope.get_variable(
        "image",
        initializer=random_ops.random_normal(
            [batch_size, image_size, image_size, depth],
            dtype=dtypes.float32,
            stddev=1e-1))
    kernel = variable_scope.get_variable(
        "weights",
        initializer=random_ops.truncated_normal(
            [kernel_size, kernel_size, depth, depth],
            dtype=dtypes.float32,
            stddev=1e-1))
    return init_step, image, kernel

  def _runOneBenchmark(self,
                       default_device,
                       num_iters=10,
                       static_unroll=False,
                       steps=10):
    """Evaluate the while loop performance.

    Args:
      default_device: The default device to run all ops except the loop_body.
        loop_body is always run on GPU.
      num_iters: Number of iterations to run.
      static_unroll: If true, run unrolled version; otherwise, run while_loop.
      steps: Total number of repeated steps to run the loop.

    Returns:
      The duration of the run in seconds.
    """

    def loop_body(i, x):
      with ops.device("/gpu:0"):
        # Always put loop body on GPU.
        nx = nn_ops.conv2d(
            input=x,
            filter=kernel,
            strides=[1, 1, 1, 1],
            padding="SAME",
            data_format="NHWC",
            name="conv2d")
        ni = math_ops.add(i, 1)
        return ni, nx

    ops.reset_default_graph()
    with session.Session() as sess, ops.device(default_device):
      # Get the initial id i, input x, and kernel.
      i, x, kernel = self._getInitVariables()
      self.evaluate(variables.global_variables_initializer())

      if static_unroll:
        for _ in range(steps):
          i, x = loop_body(i, x)
      else:
        i, x = while_loop_tf.while_loop(
            lambda i, _: i < steps,
            loop_body, [i, x],
            parallel_iterations=steps,
            swap_memory=True)

      r = math_ops.reduce_sum(x)
      dx, dk = gradients_impl.gradients(r, [x, kernel])
      # Use group to avoid fetching back results.
      r = control_flow_ops.group(dx, dk)

      for _ in range(3):
        # exclude warm up time
        self.evaluate(r)

      start_time = time.time()
      for _ in range(num_iters):
        self.evaluate(r)
      return (time.time() - start_time) / num_iters

  def benchmarkWhileOpCrossDevicePlacement(self):
    iters = 10
    # Run loop body on GPU, but other ops on CPU.
    duration = self._runOneBenchmark("cpu", iters, static_unroll=False)
    self.report_benchmark(
        name="while_op_cross_device", iters=iters, wall_time=duration)

  def benchmarkWhileOpSameDevicePlacement(self):
    iters = 10
    # Run all ops on the same GPU device.
    duration = self._runOneBenchmark("gpu", iters, static_unroll=False)
    self.report_benchmark(
        name="while_op_same_device", iters=iters, wall_time=duration)

  def benchmarkWhileOpUnrollCrossDevicePlacement(self):
    iters = 10
    # Run loop body on GPU, but other ops on CPU.
    duration = self._runOneBenchmark("cpu", iters, static_unroll=True)
    self.report_benchmark(
        name="unroll_cross_device_cpu", iters=iters, wall_time=duration)

  def benchmarkWhileOpUnrollSameDevicePlacement(self):
    iters = 10
    # Run all ops on GPU.
    duration = self._runOneBenchmark("gpu", iters, static_unroll=True)
    self.report_benchmark(
        name="unroll_same_device", iters=iters, wall_time=duration)


@test_util.with_control_flow_v2
class EagerTest(test.TestCase):

  def testCond(self):
    with context.eager_mode():
      pred = math_ops.less(1, 2)
      fn1 = lambda: [constant_op.constant(10)]
      fn2 = lambda: [constant_op.constant(20)]
      r = tf_cond.cond(pred, fn1, fn2)

      self.assertAllEqual(r.numpy(), 10)
      self.assertFalse(isinstance(r, list))

  # TODO(b/117279927): Re-enable once msan failure is fixed.
  def DISABLED_testCondInDefun(self):
    with context.eager_mode():

      @eager_def_function.function
      def foo(pred):
        # TODO(b/111124878): this only needs to output one element.
        fn1 = lambda: (constant_op.constant(10), constant_op.constant(100))
        fn2 = lambda: (constant_op.constant(20), constant_op.constant(200))
        return tf_cond.cond(constant_op.constant(pred), fn1, fn2)

      r = foo(True)
      self.assertAllEqual(r[0].numpy(), 10)
      self.assertNotIsInstance(r, list)

      r = foo(False)
      self.assertAllEqual(r[0].numpy(), 20)
      self.assertFalse(isinstance(r, list))

  def testWhileLoop(self):
    with context.eager_mode():
      tensor = constant_op.constant([1, 2, 3, 4, 5])
      self.assertAllEqual(isum(tensor).numpy(), [46, 47, 48, 49, 50])

  def testWhileLoopWithMaxIterations(self):
    with context.eager_mode():
      tensor = constant_op.constant([1, 2, 3, 4, 5])
      self.assertAllEqual(
          isum(tensor, maximum_iterations=3).numpy(),
          [1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3])

  @test_util.run_v1_only("b/120545219")
  def testWhileWithMaximumIterationsAndSingleArgument(self):
    with context.eager_mode():
      tensor = constant_op.constant(0)
      r = while_loop_tf.while_loop(
          lambda i: i < 3, lambda i: i + 1, [tensor], maximum_iterations=1)
      self.assertEqual(1, r.numpy())

  def testWithDependencies(self):
    with context.eager_mode():
      t1 = constant_op.constant(1)
      t2 = constant_op.constant(2)
      t3 = control_flow_ops.with_dependencies(t1, t2)
      self.assertAllEqual(t2.numpy(), t3.numpy())

  def testTuple(self):
    with context.eager_mode():
      t1 = constant_op.constant(1)
      t2 = constant_op.constant(2)
      tup1, tup2 = control_flow_ops.tuple([t1, t2])
      self.assertAllEqual(t1.numpy(), tup1.numpy())
      self.assertAllEqual(t2.numpy(), tup2.numpy())

  @test_util.run_v1_only("b/120545219")
  def testCase(self):
    with context.eager_mode():
      x = constant_op.constant(1)
      y = constant_op.constant(2)
      z = constant_op.constant(3)
      f1 = lambda: constant_op.constant(17)
      f2 = lambda: constant_op.constant(23)
      f3 = lambda: constant_op.constant(-1)

      r1 = control_flow_case.case([(x < y, f1), (x > z, f2)],
                                  default=f3,
                                  exclusive=True)
      self.assertAllEqual(r1.numpy(), 17)


if __name__ == "__main__":
  test.main()