tensorflow/tensorflow

View on GitHub
tensorflow/python/framework/ops_test.py

Summary

Maintainability
F
1 wk
Test Coverage
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.framework.ops."""

import gc
import os
import threading
import weakref

from absl.testing import parameterized
import numpy as np

from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import full_type_pb2
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as eager_function
from tensorflow.python.eager import wrap_function
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor as tensor_lib
from tensorflow.python.framework import tensor_conversion_registry
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import type_spec
from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond
from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops import while_loop
import tensorflow.python.ops.gradients  # pylint: disable=unused-import
from tensorflow.python.platform import googletest
from tensorflow.python.util import compat


class ResourceTest(test_util.TensorFlowTestCase):

  @test_util.run_deprecated_v1
  def testBuildGraph(self):
    with self.cached_session():
      pt = test_ops.stub_resource_handle_op(container="a", shared_name="b")
      test_ops.resource_create_op(pt).run()

  @test_util.run_deprecated_v1
  def testInitialize(self):
    with self.cached_session():
      handle = test_ops.stub_resource_handle_op(container="a", shared_name="b")
      resources.register_resource(
          handle=handle,
          create_op=test_ops.resource_create_op(handle),
          is_initialized_op=test_ops.resource_initialized_op(handle))
      self.assertEqual(
          len(
              resources.report_uninitialized_resources(
                  resources.shared_resources()).eval()), 1)
      resources.initialize_resources(resources.shared_resources()).run()
      self.assertEqual(
          len(
              resources.report_uninitialized_resources(
                  resources.shared_resources()).eval()), 0)


class TensorAndShapeTest(test_util.TensorFlowTestCase):

  def testShape(self):
    op = ops.Operation.from_node_def(
        ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]
    )
    t = op.outputs[0]
    self.assertEqual(tensor_shape.unknown_shape(), t.get_shape())
    t.set_shape([1, 2, 3])
    self.assertEqual([1, 2, 3], t.get_shape())

  def testNdim(self):

    @def_function.function
    def f(a):
      self.assertEqual(a.ndim, 2)
      return 0

    x = array_ops.zeros((3, 4))
    f(x)

  def testIterable(self):
    if not context.executing_eagerly():
      self.skipTest("Eager-mode test")
    op = ops.Operation.from_node_def(
        ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]
    )
    t = op.outputs[0]
    with self.assertRaisesRegex(TypeError, "Cannot iterate"):
      iter(t)

  def testIterableGraph(self):
    if context.executing_eagerly():
      self.skipTest("Graph-mode test")

    op = ops.Operation.from_node_def(
        ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]
    )
    t = op.outputs[0]
    with self.assertRaisesRegex(
        TypeError, "Iterating.*not allowed.*Graph mode"):
      next(iter(t))
    with self.assertRaisesRegex(
        TypeError, "Iterating.*AutoGraph.*unsupported feature"):
      with ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED):
        next(iter(t))
    with self.assertRaisesRegex(
        TypeError, "Iterating.*AutoGraph.*not be visible"):
      with ag_ctx.ControlStatusCtx(ag_ctx.Status.DISABLED):
        next(iter(t))

  def testImplicitBool(self):
    op = ops.Operation.from_node_def(
        ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.bool]
    )
    t = op.outputs[0]
    with self.assertRaisesRegex(
        TypeError, "Using.*as a.*bool.*not allowed.*Graph mode"):
      bool(t)
    with self.assertRaisesRegex(
        TypeError, "Using.*as a.*bool.*AutoGraph.*unsupported feature"):
      with ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED):
        bool(t)
    with self.assertRaisesRegex(
        TypeError, "Using.*as a.*bool.*AutoGraph.*not be visible"):
      with ag_ctx.ControlStatusCtx(ag_ctx.Status.DISABLED):
        bool(t)

  def testAddShape(self):
    with self.cached_session():
      a = array_ops.zeros([2, 3])
      b = array_ops.ones([1, 3])
      c = a + b
      self.assertEqual([2, 3], c.shape)

  @test_util.run_deprecated_v1
  def testUnknownDim(self):
    with self.cached_session():
      a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
      b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
      c = a + b
      self.assertEqual([2, None, 3], c.shape.as_list())

  @test_util.run_deprecated_v1
  def testUnknownShape(self):
    with self.cached_session():
      a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
      b = array_ops.ones([1, 3])
      c = a + b
      self.assertEqual(tensor_shape.unknown_shape(), c.shape)

  @test_util.run_deprecated_v1
  def testScalarShape(self):
    with self.cached_session():
      a = array_ops.placeholder(dtype=dtypes.float32, shape=[])
      b = array_ops.ones([])
      c = a + b
      self.assertEqual(tensor_shape.TensorShape([]), c.shape)

  @test_util.run_deprecated_v1
  def testShapeFunctionError(self):
    with self.cached_session():
      a = array_ops.ones([1, 2, 3])
      b = array_ops.ones([4, 5, 6])
      with self.assertRaisesRegex(
          ValueError, r"Dimensions must be equal, but are 2 and 5 for .*add"
          r".*Add(V2)?.* with input shapes: \[1,2,3\], \[4,5,6\]."):
        _ = a + b

  def testNumpyArray(self):
    with ops.Graph().as_default():
      x = array_ops.ones((3, 4), name="test_ones")

    with self.assertRaisesRegex(NotImplementedError,
                                r"Cannot convert a symbolic.+test_ones"):
      np.array(x)

    with self.assertRaisesRegex(TypeError, "not well defined.+test_ones"):
      len(x)

    # EagerTensors should still behave as numpy arrays.
    with context.eager_mode():
      x = array_ops.ones((3, 4))

    self.assertAllEqual(x, np.ones((3, 4)))
    self.assertAllEqual(np.array(x), np.ones((3, 4)))
    self.assertLen(x, 3)

  def testConstructor(self):
    a = array_ops.ones([])
    for name in ["T", "astype", "ravel", "transpose", "reshape", "clip", "size",
                 "tolist", "data"]:
      with self.assertRaisesRegex(
          AttributeError, r"If you are looking for numpy-related methods"):
        getattr(a, name)
    with self.assertRaisesRegex(
        AttributeError, r"object has no attribute"):
      a.foo_bar()

  def testRef(self):
    x1 = constant_op.constant(3)
    x2 = x1
    y = constant_op.constant(3)
    z = constant_op.constant([6, 10])
    w = variables.Variable(5)

    self.assertEqual(x1.ref(), x1.ref())
    self.assertEqual(x2.ref(), x2.ref())
    self.assertEqual(x1.ref(), x2.ref())
    self.assertEqual(y.ref(), y.ref())
    self.assertEqual(z.ref(), z.ref())
    self.assertEqual(w.ref(), w.ref())

    self.assertNotEqual(x1.ref(), y.ref())
    self.assertNotEqual(x1.ref(), z.ref())
    self.assertNotEqual(x1.ref(), w.ref())
    self.assertNotEqual(y.ref(), z.ref())
    self.assertNotEqual(y.ref(), w.ref())
    self.assertNotEqual(z.ref(), w.ref())

  def testRefDeref(self):
    x1 = constant_op.constant(3)
    x2 = x1
    y = constant_op.constant(3)
    z = constant_op.constant([6, 10])
    w = variables.Variable(5)

    self.assertIs(x1, x1.ref().deref())
    self.assertIs(x2, x2.ref().deref())
    self.assertIs(x1, x2.ref().deref())
    self.assertIs(x2, x1.ref().deref())
    self.assertIs(y, y.ref().deref())
    self.assertIs(z, z.ref().deref())

    self.assertIsNot(x1, y.ref().deref())
    self.assertIsNot(x1, z.ref().deref())
    self.assertIsNot(x1, w.ref().deref())
    self.assertIsNot(y, z.ref().deref())
    self.assertIsNot(y, w.ref().deref())
    self.assertIsNot(z, w.ref().deref())

  def testRefInSet(self):
    x1 = constant_op.constant(3)
    x2 = x1
    y = constant_op.constant(3)
    z = constant_op.constant([6, 10])
    w = variables.Variable(5)

    self.assertEqual(x1.ref(), x2.ref())

    tensor_set = {
        x1.ref(),
        x2.ref(),
        y.ref(),
        z.ref(),
        w.ref(),
    }

    self.assertLen(tensor_set, 4)
    self.assertIn(x1.ref(), tensor_set)
    self.assertIn(x2.ref(), tensor_set)
    self.assertIn(y.ref(), tensor_set)
    self.assertIn(z.ref(), tensor_set)
    self.assertIn(w.ref(), tensor_set)

  def testRefInDict(self):
    x1 = constant_op.constant(3)
    x2 = x1
    y = constant_op.constant(3)
    z = constant_op.constant([6, 10])
    w = variables.Variable(5)

    self.assertEqual(x1.ref(), x2.ref())

    tensor_dict = {
        x1.ref(): "x1",
        y.ref(): "y",
        z.ref(): "z",
        w.ref(): "w",
    }

    self.assertLen(tensor_dict, 4)

    # Overwriting x1
    tensor_dict[x2.ref()] = "x2"
    self.assertLen(tensor_dict, 4)

    self.assertEqual(tensor_dict[x1.ref()], "x2")
    self.assertEqual(tensor_dict[x2.ref()], "x2")
    self.assertEqual(tensor_dict[y.ref()], "y")
    self.assertEqual(tensor_dict[z.ref()], "z")
    self.assertEqual(tensor_dict[w.ref()], "w")

  def testTensorRefStrong(self):
    x = constant_op.constant(1.)
    x_ref = x.ref()
    del x
    self.assertIsNotNone(x_ref.deref())

  def testVariableRefStrong(self):
    x = variables.Variable(1.)
    x_ref = x.ref()
    del x
    self.assertIsNotNone(x_ref.deref())

  @test_util.run_in_graph_and_eager_modes
  def testBitwiseAndNumeric(self):
    x = constant_op.constant([0, 1, 3])
    y = constant_op.constant([1, 1, 1])

    z = x & y

    self.assertAllEqual(z, [0, 1, 1])

  @test_util.run_in_graph_and_eager_modes
  def testBitwiseAndBool(self):
    x = constant_op.constant([False, False, True, True])
    y = constant_op.constant([False, True, False, True])

    z = x & y

    self.assertAllEqual(z, [False, False, False, True])

  @test_util.run_in_graph_and_eager_modes
  def testBitwiseAndErrors(self):
    x_int = constant_op.constant(0)
    x_bool = constant_op.constant(True)

    if context.executing_eagerly():  # :(
      expected_errtype = errors.InvalidArgumentError
    else:
      expected_errtype = TypeError

    with self.assertRaises(expected_errtype):
      _ = x_int & x_bool
    with self.assertRaises(expected_errtype):
      _ = x_int & constant_op.constant("a")

    with self.assertRaises(expected_errtype):
      _ = x_bool & x_int
    with self.assertRaises(expected_errtype):
      _ = x_bool & constant_op.constant("a")

    with self.assertRaises(expected_errtype):
      _ = constant_op.constant("a") & constant_op.constant("b")

  @test_util.run_in_graph_and_eager_modes
  def testBitwiseOrNumeric(self):
    x = constant_op.constant([0, 1, 2])
    y = constant_op.constant([1, 1, 1])

    z = x | y

    self.assertAllEqual(z, [1, 1, 3])

  @test_util.run_in_graph_and_eager_modes
  def testBitwiseOrBool(self):
    x = constant_op.constant([False, False, True, True])
    y = constant_op.constant([False, True, False, True])

    z = x | y

    self.assertAllEqual(z, [False, True, True, True])

  @test_util.run_in_graph_and_eager_modes
  def testBitwiseOrErrors(self):
    x_int = constant_op.constant(0)
    x_bool = constant_op.constant(True)

    if context.executing_eagerly():  # :(
      expected_errtype = errors.InvalidArgumentError
    else:
      expected_errtype = TypeError

    with self.assertRaises(expected_errtype):
      _ = x_int | x_bool
    with self.assertRaises(expected_errtype):
      _ = x_int | constant_op.constant("a")

    with self.assertRaises(expected_errtype):
      _ = x_bool | x_int
    with self.assertRaises(expected_errtype):
      _ = x_bool | constant_op.constant("a")

    with self.assertRaises(expected_errtype):
      _ = constant_op.constant("a") | constant_op.constant("b")

  @test_util.run_in_graph_and_eager_modes
  def testBitwiseXorNumeric(self):
    x = constant_op.constant([0, 1, 3])
    y = constant_op.constant([1, 1, 1])

    z = x ^ y

    self.assertAllEqual(z, [1, 0, 2])

  @test_util.run_in_graph_and_eager_modes
  def testBitwiseXorBool(self):
    x = constant_op.constant([False, False, True, True])
    y = constant_op.constant([False, True, False, True])

    z = x ^ y

    self.assertAllEqual(z, [False, True, True, False])

  @test_util.run_in_graph_and_eager_modes
  def testBitwiseXorErrors(self):
    x_int = constant_op.constant(0)
    x_bool = constant_op.constant(True)

    if context.executing_eagerly():  # :(
      expected_errtype = errors.InvalidArgumentError
    else:
      expected_errtype = TypeError

    with self.assertRaises(expected_errtype):
      _ = x_int ^ x_bool
    with self.assertRaises(expected_errtype):
      _ = x_int ^ constant_op.constant("a")

    with self.assertRaises(expected_errtype):
      _ = x_bool ^ x_int
    with self.assertRaises(expected_errtype):
      _ = x_bool ^ constant_op.constant("a")

    with self.assertRaises(expected_errtype):
      _ = constant_op.constant("a") ^ constant_op.constant("b")

  @test_util.run_in_graph_and_eager_modes
  def testBitwiseNotNumeric(self):
    x = constant_op.constant([0, dtypes.int32.min, 1])

    # pylint: disable=invalid-unary-operand-type
    y = ~x

    self.assertAllEqual(y, [-1, dtypes.int32.max, -2])

  @test_util.run_in_graph_and_eager_modes
  def testBitwiseNotBool(self):
    x = constant_op.constant([False, True])

    # pylint: disable=invalid-unary-operand-type
    y = ~x

    self.assertAllEqual(y, [True, False])

  @test_util.run_in_graph_and_eager_modes
  def testBitwiseNotErrors(self):
    if context.executing_eagerly():  # :(
      expected_errtype = errors.InvalidArgumentError
    else:
      expected_errtype = TypeError

    # pylint: disable=invalid-unary-operand-type
    with self.assertRaises(expected_errtype):
      _ = ~constant_op.constant("a")


@test_util.run_all_in_graph_and_eager_modes
class IndexedSlicesTest(test_util.TensorFlowTestCase):

  def testToTensor(self):
    values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
    indices = constant_op.constant([0, 2])
    x = indexed_slices.IndexedSlices(values, indices)
    with self.assertRaises(ValueError):
      tensor = ops.convert_to_tensor(x, name="tensor")
    self.assertEqual(tensor_shape.TensorShape(None), x.shape)

    dense_shape = constant_op.constant([3, 2])
    y = indexed_slices.IndexedSlices(values, indices, dense_shape)
    tensor = ops.convert_to_tensor(y, name="tensor")
    self.assertAllEqual(tensor.shape, y.shape)
    self.assertAllEqual(self.evaluate(tensor), [[2, 3], [0, 0], [5, 7]])

  @test_util.run_gpu_only
  def testEagerCopy(self):
    with context.eager_mode():
      var = variables.Variable([[0.0], [0.0], [0.0], [0.0]], name="tensor")
      with backprop.GradientTape() as tape:
        a = array_ops.gather(array_ops.gather(var, [0, 1]), [0, 1])
        b = array_ops.gather(array_ops.gather(var, [2, 3]), [0, 1])
        r = special_math_ops.einsum("ij,ij->i", a, b)
      g = tape.gradient(r, [var])[0]
      values = g.values if isinstance(g, indexed_slices.IndexedSlices) else g
      self.assertAllEqual(values.get_shape(), [4, 1])

  def testNegation(self):
    values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
    indices = constant_op.constant([0, 2])
    x = -indexed_slices.IndexedSlices(values, indices)
    self.assertAllEqual(x.values, [[-2, -3], [-5, -7]])
    self.assertAllEqual(x.indices, [0, 2])

  def testScalarMul(self):
    values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
    indices = constant_op.constant([0, 2])
    x = math_ops.scalar_mul(-2, indexed_slices.IndexedSlices(values, indices))
    self.assertAllEqual(x.values, [[-4, -6], [-10, -14]])
    self.assertAllEqual(x.indices, [0, 2])


@test_util.run_all_in_graph_and_eager_modes
class IndexedSlicesSpecTest(test_util.TensorFlowTestCase,
                            parameterized.TestCase):

  def assertAllTensorsEqual(self, list1, list2):
    self.assertLen(list1, len(list2))
    for (t1, t2) in zip(list1, list2):
      self.assertAllEqual(t1, t2)

  def testConstruction(self):
    spec1 = indexed_slices.IndexedSlicesSpec()
    self.assertIsNone(spec1._shape.rank)
    self.assertEqual(spec1._values_dtype, dtypes.float32)
    self.assertEqual(spec1._indices_dtype, dtypes.int64)
    self.assertIsNone(spec1._dense_shape_dtype)
    self.assertEqual(spec1._indices_shape.as_list(), [None])

    spec2 = indexed_slices.IndexedSlicesSpec([None, None], dtypes.string,
                                             dtypes.int32, dtypes.int64, [10])
    self.assertEqual(spec2._shape.as_list(), [None, None])
    self.assertEqual(spec2._values_dtype, dtypes.string)
    self.assertEqual(spec2._indices_dtype, dtypes.int32)
    self.assertEqual(spec2._dense_shape_dtype, dtypes.int64)
    self.assertEqual(spec2._indices_shape.as_list(), [10])

  def testValueType(self):
    spec1 = indexed_slices.IndexedSlicesSpec()
    self.assertEqual(spec1.value_type, indexed_slices.IndexedSlices)

  @parameterized.parameters([
      (indexed_slices.IndexedSlicesSpec(),
       (tensor_shape.TensorShape(None), dtypes.float32, dtypes.int64, None,
        tensor_shape.TensorShape([None]))),
      (indexed_slices.IndexedSlicesSpec(shape=[5, None, None]),
       (tensor_shape.TensorShape([5, None, None]), dtypes.float32,
        dtypes.int64, None, tensor_shape.TensorShape([None]))),
      (indexed_slices.IndexedSlicesSpec(
          dtype=dtypes.int32, dense_shape_dtype=dtypes.int64),
       (tensor_shape.TensorShape(None), dtypes.int32, dtypes.int64,
        dtypes.int64, tensor_shape.TensorShape([None]))),
      (indexed_slices.IndexedSlicesSpec(indices_shape=[100]),
       (tensor_shape.TensorShape(None), dtypes.float32, dtypes.int64, None,
        tensor_shape.TensorShape([100]))),
  ])  # pyformat: disable
  def testSerialize(self, spec, expected):
    serialization = spec._serialize()
    # TensorShape has an unconventional definition of equality, so we can't use
    # assertEqual directly here.  But repr() is deterministic and lossless for
    # the expected values, so we can use that instead.
    self.assertEqual(repr(serialization), repr(expected))

  @parameterized.parameters([
      (indexed_slices.IndexedSlicesSpec(dtype=dtypes.string), (
          tensor_lib.TensorSpec(None, dtypes.string),
          tensor_lib.TensorSpec([None], dtypes.int64),
      )),
      (indexed_slices.IndexedSlicesSpec(
          dtype=dtypes.string, dense_shape_dtype=dtypes.int32), (
              tensor_lib.TensorSpec(None, dtypes.string),
              tensor_lib.TensorSpec([None], dtypes.int64),
              tensor_lib.TensorSpec([None], dtypes.int32),
          )),
      (indexed_slices.IndexedSlicesSpec(
          shape=[5, 10, 15], dense_shape_dtype=dtypes.int32), (
              tensor_lib.TensorSpec([None, 10, 15], dtypes.float32),
              tensor_lib.TensorSpec([None], dtypes.int64),
              tensor_lib.TensorSpec([3], dtypes.int32),
          )),
      (indexed_slices.IndexedSlicesSpec(
          shape=[5, 10, 15], dense_shape_dtype=dtypes.int32,
          indices_shape=[20]), (
              tensor_lib.TensorSpec([20, 10, 15], dtypes.float32),
              tensor_lib.TensorSpec([20], dtypes.int64),
              tensor_lib.TensorSpec([3], dtypes.int32),
          )),
  ])
  def testComponentSpecs(self, spec, expected):
    self.assertEqual(spec._component_specs, expected)

  @parameterized.parameters([
      {
          "spec": indexed_slices.IndexedSlicesSpec(),
          "values": [3.0, 5.0],
          "indices": [5, 10]
      },
      {
          "spec":
              indexed_slices.IndexedSlicesSpec(dense_shape_dtype=dtypes.int32),
          "values": [3.0, 5.0],
          "indices": [5, 10],
          "dense_shape": [100]
      },
  ])
  def testToFromComponents(self, spec, indices, values, dense_shape=None):
    x = indexed_slices.IndexedSlices(indices, values, dense_shape)
    actual_components = spec._to_components(x)
    if dense_shape is None:
      self.assertAllTensorsEqual(actual_components, [indices, values])
    else:
      self.assertAllTensorsEqual(actual_components,
                                 [indices, values, dense_shape])
    st_reconstructed = spec._from_components(actual_components)
    self.assertAllEqual(x.indices, st_reconstructed.indices)
    self.assertAllEqual(x.values, st_reconstructed.values)
    if dense_shape is None:
      self.assertIsNone(st_reconstructed.dense_shape)
    else:
      self.assertAllEqual(x.dense_shape, st_reconstructed.dense_shape)

  @test_util.run_v1_only("IndexedSlicesValue is deprecated in v2")
  def testFromNumpyComponents(self):
    indices = np.array([3, 8])
    values = np.array([1.0, 9.0])
    dense_shape = np.array([100])

    spec1 = indexed_slices.IndexedSlicesSpec(dense_shape_dtype=dtypes.int32)
    st1 = spec1._from_components((values, indices, dense_shape))
    self.assertIsInstance(st1, indexed_slices.IndexedSlicesValue)
    self.assertAllEqual(st1.indices, indices)
    self.assertAllEqual(st1.values, values)
    self.assertAllEqual(st1.dense_shape, dense_shape)

    spec2 = indexed_slices.IndexedSlicesSpec()
    st2 = spec2._from_components((values, indices))
    self.assertIsInstance(st2, indexed_slices.IndexedSlicesValue)
    self.assertAllEqual(st2.indices, indices)
    self.assertAllEqual(st2.values, values)
    self.assertIsNone(st2.dense_shape)


class NodeDefConstructorTest(test_util.TensorFlowTestCase):

  def testNoArgs(self):
    nodedef = ops._NodeDef("None", "bar")
    self.assertProtoEquals("op: 'None' name: 'bar'", nodedef)


def _apply_op(g, *args, **kwargs):
  op = g.create_op(*args, **kwargs)
  if len(op.outputs) == 1:
    return op.outputs[0]
  else:
    return op.outputs


class OperationTest(test_util.TensorFlowTestCase):

  def testTraceback(self):
    g = ops.Graph()
    op1 = ops.Operation.from_node_def(
        ops._NodeDef("None", "op1"), g, [], [dtypes.float32_ref, dtypes.float32]
    )
    self.assertIn("testTraceback", op1.traceback[-2])

  @test_util.run_deprecated_v1
  def testNoInputs(self):
    op = test_ops.float_output_string_output(name="myop").a.op
    self.assertEqual(2, len(op.values()))
    self.assertEqual(0, len(op.inputs))
    self.assertEqual("myop", op.name)

    float_t, label_str_t = op.values()
    self.assertEqual(dtypes.float32, float_t.dtype)
    self.assertEqual(op, float_t.op)
    self.assertEqual(0, float_t.value_index)
    self.assertEqual(0, len(float_t.consumers()))
    self.assertEqual("myop", float_t._as_node_def_input())

    self.assertEqual(dtypes.string, label_str_t.dtype)
    self.assertEqual(op, label_str_t.op)
    self.assertEqual(1, label_str_t.value_index)
    self.assertEqual(0, len(label_str_t.consumers()))
    self.assertEqual("myop:1", label_str_t._as_node_def_input())

    self.assertProtoEquals("op:'FloatOutputStringOutput' name:'myop'",
                           op.node_def)

  @test_util.run_deprecated_v1
  def testNoOutputs(self):
    op1 = test_ops.float_output(name="myop1").op
    float_t, = op1.values()
    op2 = test_ops.float_input(float_t, name="myop2")
    self.assertEqual(0, len(op2.values()))
    self.assertEqual(1, len(op2.inputs))
    self.assertIs(float_t, op2.inputs[0])

    self.assertEqual(1, len(float_t.consumers()))
    self.assertEqual(op2, float_t.consumers()[0])

    self.assertProtoEquals("op:'FloatOutput' name:'myop1'", op1.node_def)
    self.assertProtoEquals("op:'FloatInput' name:'myop2' input:'myop1'",
                           op2.node_def)

  @test_util.run_deprecated_v1
  def testInputsAndOutputs(self):
    op1 = test_ops.float_output(name="myop1").op
    self.assertEqual(1, len(op1.values()))
    float1_t, = op1.values()

    op2 = test_ops.float_output_string_output(name="myop2").a.op
    self.assertEqual(2, len(op2.values()))
    float2_t, label2_str_t = op2.values()

    # Note that we consume label2_str_t twice here.
    op3 = test_ops.foo2(float1_t, label2_str_t, label2_str_t, name="myop3").d.op
    self.assertEqual(2, len(op3.values()))

    self.assertEqual(1, len(float1_t.consumers()))
    self.assertEqual(op3, float1_t.consumers()[0])

    self.assertEqual(0, len(float2_t.consumers()))

    self.assertEqual(2, len(label2_str_t.consumers()))
    self.assertEqual(op3, label2_str_t.consumers()[0])
    self.assertEqual(op3, label2_str_t.consumers()[1])

    self.assertProtoEquals("""
    op:'Foo2' name:'myop3'
    input:'myop1' input:'myop2:1' input:'myop2:1'
    """, op3.node_def)

  def testDeviceObject(self):
    op = ops.Operation.from_node_def(
        ops._NodeDef("None", "myop"), ops.Graph(), [], []
    )
    op._set_device("/job:goo/device:GPU:0")
    self.assertProtoEquals(
        "op:'None' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def)
    op = ops.Operation.from_node_def(
        ops._NodeDef("None", "op2"), ops.Graph(), [], []
    )
    op._set_device(
        pydev.DeviceSpec(
            job="muu", device_type="CPU", device_index=0))
    self.assertProtoEquals(
        "op:'None' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def)

  def testReferenceInput(self):
    g = ops.Graph()
    op1 = ops.Operation.from_node_def(
        ops._NodeDef("RefOutputFloatOutput", "op1"),
        g,
        [],
        [dtypes.float32_ref, dtypes.float32],
    )
    self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def)
    self.assertEqual([], list(op1.inputs))
    ref_t, nonref_t = op1.values()
    # NOTE(mrry): Must specify input_types to preserve ref-typed input.
    op2 = ops.Operation.from_node_def(
        ops._NodeDef("RefInputFloatInput", "op2"),
        g,
        [ref_t, nonref_t],
        [],
        input_types=[dtypes.float32_ref, dtypes.float32],
    )
    self.assertProtoEquals(
        "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'",
        op2.node_def)
    self.assertEqual([ref_t, nonref_t], list(op2.inputs))
    op3 = ops.Operation.from_node_def(
        ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], []
    )
    self.assertProtoEquals(
        "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'",
        op3.node_def)

  def testInvalidNames(self):
    g = ops.Graph()
    with self.assertRaises(ValueError):
      ops.Operation.from_node_def(ops._NodeDef("op", ""), g)
    with self.assertRaises(ValueError):
      ops.Operation.from_node_def(ops._NodeDef("op", "_invalid"), g)
    with self.assertRaises(ValueError):
      ops.Operation.from_node_def(ops._NodeDef("op", "-invalid"), g)
    with self.assertRaises(ValueError):
      ops.Operation.from_node_def(ops._NodeDef("op", "/invalid"), g)
    with self.assertRaises(ValueError):
      ops.Operation.from_node_def(ops._NodeDef("op", "invalid:0"), g)

  @test_util.run_deprecated_v1
  def testNoShapeFunction(self):
    op = test_ops.a()
    self.assertEqual(tensor_shape.unknown_shape(), op.get_shape())

  @test_util.run_in_graph_and_eager_modes
  def testConvertToTensorNestedArray(self):
    values = [[2], [3], [5], [7]]
    tensor = ops.convert_to_tensor(values)
    self.assertAllEqual((4, 1), tensor.get_shape().as_list())
    self.assertAllEqual(values, self.evaluate(tensor))

  def testShapeTuple(self):
    with self.cached_session():
      c = constant_op.constant(1)
      self.assertEqual(c._shape_tuple(), ())  # pylint: disable=protected-access

  def testConvertToTensorEager(self):
    with context.eager_mode():
      t = constant_op.constant(1)
      self.assertTrue(isinstance(t, ops.EagerTensor))
      converted = ops.convert_to_tensor(t)
      self.assertTrue(isinstance(converted, ops.EagerTensor))
      converted = ops.convert_to_tensor(1)
      self.assertTrue(isinstance(converted, ops.EagerTensor))

  @test_util.run_in_graph_and_eager_modes
  def testConvertToTensorNestedTuple(self):
    values = ((2,), (3,), (5,), (7,))
    tensor = ops.convert_to_tensor(values)
    self.assertAllEqual((4, 1), tensor.get_shape().as_list())
    self.assertAllEqual(values, self.evaluate(ops.convert_to_tensor(values)))

  @test_util.run_in_graph_and_eager_modes
  def testConvertToTensorNestedTensors(self):
    values = ((2,), (3,), (5,), (7,))
    tensor = ops.convert_to_tensor(
        [constant_op.constant(row) for row in values])
    self.assertAllEqual((4, 1), tensor.get_shape().as_list())
    self.assertAllEqual(values, self.evaluate(tensor))
    tensor = ops.convert_to_tensor(
        [[constant_op.constant(v) for v in row] for row in values])
    self.assertAllEqual((4, 1), tensor.get_shape().as_list())
    self.assertAllEqual(values, self.evaluate(tensor))

  @test_util.run_in_graph_and_eager_modes
  def testConvertToTensorNestedMix(self):
    values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7]))
    tensor = ops.convert_to_tensor(values)
    self.assertAllEqual((4, 1), tensor.get_shape().as_list())
    self.assertAllEqual(((2,), (3,), (5,), (7,)), self.evaluate(tensor))

  @test_util.run_in_graph_and_eager_modes
  def testConvertToTensorPreferred(self):
    values = [2, 3, 5, 7]
    tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32)
    self.assertEqual(dtypes.float32, tensor.dtype)

    # Convert empty tensor to anything.
    values = []
    tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
    self.assertEqual(dtypes.int64, tensor.dtype)

    # The preferred dtype is a type error and will convert to
    # float32 instead.
    values = [1.23]
    tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
    self.assertEqual(dtypes.float32, tensor.dtype)

  @test_util.run_in_graph_and_eager_modes
  def testConvertToInvalidTensorType(self):
    with self.assertRaises(TypeError):
      # Forcing an invalid dtype should fail with a type error.
      values = [1.23]
      ops.convert_to_tensor(values, dtype=dtypes.int64)

  @test_util.run_in_graph_and_eager_modes
  def testConvertToLongLongTensorType(self):
    tensor = ops.convert_to_tensor(
        # Get a numpy array of dtype NPY_LONGLONG
        np.prod(constant_op.constant([1])._shape_tuple()),
        dtype=dtypes.int64)
    self.assertEqual(dtypes.int64, tensor.dtype)

  @test_util.run_in_graph_and_eager_modes
  def testConvertToTensorFromValidTensor(self):
    tensor = constant_op.constant(413, dtype=dtypes.int64)
    converted = ops.convert_to_tensor(tensor, dtype=dtypes.int64)
    # If dtype is compatible, the returned tensor should be the same instance.
    self.assertEqual(tensor, converted)

  @test_util.run_in_graph_and_eager_modes
  def testConvertToTensorFromInvalidTensor(self):
    tensor = constant_op.constant(42.0, dtype=dtypes.float32)
    with self.assertRaises(ValueError):
      ops.convert_to_tensor(tensor, dtype=dtypes.int32)

  @test_util.run_in_graph_and_eager_modes
  def testConvertToTensorProtocol(self):
    class TensorCompatible:

      def __tf_tensor__(self, dtype=None, name=None):
        return constant_op.constant((1, 2, 3), dtype=dtype, name=name)

    tc = TensorCompatible()

    tensor = ops.convert_to_tensor(tc, dtype=dtypes.int32)
    self.assertEqual(tensor.dtype, dtypes.int32)
    self.assertAllEqual((1, 2, 3), self.evaluate(tensor))

  @test_util.run_deprecated_v1
  def testNoConvert(self):
    # Operation cannot be converted to Tensor.
    op = gen_control_flow_ops.no_op()
    with self.assertRaisesRegex(TypeError,
                                "can't convert Operation '.+' to Tensor"):
      ops.convert_to_tensor(op)

  def testStr(self):
    node_def = ops._NodeDef("None", "op1")
    op = ops.Operation.from_node_def(
        node_def, ops.Graph(), [], [dtypes.float32]
    )
    self.assertEqual(str(node_def), str(op))

  def testRepr(self):
    op = ops.Operation.from_node_def(
        ops._NodeDef("None", "op1"), ops.Graph(), [], [dtypes.float32]
    )
    self.assertEqual("<tf.Operation 'op1' type=None>", repr(op))

  @test_util.run_deprecated_v1
  def testGetAttr(self):
    op = test_ops.default_attrs()
    self.assertEqual(op.get_attr("string_val"), b"abc")
    self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""])
    self.assertEqual(op.get_attr("int_val"), 123)
    self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3])
    self.assertEqual(op.get_attr("float_val"), 10.0)
    self.assertEqual(op.get_attr("float_list_val"), [10.0])
    self.assertEqual(op.get_attr("bool_val"), True)
    self.assertEqual(op.get_attr("bool_list_val"), [True, False])
    self.assertEqual(op.get_attr("shape_val"),
                     tensor_shape.as_shape([2, 1]).as_proto())
    self.assertEqual(op.get_attr("shape_list_val"),
                     [tensor_shape.as_shape([]).as_proto(),
                      tensor_shape.as_shape([1]).as_proto()])
    self.assertEqual(op.get_attr("tensor_val"),
                     tensor_util.make_tensor_proto(1, dtypes.int32))
    self.assertEqual(op.get_attr("tensor_list_val"),
                     [tensor_util.make_tensor_proto(1, dtypes.int32)])

    type_val = op.get_attr("type_val")
    # First check that type_val is a DType, because the assertEqual will work
    # no matter what since DType overrides __eq__
    self.assertIsInstance(type_val, dtypes.DType)
    self.assertEqual(type_val, dtypes.int32)

    type_list_val = op.get_attr("type_list_val")
    self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val))
    self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32])

    @function.Defun(dtypes.float32, func_name="MyFunc")
    def func(x):
      return x

    op = test_ops.func_attr(func)
    self.assertEqual(op.get_attr("f"),
                     attr_value_pb2.NameAttrList(name="MyFunc"))

    # Try fetching missing attr
    with self.assertRaisesRegex(
        ValueError, "Operation 'FuncAttr' has no attr named 'FakeAttr'."):
      op.get_attr("FakeAttr")

  # TODO(b/65162920): remove this test when users who are directly mutating the
  # node_def have been updated to proper usage.
  @test_util.run_deprecated_v1
  def testSetAttr(self):
    op = test_ops.int_attr().op
    op._set_attr("foo", attr_value_pb2.AttrValue(i=2))
    # TODO(skyewm): add node_def check
    self.assertEqual(op.get_attr("foo"), 2)

  @test_util.run_v2_only
  def testSetFullType(self):
    @def_function.function
    def test_fn():
      ds = dataset_ops.Dataset.range(3)._variant_tensor

      ds.op.experimental_set_type(
          full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_PRODUCT))

      self.assertEqual(ds.op.node_def.experimental_type.type_id,
                       full_type_pb2.TFT_PRODUCT)

    test_fn()

  # TODO(nolivia): test all error cases
  def testAddControlInput(self):
    with ops.Graph().as_default():
      x = constant_op.constant(1).op
      y = constant_op.constant(2).op
      z = constant_op.constant(3).op
    z._add_control_input(x)  # pylint: disable=protected-access
    self.assertEqual(z.control_inputs, [x])
    z._add_control_input(x)  # pylint: disable=protected-access
    self.assertEqual(z.control_inputs, [x])
    z._add_control_inputs([x, y, y])  # pylint: disable=protected-access
    self.assertEqual(z.control_inputs, [x, y])
    self.assertEqual(x._control_outputs, [z])

  @test_util.run_deprecated_v1
  def testRemoveAllControlInputs(self):
    a = constant_op.constant(1)
    with ops.control_dependencies([a]):
      b = constant_op.constant(2)
    c = constant_op.constant(3)
    d = constant_op.constant(4)
    e = constant_op.constant(5)
    with ops.control_dependencies([a, c]):
      f = d + e

    self.assertEqual(a.op.control_inputs, [])
    self.assertEqual(b.op.control_inputs, [a.op])
    self.assertEqual(f.op.control_inputs, [a.op, c.op])

    a.op._remove_all_control_inputs()  # pylint: disable=protected-access
    self.assertEqual(a.op.control_inputs, [])

    b.op._remove_all_control_inputs()  # pylint: disable=protected-access
    self.assertEqual(b.op.control_inputs, [])

    f.op._remove_all_control_inputs()  # pylint: disable=protected-access
    self.assertEqual(f.op.control_inputs, [])
    self.assertEqual(list(f.op.inputs), [d, e])

  @test_util.run_deprecated_v1
  def testControlInputCycle(self):
    graph = ops.Graph()
    with graph.as_default():
      z = constant_op.constant(0)
      x = constant_op.constant(1)
      y = constant_op.constant(2)
      y.op._add_control_input(z.op)  # pylint: disable=protected-access
      y.op._add_control_input(x.op)  # pylint: disable=protected-access
      x.op._add_control_input(y.op)  # pylint: disable=protected-access
    with self.session(graph=graph) as sess:
      with self.assertRaisesRegex(
          errors.InvalidArgumentError,
          "Graph is invalid, contains a cycle with 2 nodes"):
        self.evaluate(x)

  def testUpdateInput(self):
    g = ops.Graph()
    with g.as_default():
      x = constant_op.constant(1)
      y = constant_op.constant(2)
      z = x + y

    z.op._update_input(0, y)  # pylint: disable=protected-access
    self.assertEqual(list(z.op.inputs), [y, y])
    self.assertEqual(x.consumers(), [])
    self.assertEqual(y.consumers(), [z.op, z.op])
    with session.Session(graph=g) as sess:
      self.assertEqual(self.evaluate(z), 4)

    z.op._update_input(0, x)  # pylint: disable=protected-access
    self.assertEqual(list(z.op.inputs), [x, y])
    self.assertEqual(x.consumers(), [z.op])
    self.assertEqual(y.consumers(), [z.op])
    with session.Session(graph=g) as sess:
      self.assertEqual(self.evaluate(z), 3)

    z.op._update_input(1, y)  # pylint: disable=protected-access
    self.assertEqual(list(z.op.inputs), [x, y])
    self.assertEqual(x.consumers(), [z.op])
    self.assertEqual(y.consumers(), [z.op])
    with session.Session(graph=g) as sess:
      self.assertEqual(self.evaluate(z), 3)

  def testUpdateInputGraphError(self):
    g_0 = ops.Graph()
    g_1 = ops.Graph()
    with g_0.as_default():
      x = constant_op.constant(1)
    with g_1.as_default():
      y = constant_op.constant(2)
      z = y * 2
      with self.assertRaisesRegex(ValueError, "must be from the same graph"):
        z.op._update_input(0, x)  # pylint: disable=protected-access

  def testUpdateInputTypeError(self):
    g = ops.Graph()
    with g.as_default():
      w = constant_op.constant(0)
      x = constant_op.constant("")
      y = constant_op.constant(1)
      z = y + w
      z.op._update_input(0, x)  # pylint: disable=protected-access
    with session.Session(graph=g) as sess:
      with self.assertRaisesRegex(
          errors.InvalidArgumentError,
          "Input 0 of node add was passed string from Const_1:0 incompatible "
          "with expected int32"):
        self.evaluate(z)

  def testUpdateInputShapeError(self):
    g = ops.Graph()
    with g.as_default():
      w = constant_op.constant(2, shape=[3, 1])
      x = constant_op.constant(0, shape=[3, 1])
      y = constant_op.constant(1, shape=[2, 2])
      z = w + x
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"):
      z.op._update_input(0, y)  # pylint: disable=protected-access

  def testUpdateInputOutOfRange(self):
    g = ops.Graph()
    with g.as_default():
      x = constant_op.constant(1)
    with self.assertRaisesRegex(
        errors.OutOfRangeError,
        r"Cannot update edge. Input index \[1\] is greater than the number of "
        r"total inputs \[0\]."):
      x.op._update_input(1, x)  # pylint: disable=protected-access

  @test_util.enable_control_flow_v2
  @test_util.run_v1_only("b/120545219")
  def testAddWhileInput(self):

    @def_function.function
    def test():
      output = while_loop.while_loop(lambda x: x < 3, lambda x: x + 1, [1])
      while_op = output.op
      self.assertEqual(while_op.type, "StatelessWhile")
      orig_num_inputs = len(while_op.inputs)

      # Make sure we can handle the while op having a control input.
      while_op._add_control_input(constant_op.constant(0).op)

      new_input1 = constant_op.constant(1.0)
      new_input2 = constant_op.constant(True)

      # Clear output shapes to bypass shape checking.
      while_op._set_shape_list_attr("output_shapes", [])
      while_op._set_type_list_attr("T", [t.dtype for t in while_op.inputs] +
                                   [new_input1.dtype, new_input2.dtype])

      while_op._add_while_inputs([new_input1, new_input2])
      # Can't add an edge beyond what's specified by "T"
      with self.assertRaises(errors.OutOfRangeError):
        while_op._add_while_inputs([new_input2])
      self.assertLen(while_op.inputs, orig_num_inputs + 2)  # pylint: disable=g-deprecated-assert

      test()

  @test_util.run_deprecated_v1
  def testOpDef(self):
    x = constant_op.constant(0)
    y = constant_op.constant(1)
    z = x + y

    self.assertEqual(x.op.op_def.name, "Const")
    self.assertLen(x.op.op_def.input_arg, 0)
    self.assertLen(x.op.op_def.output_arg, 1)

    self.assertRegex(z.op.op_def.name, "Add(V2)?")
    self.assertLen(z.op.op_def.input_arg, 2)
    self.assertLen(z.op.op_def.output_arg, 1)

  def testInputFromDifferentGraphError(self):
    g_0 = ops.Graph()
    g_1 = ops.Graph()
    with g_0.as_default():
      x = constant_op.constant(1)
    with g_1.as_default():
      y = constant_op.constant(2)
      with self.assertRaisesRegex(ValueError, "must be from the same graph"):
        y * x  # pylint: disable=pointless-statement

  def testInputsAreImmutable(self):
    g = ops.Graph()
    with g.as_default():
      x = test_ops.int_output()
      op = test_ops.int_input_int_output(x, name="myop").op
    with self.assertRaisesRegex(AttributeError,
                                "'tuple' object has no attribute 'append'"):
      op.inputs.append(None)


class CreateOpTest(test_util.TensorFlowTestCase):

  def testNodeDefArgs(self):
    g = ops.Graph()
    op1 = g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
    with g.device("/device:GPU:0"):
      op2 = g.create_op(
          "FloatOutputStringOutput", [], [dtypes.float32, dtypes.string], None,
          name="myop2")
    op3 = g.create_op(
        "Foo3",
        [list(op1.values())[0], list(op2.values())[1], list(op2.values())[0]],
        [dtypes.float32, dtypes.int32],
        None,
        name="myop3")
    self.assertDeviceEqual(None, op1.device)
    self.assertDeviceEqual("/device:GPU:0", op2.device)
    self.assertDeviceEqual(None, op3.device)
    self.assertProtoEquals("name:'myop1' op:'FloatOutput'", op1.node_def)
    self.assertProtoEquals(
        "name:'myop2' op:'FloatOutputStringOutput' device:'/device:GPU:0'",
        op2.node_def)
    self.assertProtoEquals(
        "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo3'",
        op3.node_def)

  def testReferenceInput(self):
    g = ops.Graph()
    op1 = g.create_op(
        "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32],
        name="op1")
    self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def)
    ref_t, nonref_t = op1.values()
    # NOTE(mrry): Must specify input_types to preserve ref-typed input.
    op2 = g.create_op(
        "RefInputFloatInput", [ref_t, nonref_t], [],
        input_types=[dtypes.float32_ref, dtypes.float32],
        name="op2")
    self.assertProtoEquals(
        "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'",
        op2.node_def)
    op3 = g.create_op("TwoFloatInputs", [ref_t, nonref_t], [], name="op3")
    self.assertProtoEquals(
        "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'",
        op3.node_def)

  def testFinalized(self):
    g = ops.Graph()
    g.finalize()
    with self.assertRaises(RuntimeError):
      g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")

    # Test unfinalize.
    g._unsafe_unfinalize()
    g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")


# NOTE(skyewm): these cases test the private Graph._create_op_from_tf_operation
# method. Arguably we should only test the public APIs that depend on this
# method. However, this logic is complex and tricky, and it can be difficult to
# ascertain if we have adequate coverage (e.g. a graph may run successfully if
# the control flow context isn't set properly, but a more complicated use case
# that might not be obvious to test will fail). Thus we instead explicitly test
# the low-level behavior.
class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase):

  @test_util.run_deprecated_v1
  def testBasic(self):
    g = ops.Graph()
    with g.as_default():
      x = test_ops.int_output()
      c_op = ops._create_c_op(
          g, ops._NodeDef("IntInputIntOutput", "myop"), [x], [])
      op = g._create_op_from_tf_operation(c_op)

    self.assertEqual(op.name, "myop")
    self.assertEqual(op.type, "IntInputIntOutput")
    self.assertLen(op.outputs, 1)
    self.assertEqual(op.outputs[0].shape, tensor_shape.unknown_shape())
    self.assertEqual(list(op.inputs), [x])
    self.assertEqual(op.control_inputs, [])
    self.assertEqual(op.graph, g)
    self.assertEqual(x.consumers(), [op])
    self.assertIsNotNone(op.traceback)
    self.assertIn("testBasic", op.traceback[-1])
    self.assertEqual(g.get_operation_by_name("myop"), op)
    self.assertEqual(g.get_tensor_by_name("myop:0"), op.outputs[0])

  def testShape(self):
    g = ops.Graph()
    with g.as_default():
      x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
      c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), [x], [])
      op = g._create_op_from_tf_operation(c_op)

    self.assertEqual(op.name, "myop")
    self.assertEqual(op.type, "Identity")
    self.assertLen(op.outputs, 1)
    self.assertEqual(op.outputs[0].shape, tensor_shape.TensorShape([2, 3]))

  def testUniqueName(self):
    g = ops.Graph()
    with g.as_default():
      c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], [])
      c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], [])
      op = g._create_op_from_tf_operation(c_op)
      op2 = g._create_op_from_tf_operation(c_op2)

      # Create ops with same names as op1 and op2. We expect the new names to be
      # uniquified.
      op3 = test_ops.int_output(name="myop").op
      op4 = test_ops.int_output(name="myop_1").op

    self.assertEqual(op.name, "myop")
    self.assertEqual(op2.name, "myop_1")
    self.assertEqual(op3.name, "myop_2")
    self.assertEqual(op4.name, "myop_1_1")

  @test_util.run_v1_only("b/120545219")
  def testCond(self):
    g = ops.Graph()
    with g.as_default():
      x = test_ops.int_output()

      def true_fn():
        ops._create_c_op(ops.get_default_graph(),
                         ops._NodeDef("IntInput", "cond/myop"), [x], [])
        new_ops = g._add_new_tf_operations()
        self.assertLen(new_ops, 1)
        return x

      cond.cond(x < 10, true_fn, lambda: x)

    op = g.get_operation_by_name("cond/myop")
    self.assertIsNotNone(op)
    self.assertEqual(op.name, "cond/myop")
    self.assertEqual(op.type, "IntInput")
    self.assertEqual(op.outputs, [])
    op_input = op.inputs[0].op
    self.assertEqual(op_input.type, "Switch")
    self.assertEqual(op_input.inputs[0], x)
    self.assertEqual(op.graph, g)
    # pylint: disable=protected-access
    self.assertIsNotNone(op._get_control_flow_context())
    self.assertEqual(op._get_control_flow_context().name,
                     "cond/cond_text")
    # pylint: enable=protected-access

  @test_util.run_v1_only("b/120545219")
  def testWhileLoop(self):
    g = ops.Graph()
    with g.as_default():
      x = test_ops.int_output()

      def body(i):
        ops._create_c_op(ops.get_default_graph(),
                         ops._NodeDef("IntInput", "myloop/myop"), [x], [])
        new_ops = g._add_new_tf_operations()
        self.assertLen(new_ops, 1)
        return i

      while_loop.while_loop(lambda i: i < 10, body, [0], name="myloop")

    op = g.get_operation_by_name("myloop/myop")
    self.assertIsNotNone(op)
    self.assertEqual(op.name, "myloop/myop")
    self.assertEqual(op.type, "IntInput")
    self.assertEqual(op.outputs, [])
    op_input = op.inputs[0].op
    self.assertEqual(op_input.type, "Enter")
    self.assertEqual(list(op_input.inputs), [x])
    self.assertEqual(op.graph, g)
    # pylint: disable=protected-access
    self.assertIsNotNone(op._get_control_flow_context())
    self.assertEqual(op._get_control_flow_context().name,
                     "myloop/while_context")
    # pylint: enable=protected-access

  @test_util.run_v1_only("b/120545219")
  def testWhileLoopWithInternalControlDep(self):
    g = ops.Graph()
    with g.as_default():
      x = test_ops.int_output()

      def body(i):
        c = constant_op.constant(1.0, name="c")
        ops._create_c_op(ops.get_default_graph(),
                         ops._NodeDef("IntInput", "myloop/myop"), [x], [])
        with ops.control_dependencies([c]):
          new_ops = g._add_new_tf_operations()
          self.assertLen(new_ops, 1)
        return i

      while_loop.while_loop(lambda i: i < 10, body, [0], name="myloop")

    op = g.get_operation_by_name("myloop/myop")
    self.assertIsNotNone(op)
    c = g.get_operation_by_name("myloop/c")
    self.assertIsNotNone(c)
    # Internal control dep is preserved
    self.assertEqual(op.control_inputs, [c])

  @test_util.run_v1_only("b/120545219")
  def testWhileLoopWithExternalControlDep(self):
    g = ops.Graph()
    with g.as_default():
      x = test_ops.int_output()
      c = constant_op.constant(1.0)

      def body(i):
        ops._create_c_op(ops.get_default_graph(),
                         ops._NodeDef("IntInput", "myloop/myop"), [x], [])
        with ops.control_dependencies([c]):
          new_ops = g._add_new_tf_operations()
          self.assertLen(new_ops, 1)
        return i

      while_loop.while_loop(lambda i: i < 10, body, [0], name="myloop")

    op = g.get_operation_by_name("myloop/myop")
    self.assertIsNotNone(op)
    # External control dep is removed and replaced with internal control dep
    self.assertNotEqual(op.control_inputs[0], c.op)
    self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())


class ApplyOpTest(test_util.TensorFlowTestCase):

  def testNodeDefArgs(self):
    g = ops.Graph()
    t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1")
    with g.device("/device:GPU:0"):
      t2 = _apply_op(
          g, "TwoIntOutputs", [], [dtypes.int32, dtypes.int32], name="myop2")
    t3 = _apply_op(
        g,
        "Foo1", [t1, t2[1], t2[0]], [dtypes.float32, dtypes.int32],
        name="myop3")
    self.assertTrue(isinstance(t1, tensor_lib.Tensor))
    self.assertTrue(isinstance(t2, list))
    self.assertTrue(isinstance(t3, list))
    self.assertTrue(isinstance(t3[0], tensor_lib.Tensor))
    self.assertEqual("myop1", t1._as_node_def_input())
    self.assertEqual("myop2", t2[0]._as_node_def_input())
    self.assertEqual("myop2:1", t2[1]._as_node_def_input())
    self.assertEqual("myop3", t3[0]._as_node_def_input())
    # Validate that we got the right ops as well
    self.assertProtoEquals("name:'myop1' op:'FloatOutput'", t1.op.node_def)
    self.assertProtoEquals(
        "name:'myop2' op:'TwoIntOutputs' device:'/device:GPU:0'",
        t2[0].op.node_def)
    self.assertProtoEquals(
        "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo1'",
        t3[0].op.node_def)

  def testReferenceInput(self):
    g = ops.Graph()
    ref_t, nonref_t = _apply_op(
        g, "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32],
        name="op1")
    self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'",
                           ref_t.op.node_def)
    # NOTE(mrry): Must specify input_types to preserve ref-typed input.
    out_2 = _apply_op(
        g,
        "RefInputFloatInputIntOutput", [ref_t, nonref_t], [dtypes.int32],
        input_types=[dtypes.float32_ref, dtypes.float32],
        name="op2")
    self.assertProtoEquals(
        "op:'RefInputFloatInputIntOutput' name:'op2' input:'op1' input:'op1:1'",
        out_2.op.node_def)
    out_3 = _apply_op(
        g, "TwoFloatInputsIntOutput", [ref_t, nonref_t], [dtypes.int32],
        name="op3")
    self.assertProtoEquals(
        "op:'TwoFloatInputsIntOutput' name:'op3' input:'op1' input:'op1:1'",
        out_3.op.node_def)


class NameStackTest(test_util.TensorFlowTestCase):

  def testBasics(self):
    g = ops.Graph()
    self.assertEqual("foo", g.unique_name("foo", mark_as_used=False))
    self.assertEqual("foo", g.unique_name("foo", mark_as_used=False))
    self.assertEqual("foo", g.unique_name("foo"))
    self.assertEqual("foo_1", g.unique_name("foo", mark_as_used=False))
    self.assertEqual("foo_1", g.unique_name("foo"))
    self.assertEqual("foo_2", g.unique_name("foo", mark_as_used=False))
    self.assertEqual("foo_2", g.unique_name("foo"))
    self.assertEqual("foo_1_1", g.unique_name("foo_1", mark_as_used=False))
    self.assertEqual("foo_1_1", g.unique_name("foo_1"))
    self.assertEqual("foo_1_2", g.unique_name("foo_1", mark_as_used=False))
    self.assertEqual("foo_1_2", g.unique_name("foo_1"))
    self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2", mark_as_used=False))
    self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2"))
    with g.name_scope("bar"):
      self.assertEqual("bar/foo", g.unique_name("foo", mark_as_used=False))
      self.assertEqual("bar/foo", g.unique_name("foo"))
      self.assertEqual("bar/foo_1", g.unique_name("foo", mark_as_used=False))
      self.assertEqual("bar/foo_1", g.unique_name("foo"))
      with g.name_scope(None):
        self.assertEqual("foo_3", g.unique_name("foo", mark_as_used=False))
        self.assertEqual("foo_3", g.unique_name("foo"))
      with g.name_scope("baz"):
        self.assertEqual(
            "bar/baz/foo", g.unique_name(
                "foo", mark_as_used=False))
        self.assertEqual("bar/baz/foo", g.unique_name("foo"))
        self.assertEqual(
            "bar/baz/foo_1", g.unique_name(
                "foo", mark_as_used=False))
        self.assertEqual("bar/baz/foo_1", g.unique_name("foo"))
      with g.name_scope("baz"):
        self.assertEqual(
            "bar/baz_1/foo", g.unique_name(
                "foo", mark_as_used=False))
        self.assertEqual("bar/baz_1/foo", g.unique_name("foo"))
        self.assertEqual(
            "bar/baz_1/foo_1", g.unique_name(
                "foo", mark_as_used=False))
        self.assertEqual("bar/baz_1/foo_1", g.unique_name("foo"))
    with g.name_scope("quux"):
      self.assertEqual("quux/foo", g.unique_name("foo", mark_as_used=False))
      self.assertEqual("quux/foo", g.unique_name("foo"))
    with g.name_scope("bar"):
      with g.name_scope("baz"):
        self.assertEqual(
            "bar_1/baz/foo", g.unique_name(
                "foo", mark_as_used=False))
        self.assertEqual("bar_1/baz/foo", g.unique_name("foo"))
    self.assertEqual("foo_4", g.unique_name("foo", mark_as_used=False))
    self.assertEqual("foo_4", g.unique_name("foo"))
    self.assertEqual("bar_2", g.unique_name("bar", mark_as_used=False))
    self.assertEqual("bar_2", g.unique_name("bar"))

  def testBackslashAndDashRegex(self):
    # GitHub issue 39019, all should pass
    g = ops.Graph()
    with g.name_scope("n_CatCntc-campaign\\c_campaign"):
      pass
    with g.name_scope("foo"):
      with g.name_scope("n_CatCntc-campaign\\c_campaign"):
        pass
    with g.name_scope("n_CatCntc-campaign\\c_campaign"):
      with g.name_scope("foo"):
        pass

  @test_util.run_deprecated_v1
  def testNameAndVariableScope(self):
    with self.cached_session() as sess:
      with sess.graph.name_scope("l0"):
        with variable_scope.variable_scope("l1"):
          with sess.graph.name_scope("l1") as scope:
            self.assertEqual("l0/l1/l1/", scope)
            self.assertEqual(
                "l0/l1/l1/foo",
                sess.graph.unique_name(
                    "foo", mark_as_used=False))
            self.assertEqual("l0/l1/l1/foo", sess.graph.unique_name("foo"))
          with sess.graph.name_scope("l2") as scope:
            self.assertEqual("l0/l1/l2/", scope)
            self.assertEqual(
                "l0/l1/l2/foo",
                sess.graph.unique_name(
                    "foo", mark_as_used=False))
            self.assertEqual("l0/l1/l2/foo", sess.graph.unique_name("foo"))

  def testOutOfOrderUniqueName(self):
    g = ops.Graph()
    self.assertEqual("foo_2", g.unique_name("foo_2"))
    self.assertEqual("foo", g.unique_name("foo"))
    self.assertEqual("foo_1", g.unique_name("foo"))
    self.assertEqual("foo_3", g.unique_name("foo"))

  def testUniqueNameCaseInsensitivity(self):
    g = ops.Graph()
    self.assertEqual("foo", g.unique_name("foo"))
    self.assertEqual("Foo_1", g.unique_name("Foo"))
    with g.name_scope("bar"):
      self.assertEqual("bar/foo", g.unique_name("foo"))
    with g.name_scope("Bar"):
      self.assertEqual("Bar_1/foo", g.unique_name("foo"))

  def testInvalidNameRaisesError(self):
    g = ops.Graph()
    with g.name_scope(""):  # Should not raise
      pass
    with g.name_scope("foo/"):  # Should not raise
      with g.name_scope("_bar"):  # Should not raise
        pass
    with self.assertRaises(ValueError):
      with g.name_scope("foo:0"):
        pass
    with self.assertRaises(ValueError):
      with g.name_scope("_bar"):
        pass

  def testEmptyScopeEdgeCases(self):
    g = ops.Graph()
    self.assertEqual("", g.get_name_scope())
    with g.name_scope("") as scope:
      self.assertEqual("", scope)
      self.assertEqual("", g.get_name_scope())
    with g.name_scope(None) as scope:
      self.assertEqual("", scope)
      self.assertEqual("", g.get_name_scope())
    with g.name_scope("foo") as scope:
      self.assertEqual("foo/", scope)
      self.assertEqual("foo", g.get_name_scope())
      with g.name_scope("") as scope:
        self.assertEqual("", scope)
        self.assertEqual("", g.get_name_scope())
      with g.name_scope(None) as scope:
        self.assertEqual("", scope)
        self.assertEqual("", g.get_name_scope())


class NameTest(test_util.TensorFlowTestCase):

  def testGenerateName(self):
    g = ops.Graph()
    op0 = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
    self.assertEqual("TwoFloatOutputs", op0.name)
    self.assertEqual("TwoFloatOutputs:0", op0.outputs[0].name)
    self.assertEqual("TwoFloatOutputs:1", op0.outputs[1].name)

    op1 = g.create_op("FloatOutput", [], [dtypes.float32])
    self.assertEqual("FloatOutput", op1.name)
    self.assertEqual("FloatOutput:0", op1.outputs[0].name)

    op2 = g.create_op("FloatOutput", [], [dtypes.float32])
    self.assertEqual("FloatOutput_1", op2.name)
    self.assertEqual("FloatOutput_1:0", op2.outputs[0].name)

    op3 = g.create_op("FloatOutput", [], [dtypes.float32], name="my_op")
    self.assertEqual("my_op", op3.name)
    self.assertEqual("my_op:0", op3.outputs[0].name)

  def testNameScope(self):
    g = ops.Graph()

    with g.name_scope("foo") as foo:
      self.assertEqual("foo/", foo)
      with g.name_scope("foo2") as foo2:
        self.assertEqual("foo/foo2/", foo2)
      with g.name_scope(None) as empty1:
        self.assertEqual("", empty1)
        with g.name_scope("foo3") as foo3:
          self.assertEqual("foo3/", foo3)
      with g.name_scope("") as empty2:
        self.assertEqual("", empty2)

    self.assertEqual("FloatOutput",
                     g.create_op("FloatOutput", [], [dtypes.float32]).name)
    with g.name_scope("bar") as scope:
      self.assertEqual("bar/FloatOutput",
                       g.create_op("FloatOutput", [], [dtypes.float32]).name)
      self.assertEqual("bar/FloatOutput_1",
                       g.create_op("FloatOutput", [], [dtypes.float32]).name)
      # If you use the value from "with .. as", that values is used as-is.
      self.assertEqual(
          "bar", g.create_op(
              "FloatOutput", [], [dtypes.float32], name=scope).name)
    with g.name_scope("baz") as scope:
      with g.name_scope("quux"):
        self.assertEqual("baz/quux/FloatOutput",
                         g.create_op("FloatOutput", [], [dtypes.float32]).name)
      # If you use the value from the enclosing "with .. as", nothing is pushed.
      with g.name_scope(scope):
        self.assertEqual("baz/FloatOutput",
                         g.create_op("FloatOutput", [], [dtypes.float32]).name)
        self.assertEqual(
            "baz", g.create_op(
                "FloatOutput", [], [dtypes.float32], name=scope).name)
        self.assertEqual(
            "trailing",
            g.create_op(
                "FloatOutput", [], [dtypes.float32], name="trailing/").name)
    with g.name_scope("bar"):
      self.assertEqual("bar_1/FloatOutput",
                       g.create_op("FloatOutput", [], [dtypes.float32]).name)
    with g.name_scope("bar/"):
      self.assertEqual("bar/FloatOutput_2",
                       g.create_op("FloatOutput", [], [dtypes.float32]).name)


class DeviceTest(test_util.TensorFlowTestCase):

  def testNoDevice(self):
    g = ops.Graph()
    op = g.create_op("FloatOutput", [], [dtypes.float32])
    self.assertDeviceEqual(None, op.device)
    gd = g.as_graph_def()
    self.assertProtoEqualsVersion("""
      node { name: "FloatOutput" op: "FloatOutput" }
    """, gd)

  def testEagerBackingDevice(self):
    with context.eager_mode():
      with ops.device("/device:CPU:0"):
        t = constant_op.constant(1.0)
        self.assertRegex(t.device, "/device:CPU:0")
        self.assertRegex(t.backing_device, "/device:CPU:0")

  def testDevicePartialString(self):
    g = ops.Graph()
    with g.device("/job:worker/replica:2"):
      g.create_op("FloatOutput", [], [dtypes.float32])
    gd = g.as_graph_def()
    self.assertProtoEqualsVersion("""
      node { name: "FloatOutput" op: "FloatOutput"
             device: "/job:worker/replica:2" }
    """, gd)

  def testDeviceFull(self):
    g = ops.Graph()
    with g.device(
        pydev.DeviceSpec(
            job="worker", replica=2, task=0, device_type="CPU",
            device_index=3)):
      g.create_op("FloatOutput", [], [dtypes.float32])
    gd = g.as_graph_def()
    self.assertProtoEqualsVersion("""
      node { name: "FloatOutput" op: "FloatOutput"
             device: "/job:worker/replica:2/task:0/device:CPU:3" }
    """, gd)

  def testNesting(self):
    g = ops.Graph()
    with g.device("/job:worker/replica:2"):
      g.create_op("FloatOutput", [], [dtypes.float32])
      with g.device("/job:worker/replica:3/task:0"):
        g.create_op("FloatOutput", [], [dtypes.float32])
      g.create_op("FloatOutput", [], [dtypes.float32])
    gd = g.as_graph_def()
    self.assertProtoEqualsVersion("""
      node { name: "FloatOutput" op: "FloatOutput"
             device: "/job:worker/replica:2" }
      node { name: "FloatOutput_1" op: "FloatOutput"
             device: "/job:worker/replica:3/task:0" }
      node { name: "FloatOutput_2" op: "FloatOutput"
             device: "/job:worker/replica:2" }
    """, gd)

  def testNestingString(self):
    g = ops.Graph()
    with g.device("/job:worker/replica:2"):
      g.create_op("FloatOutput", [], [dtypes.float32])
      with g.device("/job:worker/replica:3/task:0"):
        g.create_op("FloatOutput", [], [dtypes.float32])
      g.create_op("FloatOutput", [], [dtypes.float32])
    gd = g.as_graph_def()
    self.assertProtoEqualsVersion("""
      node { name: "FloatOutput" op: "FloatOutput"
             device: "/job:worker/replica:2" }
      node { name: "FloatOutput_1" op: "FloatOutput"
             device: "/job:worker/replica:3/task:0" }
      node { name: "FloatOutput_2" op: "FloatOutput"
             device: "/job:worker/replica:2" }
    """, gd)

  def testNestingOverrideGpuCpu(self):
    g = ops.Graph()
    with g.device("/job:worker/replica:2/device:CPU:1"):
      g.create_op("FloatOutput", [], [dtypes.float32])
      with g.device("/job:worker/replica:2/device:GPU:2"):
        g.create_op("FloatOutput", [], [dtypes.float32])
      g.create_op("FloatOutput", [], [dtypes.float32])
    gd = g.as_graph_def()
    self.assertProtoEqualsVersion("""
      node { name: "FloatOutput" op: "FloatOutput"
             device: "/job:worker/replica:2/device:CPU:1"  }
      node { name: "FloatOutput_1" op: "FloatOutput"
             device: "/job:worker/replica:2/device:GPU:2" }
      node { name: "FloatOutput_2" op: "FloatOutput"
             device: "/job:worker/replica:2/device:CPU:1" }
    """, gd)

  def testNestingWithMergeDeviceFunction(self):
    g = ops.Graph()

    with g.device(pydev.merge_device("/device:GPU:0")):
      g.create_op("FloatOutput", [], [dtypes.float32])
      with g.device(pydev.merge_device("/job:worker")):
        g.create_op("FloatOutput", [], [dtypes.float32])
        with g.device(pydev.merge_device("/device:CPU:0")):
          g.create_op("FloatOutput", [], [dtypes.float32])
          with g.device(pydev.merge_device("/job:ps")):
            g.create_op("FloatOutput", [], [dtypes.float32])
            with g.device(pydev.merge_device(None)):
              g.create_op("FloatOutput", [], [dtypes.float32])

    gd = g.as_graph_def()
    self.assertProtoEqualsVersion("""
      node { name: "FloatOutput" op: "FloatOutput"
             device: "/device:GPU:0" }
      node { name: "FloatOutput_1" op: "FloatOutput"
             device: "/job:worker/device:GPU:0" }
      node { name: "FloatOutput_2" op: "FloatOutput"
             device: "/job:worker/device:CPU:0" }
      node { name: "FloatOutput_3" op: "FloatOutput"
             device: "/job:ps/device:CPU:0" }
      node { name: "FloatOutput_4" op: "FloatOutput"
             device: "/job:ps/device:CPU:0" }
    """, gd)

  def testNestingWithDeviceStrings(self):
    g = ops.Graph()

    with g.device("/device:GPU:0"):
      g.create_op("FloatOutput", [], [dtypes.float32])
      with g.device("/job:worker"):
        g.create_op("FloatOutput", [], [dtypes.float32])
        with g.device("/device:CPU:0"):
          g.create_op("FloatOutput", [], [dtypes.float32])
          with g.device("/job:ps"):
            g.create_op("FloatOutput", [], [dtypes.float32])
            with g.device(""):
              g.create_op("FloatOutput", [], [dtypes.float32])

    gd = g.as_graph_def()
    self.assertProtoEqualsVersion("""
      node { name: "FloatOutput" op: "FloatOutput"
             device: "/device:GPU:0" }
      node { name: "FloatOutput_1" op: "FloatOutput"
             device: "/job:worker/device:GPU:0" }
      node { name: "FloatOutput_2" op: "FloatOutput"
             device: "/job:worker/device:CPU:0" }
      node { name: "FloatOutput_3" op: "FloatOutput"
             device: "/job:ps/device:CPU:0" }
      node { name: "FloatOutput_4" op: "FloatOutput"
             device: "/job:ps/device:CPU:0" }
    """, gd)

  def testNestingWithDeviceStringWildcard(self):
    g = ops.Graph()

    with g.device("/device:GPU:7"):
      g.create_op("FloatOutput", [], [dtypes.float32])
      with g.device("/device:GPU:*"):
        g.create_op("FloatOutput", [], [dtypes.float32])

    with g.device("/device:CPU:*"):
      g.create_op("FloatOutput", [], [dtypes.float32])
      with g.device("/device:CPU:5"):
        g.create_op("FloatOutput", [], [dtypes.float32])

    gd = g.as_graph_def()
    self.assertProtoEqualsVersion("""
      node { name: "FloatOutput" op: "FloatOutput"
             device: "/device:GPU:7" }
      node { name: "FloatOutput_1" op: "FloatOutput"
             device: "/device:GPU:7" }
      node { name: "FloatOutput_2" op: "FloatOutput"
             device: "/device:CPU:*" }
      node { name: "FloatOutput_3" op: "FloatOutput"
             device: "/device:CPU:5" }
    """, gd)

  def testNestingErrorGraph(self):
    g = ops.Graph()
    scope = g.device("/device:GPU:8")
    scope.__enter__()
    with g.device("/device:GPU:9"):
      with self.assertRaises(RuntimeError):
        scope.__exit__(None, None, None)

  def testNestingErrorEager(self):
    with context.eager_mode():
      scope = ops.device("/device:CPU:0")
      scope.__enter__()
      with ops.device(None):
        with self.assertRaises(RuntimeError):
          scope.__exit__(None, None, None)

  def testNoneClearsDefault(self):
    g = ops.Graph()
    with g.device("/job:worker/replica:2/device:CPU:1"):
      g.create_op("FloatOutput", [], [dtypes.float32])
      with g.device(None):
        g.create_op("FloatOutput", [], [dtypes.float32])
      g.create_op("FloatOutput", [], [dtypes.float32])
    gd = g.as_graph_def()
    self.assertProtoEqualsVersion("""
      node { name: "FloatOutput" op: "FloatOutput"
             device: "/job:worker/replica:2/device:CPU:1" }
      node { name: "FloatOutput_1" op: "FloatOutput" }
      node { name: "FloatOutput_2" op: "FloatOutput"
             device: "/job:worker/replica:2/device:CPU:1" }
    """, gd)

  def testNoneIgnoresOuterDeviceFunction(self):
    g = ops.Graph()
    with g.device(lambda op: "/job:worker/replica:2/device:CPU:1"):
      g.create_op("FloatOutput", [], [dtypes.float32])
      with g.device(None):
        g.create_op("FloatOutput", [], [dtypes.float32])
      g.create_op("FloatOutput", [], [dtypes.float32])
    gd = g.as_graph_def()
    self.assertProtoEqualsVersion("""
      node { name: "FloatOutput" op: "FloatOutput"
             device: "/job:worker/replica:2/device:CPU:1" }
      node { name: "FloatOutput_1" op: "FloatOutput" }
      node { name: "FloatOutput_2" op: "FloatOutput"
             device: "/job:worker/replica:2/device:CPU:1" }
    """, gd)

  def _overwritingDeviceFunction(self, unused_op):
    # This device function unconditionally overwrites the device of ops.
    #
    # NOTE(mrry): Writing device functions like this is not
    # recommended. Instead, in most cases you should use
    # `pydev.merge_device("/job:ps")` or simply `"/job:ps"` as the
    # argument to `tf.device()` and the device component will be merged in.
    return "/job:overwrite"

  def testOverwritingBehavior(self):
    g = ops.Graph()
    with g.device(self._overwritingDeviceFunction):
      g.create_op("FloatOutput", [], [dtypes.float32])
      with g.device("/job:ps"):  # Will be overwritten.
        g.create_op("FloatOutput", [], [dtypes.float32])
      with g.device(pydev.merge_device("/job:ps")):  # Will be overwritten.
        g.create_op("FloatOutput", [], [dtypes.float32])
      with g.device(None):  # Disables overwriting device function
        with g.device("/job:ps"):
          g.create_op("FloatOutput", [], [dtypes.float32])
      with g.device(None):  # Disables overwriting device function
        with g.device(pydev.merge_device("/job:ps")):
          g.create_op("FloatOutput", [], [dtypes.float32])
    gd = g.as_graph_def()
    self.assertProtoEqualsVersion("""
      node { name: "FloatOutput" op: "FloatOutput"
             device: "/job:overwrite" }
      node { name: "FloatOutput_1" op: "FloatOutput"
             device: "/job:overwrite" }
      node { name: "FloatOutput_2" op: "FloatOutput"
             device: "/job:overwrite" }
      node { name: "FloatOutput_3" op: "FloatOutput"
             device: "/job:ps" }
      node { name: "FloatOutput_4" op: "FloatOutput"
             device: "/job:ps" }
    """, gd)


class MultithreadedGraphStateTest(test_util.TensorFlowTestCase):

  class TestThread(threading.Thread):

    def __init__(self, graph, replica_id):
      super(MultithreadedGraphStateTest.TestThread, self).__init__()
      self._graph = graph
      self._replica_id = replica_id
      # This thread sets this event when it mutated the graph.  The caller can
      # wait for that.
      self.has_mutated_graph = threading.Event()
      # This thread waits for when it should continue.  The caller can set this
      # event.
      self.should_continue = threading.Event()

    def run(self):
      # Mutate a graph's stack, then set `has_mutated_graph`, then wait for
      # `should_continue`, then add an op to the graph affected by the graph's
      # stack.
      raise NotImplementedError("must be implemented in descendants")

  def testDeviceFunctionStack(self):

    class DeviceSettingThread(self.TestThread):

      def run(self):
        with g.device("/job:worker/replica:{}".format(self._replica_id)):
          self.has_mutated_graph.set()
          self.should_continue.wait()
          self.should_continue.clear()
          g.create_op(
              "FloatOutput", [], [dtypes.float32],
              name="FloatOutput_{}".format(self._replica_id))

    g = ops.Graph()
    # If `switch_to_thread` isn't called, then device placement of the ops
    # below is not deterministic.
    g.switch_to_thread_local()
    threads = [DeviceSettingThread(g, i) for i in range(3)]
    for t in threads:
      t.start()
      t.has_mutated_graph.wait()
      t.has_mutated_graph.clear()
    for t in threads:
      t.should_continue.set()
      t.join()

    gd = g.as_graph_def()
    self.assertProtoEqualsVersion("""
      node { name: "FloatOutput_0" op: "FloatOutput"
             device: "/job:worker/replica:0" }
      node { name: "FloatOutput_1" op: "FloatOutput"
             device: "/job:worker/replica:1" }
      node { name: "FloatOutput_2" op: "FloatOutput"
             device: "/job:worker/replica:2" }
    """, gd)

  def testColocateWith(self):

    class ColocatingThread(self.TestThread):

      def __init__(self, graph, replica_id, op_to_colocate_with):
        super(ColocatingThread, self).__init__(graph, replica_id)
        self._op_to_colocate_with = op_to_colocate_with

      def run(self):
        with g.colocate_with(self._op_to_colocate_with):
          self.has_mutated_graph.set()
          self.should_continue.wait()
          self.should_continue.clear()
          g.create_op(
              "FloatOutput", [], [dtypes.float32],
              name="FloatOutput_{}".format(self._replica_id))

    g = ops.Graph()
    ops_to_colocate_with = []
    for i in range(3):
      with g.device("/job:worker/replica:{}".format(i)):
        ops_to_colocate_with.append(
            g.create_op(
                "FloatOutput", [], [dtypes.float32],
                name="ColocateWithMe_{}".format(i)))

    # If `switch_to_thread` isn't called, then `device` and `attr` values for
    # the ops below are not deterministic.
    g.switch_to_thread_local()
    threads = [
        ColocatingThread(g, i, ops_to_colocate_with[i]) for i in range(3)
    ]
    for t in threads:
      t.start()
      t.has_mutated_graph.wait()
      t.has_mutated_graph.clear()
    for t in threads:
      t.should_continue.set()
      t.join()

    gd = g.as_graph_def()
    self.assertProtoEqualsVersion("""
      node { name: "ColocateWithMe_0" op: "FloatOutput"
             device: "/job:worker/replica:0" }
      node { name: "ColocateWithMe_1" op: "FloatOutput"
             device: "/job:worker/replica:1" }
      node { name: "ColocateWithMe_2" op: "FloatOutput"
             device: "/job:worker/replica:2" }
      node { name: "FloatOutput_0" op: "FloatOutput"
             device: "/job:worker/replica:0"
             attr { key: "_class"
               value { list {
                 s: "loc:@ColocateWithMe_0"}}}}
      node { name: "FloatOutput_1" op: "FloatOutput"
             device: "/job:worker/replica:1"
             attr { key: "_class"
               value { list {
                 s: "loc:@ColocateWithMe_1"}}}}
      node { name: "FloatOutput_2" op: "FloatOutput"
             device: "/job:worker/replica:2"
             attr { key: "_class"
               value { list {
                 s: "loc:@ColocateWithMe_2"}}}}
    """, gd)

  def testControlDependencies(self):

    class DependingThread(self.TestThread):

      def __init__(self, graph, replica_id, dependency_op):
        super(DependingThread, self).__init__(graph, replica_id)
        self._dependency_op = dependency_op

      def run(self):
        with g.control_dependencies([self._dependency_op]):
          self.has_mutated_graph.set()
          self.should_continue.wait()
          self.should_continue.clear()
          g.create_op(
              "FloatOutput", [], [dtypes.float32],
              name="FloatOutput_{}".format(self._replica_id))

    g = ops.Graph()
    dependency_ops = []
    for i in range(3):
      dependency_ops.append(
          g.create_op(
              "FloatOutput", [], [dtypes.float32],
              name="ColocateWithMe_{}".format(i)))

    # If `switch_to_thread` isn't called, then `input` values for the ops below
    # are not deterministic.
    g.switch_to_thread_local()
    threads = [DependingThread(g, i, dependency_ops[i]) for i in range(3)]
    for t in threads:
      t.start()
      t.has_mutated_graph.wait()
      t.has_mutated_graph.clear()
    for t in threads:
      t.should_continue.set()
      t.join()

    gd = g.as_graph_def()
    self.assertProtoEqualsVersion(
        """
      node { name: "ColocateWithMe_0" op: "FloatOutput"
             attr { key: "_has_manual_control_dependencies"
                    value { b: true } } }
      node { name: "ColocateWithMe_1" op: "FloatOutput"
             attr { key: "_has_manual_control_dependencies"
                    value { b: true } } }
      node { name: "ColocateWithMe_2" op: "FloatOutput"
             attr { key: "_has_manual_control_dependencies"
                    value { b: true } } }
      node { name: "FloatOutput_0" op: "FloatOutput"
             input: "^ColocateWithMe_0" }
      node { name: "FloatOutput_1" op: "FloatOutput"
             input: "^ColocateWithMe_1" }
      node { name: "FloatOutput_2" op: "FloatOutput"
             input: "^ColocateWithMe_2" }
    """, gd)

  def testNameStack(self):

    class NameSettingThread(self.TestThread):

      def run(self):
        with g.name_scope("foo"):
          op1 = g.create_op("FloatOutput", [], [dtypes.float32])
          self.has_mutated_graph.set()
          self.should_continue.wait()
          self.should_continue.clear()
          op2 = g.create_op("FloatOutput", [], [dtypes.float32])
          self.result = (op1, op2)

    g = ops.Graph()
    threads = [NameSettingThread(g, i) for i in range(3)]
    for t in threads:
      t.start()
      t.has_mutated_graph.wait()
      t.has_mutated_graph.clear()

    for t in threads:
      t.should_continue.set()
      t.join()

    suffixes = ["", "_1", "_2"]
    for t, s in zip(threads, suffixes):
      self.assertEqual("foo" + s + "/FloatOutput", t.result[0].name)
      self.assertEqual("foo" + s + "/FloatOutput_1", t.result[1].name)


class ObjectWithName(object):

  def __init__(self, name):
    self._name = name

  @property
  def name(self):
    return self._name


class CollectionTest(test_util.TensorFlowTestCase):

  def test_get_collections(self):
    g = ops.Graph()
    self.assertSequenceEqual(g.collections, [])
    g.add_to_collection("key", 12)
    g.add_to_collection("key", 15)
    self.assertSequenceEqual(g.collections, ["key"])
    g.add_to_collection("other", "foo")
    self.assertSequenceEqual(sorted(g.collections), ["key", "other"])
    self.assertSequenceEqual(
        sorted(g.get_all_collection_keys()), ["key", "other"])

  def test_add_to_collection(self):
    g = ops.Graph()
    g.add_to_collection("key", 12)
    g.add_to_collection("other", "foo")
    g.add_to_collection("key", 34)

    # Note that only blank1 is returned.
    g.add_to_collection("blah", 27)
    blank1 = ObjectWithName("prefix/foo")
    g.add_to_collection("blah", blank1)
    blank2 = ObjectWithName("junk/foo")
    g.add_to_collection("blah", blank2)

    self.assertEqual([12, 34], g.get_collection("key"))
    self.assertEqual([], g.get_collection("nothing"))
    self.assertEqual([27, blank1, blank2], g.get_collection("blah"))
    self.assertEqual([blank1], g.get_collection("blah", "prefix"))
    self.assertEqual([blank1], g.get_collection("blah", ".*x"))

    # Make sure that get_collection() returns a first-level
    # copy of the collection, while get_collection_ref() returns
    # the original list.
    other_collection_snapshot = g.get_collection("other")
    other_collection_ref = g.get_collection_ref("other")
    self.assertEqual(["foo"], other_collection_snapshot)
    self.assertEqual(["foo"], other_collection_ref)
    g.add_to_collection("other", "bar")
    self.assertEqual(["foo"], other_collection_snapshot)
    self.assertEqual(["foo", "bar"], other_collection_ref)
    self.assertEqual(["foo", "bar"], g.get_collection("other"))
    self.assertTrue(other_collection_ref is g.get_collection_ref("other"))

    # Verify that getting an empty collection ref returns a modifiable list.
    empty_coll_ref = g.get_collection_ref("empty")
    self.assertEqual([], empty_coll_ref)
    empty_coll = g.get_collection("empty")
    self.assertEqual([], empty_coll)
    self.assertFalse(empty_coll is empty_coll_ref)
    empty_coll_ref2 = g.get_collection_ref("empty")
    self.assertTrue(empty_coll_ref2 is empty_coll_ref)
    # Add to the collection.
    empty_coll_ref.append("something")
    self.assertEqual(["something"], empty_coll_ref)
    self.assertEqual(["something"], empty_coll_ref2)
    self.assertEqual([], empty_coll)
    self.assertEqual(["something"], g.get_collection("empty"))
    empty_coll_ref3 = g.get_collection_ref("empty")
    self.assertTrue(empty_coll_ref3 is empty_coll_ref)

  def test_add_to_collections_uniquify(self):
    g = ops.Graph()
    g.add_to_collections([1, 2, 1], "key")
    # Make sure "key" is not added twice
    self.assertEqual(["key"], g.get_collection(1))

  def test_add_to_collections_from_list(self):
    g = ops.Graph()
    g.add_to_collections(["abc", "123"], "key")
    self.assertEqual(["key"], g.get_collection("abc"))
    self.assertEqual(["key"], g.get_collection("123"))

  def test_add_to_collections_from_tuple(self):
    g = ops.Graph()
    g.add_to_collections(("abc", "123"), "key")
    self.assertEqual(["key"], g.get_collection("abc"))
    self.assertEqual(["key"], g.get_collection("123"))

  def test_add_to_collections_from_generator(self):
    g = ops.Graph()

    def generator():
      yield "abc"
      yield "123"

    g.add_to_collections(generator(), "key")
    self.assertEqual(["key"], g.get_collection("abc"))
    self.assertEqual(["key"], g.get_collection("123"))

  def test_add_to_collections_from_set(self):
    g = ops.Graph()
    g.add_to_collections(set(["abc", "123"]), "key")
    self.assertEqual(["key"], g.get_collection("abc"))
    self.assertEqual(["key"], g.get_collection("123"))

  def test_add_to_collections_from_string(self):
    g = ops.Graph()
    g.add_to_collections("abc", "key")
    self.assertEqual(["key"], g.get_collection("abc"))

  def test_default_graph(self):
    with ops.Graph().as_default():
      ops.add_to_collection("key", 90)
      ops.add_to_collection("key", 100)
      # Collections are ordered.
      self.assertEqual([90, 100], ops.get_collection("key"))


ops.NotDifferentiable("FloatOutput")


@ops.RegisterGradient("CopyOp")
def _CopyGrad(op, x_grad):  # pylint: disable=invalid-name
  _ = op
  return x_grad


@ops.RegisterGradient("copy_override")
def _CopyOverrideGrad(op, x_grad):  # pylint: disable=invalid-name
  _ = op
  return x_grad


class RegistrationTest(test_util.TensorFlowTestCase):

  @test_util.run_deprecated_v1
  def testRegisterGradients(self):
    x = test_ops.float_output()
    y = test_ops.copy_op(x)
    fn = ops.get_gradient_function(y.op)
    self.assertEqual(_CopyGrad, fn)

  def testOverrideGradients(self):
    g = ops.Graph()
    with g.as_default():
      x = test_ops.float_output()
      with g.gradient_override_map({"CopyOp": "copy_override"}):
        y = test_ops.copy_op(x)
      fn = ops.get_gradient_function(y.op)
      self.assertEqual(_CopyOverrideGrad, fn)

  def testNonExistentOverride(self):
    g = ops.Graph()
    with g.as_default():
      x = test_ops.float_output()
      with g.gradient_override_map({"CopyOp": "unknown_override"}):
        y = test_ops.copy_op(x)
      with self.assertRaisesRegex(LookupError, "unknown_override"):
        ops.get_gradient_function(y.op)


class ComparisonTest(test_util.TensorFlowTestCase):

  def testMembershipAllowed(self):
    g = ops.Graph()
    t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1")
    t2 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop2")
    self.assertTrue(isinstance(t1, tensor_lib.Tensor))
    self.assertTrue(isinstance(t2, tensor_lib.Tensor))
    self.assertTrue(t1 in [t1])
    self.assertTrue(t1 not in [t2])


class ControlDependenciesTest(test_util.TensorFlowTestCase):

  @test_util.run_deprecated_v1
  def testBasic(self):
    g = ops.Graph()
    with g.as_default():
      # Creating unregistered ops with _apply_op() doesn't work with the C API
      # TODO(skyewm): address this more consistently. Possible solutions are
      # to use registered ops in all tests, create a way to register ops in
      # Python tests, or conditionally disable the op registration check in
      # the C API.
      a = constant_op.constant(1.0)
      b = constant_op.constant(1.0)
      with g.control_dependencies([a]):
        c = constant_op.constant(1.0)
        d = array_ops.identity(b)
        e = array_ops.identity(c)

    self.assertEqual(c.op.control_inputs, [a.op])
    self.assertEqual(d.op.control_inputs, [a.op])
    # e should be dominated by c.
    self.assertEqual(e.op.control_inputs, [])

  @test_util.run_in_graph_and_eager_modes
  def testEager(self):
    def future():
      future.calls += 1
      return constant_op.constant(2.0)
    future.calls = 0

    if context.executing_eagerly():
      a = constant_op.constant(1.0)
      b = future
      with ops.control_dependencies([a, b]):
        c = constant_op.constant(3.0)
      self.assertEqual(future.calls, 1)
    else:
      g = ops.Graph()
      with g.as_default():
        a = constant_op.constant(1.0)
        b = future()
        with g.control_dependencies([a, b]):
          c = constant_op.constant(3.0)
      self.assertEqual(c.op.control_inputs, [a.op, b.op])
      self.assertEqual(future.calls, 1)

  def testBasicWithConversion(self):
    g = ops.Graph()
    a = _apply_op(g, "FloatOutput", [], [dtypes.float32])

    class ConvertibleObj(object):

      def _as_graph_element(self):
        return a

    with g.control_dependencies([ConvertibleObj()]):
      c = _apply_op(g, "FloatOutput", [], [dtypes.float32])

    self.assertEqual(c.op.control_inputs, [a.op])

  def testNested(self):
    g = ops.Graph()
    a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
    a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
    a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
    a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])

    with g.control_dependencies([a_1, a_2, a_3, a_4]):
      b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])

    with g.control_dependencies([a_1]):
      with g.control_dependencies([a_2]):
        with g.control_dependencies([a_3]):
          with g.control_dependencies([a_4]):
            b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])

    self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op],
                          b_1.op.control_inputs)
    self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs)

  def testClear(self):
    g = ops.Graph()
    a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
    a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
    a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
    a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])

    with g.control_dependencies([a_1]):
      with g.control_dependencies([a_2]):
        with g.control_dependencies(None):
          with g.control_dependencies([a_3]):
            with g.control_dependencies([a_4]):
              # deps [a_3, a_4]
              b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
            # deps = [a_3]
            b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
          # deps back to None
          b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
        # deps back to [a_1, a_2]
        b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
      # deps back to [a_1]
      b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
      with g.control_dependencies(None):
        # deps are None again
        b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])

    self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
    self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
    self.assertItemsEqual([], b_none.op.control_inputs)
    self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
    self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
    self.assertItemsEqual([], b_none2.op.control_inputs)

  def testComplex(self):
    g = ops.Graph()

    # Usage pattern:
    # * Nodes a_i are constants defined at the outermost scope, and are used
    #   as control inputs for the ith nested scope.
    # * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
    # * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
    # * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
    # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.

    a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
    a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
    a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
    a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])

    with g.control_dependencies([a_1]):
      b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
                      [dtypes.float32])
      c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
                      [dtypes.float32])
      d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1],
                      [dtypes.float32])
      e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
      with g.control_dependencies([a_2]):
        b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
                        [dtypes.float32])
        c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
                        [dtypes.float32])
        d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2],
                        [dtypes.float32])
        e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1],
                        [dtypes.float32])
        with g.control_dependencies([a_3]):
          b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
                          [dtypes.float32])
          c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
                          [dtypes.float32])
          d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3],
                          [dtypes.float32])
          e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2],
                          [dtypes.float32])
          with g.control_dependencies([a_4]):
            b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
                            [dtypes.float32])
            c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
                            [dtypes.float32])
            d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4],
                            [dtypes.float32])
            e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3],
                            [dtypes.float32])

    self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
    self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs)
    self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs)
    self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs)

    self.assertItemsEqual([], c_1.op.control_inputs)
    self.assertItemsEqual([a_2.op], c_2.op.control_inputs)
    self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs)
    self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs)

    self.assertItemsEqual([], d_1.op.control_inputs)
    self.assertItemsEqual([], d_2.op.control_inputs)
    self.assertItemsEqual([], d_3.op.control_inputs)
    self.assertItemsEqual([], d_4.op.control_inputs)

    self.assertItemsEqual([a_1.op], e_1.op.control_inputs)
    self.assertItemsEqual([a_2.op], e_2.op.control_inputs)
    self.assertItemsEqual([a_3.op], e_3.op.control_inputs)
    self.assertItemsEqual([a_4.op], e_4.op.control_inputs)

  def testRepeatedDependency(self):
    g = ops.Graph()
    a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
    a_0, a_1 = a.outputs
    with g.control_dependencies([a_0]):
      b = _apply_op(g, "FloatOutput", [], [dtypes.float32])
      with g.control_dependencies([a_1]):
        c = _apply_op(g, "FloatOutput", [], [dtypes.float32])

    self.assertEqual(b.op.control_inputs, [a])
    self.assertEqual(c.op.control_inputs, [a])

  def testNoControlDependencyWithDataDependency(self):
    g = ops.Graph()
    a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
    with g.control_dependencies([a]):
      b = _apply_op(g, "Identity", [a], [dtypes.float32])

    self.assertEqual(b.op.control_inputs, [])

  def testMonitoringAttributeAddedWhenUsingManualControlDep(self):
    g = ops.Graph()
    a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
    b = _apply_op(g, "FloatOutput", [], [dtypes.float32])
    with g.control_dependencies([a]):
      c = _apply_op(g, "Identity", [b], [dtypes.float32])

    with g.control_dependencies([b]):
      d = _apply_op(g, "Identity", [b], [dtypes.float32])

    # Validate that the monitoring attribute is set to track usage of the
    # `control_dependencies(...)` API.
    self.assertEqual(c.op.control_inputs, [a.op])
    with self.assertRaises(ValueError):
      c.op.get_attr("_has_manual_control_dependencies")
    self.assertEqual(a.op.get_attr("_has_manual_control_dependencies"), True)

    # Validate that the monitoring attribute is set to track usage of the
    # `control_dependencies(...)` API even when the manual control deps actually
    # happened to be pruned at runtime.
    self.assertEqual(d.op.control_inputs, [])
    with self.assertRaises(ValueError):
      d.op.get_attr("_has_manual_control_dependencies")
    self.assertEqual(b.op.get_attr("_has_manual_control_dependencies"), True)


class OpScopeTest(test_util.TensorFlowTestCase):

  @test_util.run_in_graph_and_eager_modes
  def testNames(self):
    with ops.name_scope("foo", skip_on_eager=False) as foo:
      self.assertEqual("foo/", foo)
      with ops.name_scope("foo2", skip_on_eager=False) as foo2:
        self.assertEqual("foo/foo2/", foo2)
      with ops.name_scope(None, skip_on_eager=False) as empty1:
        self.assertEqual("", empty1)
        with ops.name_scope("foo3", skip_on_eager=False) as foo3:
          self.assertEqual("foo3/", foo3)
      with ops.name_scope("", skip_on_eager=False) as empty2:
        self.assertEqual("", empty2)
    with ops.name_scope("foo/", skip_on_eager=False) as outer_foo:
      self.assertEqual("foo/", outer_foo)
      with ops.name_scope("", skip_on_eager=False) as empty3:
        self.assertEqual("", empty3)
      with ops.name_scope("foo4", skip_on_eager=False) as foo4:
        self.assertEqual("foo/foo4/", foo4)
      with ops.name_scope("foo5//", skip_on_eager=False) as foo5:
        self.assertEqual("foo5//", foo5)
        with ops.name_scope("foo6", skip_on_eager=False) as foo6:
          self.assertEqual("foo5//foo6/", foo6)
      with ops.name_scope("/", skip_on_eager=False) as foo7:
        self.assertEqual("/", foo7)
      with ops.name_scope("//", skip_on_eager=False) as foo8:
        self.assertEqual("//", foo8)
      with ops.name_scope("a//b/c", skip_on_eager=False) as foo9:
        self.assertEqual("foo/a//b/c/", foo9)
    with ops.name_scope("a//b/c", skip_on_eager=False) as foo10:
      self.assertEqual("a//b/c/", foo10)

  @test_util.run_in_graph_and_eager_modes
  def testEagerDefaultScopeName(self):
    with ops.name_scope(None, "default", skip_on_eager=False) as scope:
      self.assertEqual(scope, "default/")
      with ops.name_scope(None, "default2", skip_on_eager=False) as scope2:
        self.assertEqual(scope2, "default/default2/")

  @test_util.run_in_graph_and_eager_modes
  def testNameScopeV2IsReEntrant(self):
    foo = ops.name_scope_v2("foo")
    bar = ops.name_scope_v2("bar")
    with foo as scope_name:
      self.assertEqual("foo/", scope_name)
      with foo as scope_name:
        self.assertEqual("foo/foo/", scope_name)
      with bar as scope_name:
        self.assertEqual("foo/bar/", scope_name)
        with foo as scope_name:
          self.assertEqual("foo/bar/foo/", scope_name)
    with bar as scope_name:
      self.assertEqual("bar/", scope_name)

  @test_util.run_deprecated_v1
  def testNoScopeName(self):
    g0 = ops.Graph()
    values = [
        g0.create_op("A", [], [dtypes.float32]),
        g0.create_op("B", [], [dtypes.float32])
    ]
    with self.assertRaises(ValueError):
      with ops.name_scope(None, values=values):
        pass
    with self.assertRaises(ValueError):
      with ops.name_scope(None, None, values):
        pass

  @test_util.run_deprecated_v1
  def testEmptyScopeName(self):
    g0 = ops.Graph()
    a = g0.create_op("A", [], [dtypes.float32])
    b = g0.create_op("B", [], [dtypes.float32])
    with ops.name_scope("", values=[a, b]) as scope:
      self.assertEqual("", scope)
      self.assertEqual(g0, ops.get_default_graph())
    with ops.name_scope("", "my_default_scope", [a, b]) as scope:
      self.assertEqual("", scope)
      self.assertEqual(g0, ops.get_default_graph())

  @test_util.run_deprecated_v1
  def testDefaultScopeName(self):
    g0 = ops.Graph()
    a = g0.create_op("A", [], [dtypes.float32])
    b = g0.create_op("B", [], [dtypes.float32])
    scope_name = "my_scope"
    default_scope_name = "my_default_scope"
    with ops.name_scope(scope_name, default_scope_name, [a, b]) as scope:
      self.assertEqual("%s/" % scope_name, scope)
      self.assertEqual(g0, ops.get_default_graph())
    with ops.name_scope(None, default_scope_name, [a, b]) as scope:
      self.assertEqual("%s/" % default_scope_name, scope)
      self.assertEqual(g0, ops.get_default_graph())
    with self.assertRaises(TypeError):
      with ops.name_scope(scope_name, [a, b]):
        pass

  def _testGraphElements(self, graph_elements):
    scope_name = "my_scope"
    with ops.name_scope(scope_name, values=graph_elements) as scope:
      self.assertEqual("%s/" % scope_name, scope)
      self.assertEqual(graph_elements[0].graph, ops.get_default_graph())
    g1 = ops.Graph()
    a = g1.create_op("A", [], [dtypes.float32])
    with self.assertRaises(ValueError):
      with ops.name_scope(scope_name, values=graph_elements + [a]):
        pass

  @test_util.run_in_graph_and_eager_modes
  def testGetCurrentNameScope(self):
    self.assertEqual(ops.get_current_name_scope(), "")
    with ops.name_scope_v2("aaa"):
      self.assertEqual(ops.get_current_name_scope(), "aaa")
      with ops.name_scope_v2("bbb"):
        self.assertEqual(ops.get_current_name_scope(), "aaa/bbb")
      self.assertEqual(ops.get_current_name_scope(), "aaa")
    self.assertEqual(ops.get_current_name_scope(), "")

  @test_util.run_deprecated_v1
  def testTensor(self):
    g0 = ops.Graph()
    a = g0.create_op("A", [], [dtypes.float32])
    b = g0.create_op("B", [], [dtypes.float32])
    self._testGraphElements([a, b])

  @test_util.run_deprecated_v1
  def testSparseTensor(self):
    g0 = ops.Graph()
    a = g0.create_op("A", [], [dtypes.float32])
    b = g0.create_op("B", [], [dtypes.float32])
    sparse = sparse_tensor.SparseTensor(
        _apply_op(g0, "Int64Output", [], [dtypes.int64]),
        _apply_op(g0, "FloatOutput", [], [dtypes.float32]),
        _apply_op(g0, "Int64Output", [], [dtypes.int64]))
    self._testGraphElements([a, sparse, b])

  @test_util.run_deprecated_v1
  def testVariable(self):
    g0 = ops.Graph()
    with g0.as_default():
      variable = variables.Variable([1.0])
    a = g0.create_op("A", [], [dtypes.float32])
    b = g0.create_op("B", [], [dtypes.float32])
    self._testGraphElements([a, variable, b])


class InitScopeTest(test_util.TensorFlowTestCase):

  def testClearsControlDependencies(self):
    g = ops.Graph()
    a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
    a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
    a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
    a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])

    with g.as_default():
      with g.control_dependencies([a_1]):
        with g.control_dependencies([a_2]):
          with ops.init_scope():
            with g.control_dependencies([a_3]):
              with g.control_dependencies([a_4]):
                # deps [a_3, a_4]
                b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
              # deps = [a_3]
              b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
            # deps back to None
            b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
          # deps back to [a_1, a_2]
          b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
        # deps back to [a_1]
        b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
        with ops.init_scope():
          # deps are None again
          b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])

    self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
    self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
    self.assertItemsEqual([], b_none.op.control_inputs)
    self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
    self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
    self.assertItemsEqual([], b_none2.op.control_inputs)

  def testLiftsOpsFromFunctions(self):
    g0 = ops.Graph()
    g1 = ops.Graph()
    g1._building_function = True  # pylint: disable=protected-access
    g2 = ops.Graph()
    g2._building_function = True  # pylint: disable=protected-access

    with g0.as_default():
      with g1.as_default():
        with g2.as_default():
          with ops.init_scope():
            _ = constant_op.constant(1.0)

    self.assertLen(g2.get_operations(), 0)
    self.assertLen(g1.get_operations(), 0)
    self.assertLen(g0.get_operations(), 1)

  def testPreservesDevices(self):
    g0 = ops.Graph()
    with g0.as_default(), ops.device("CPU:0"):
      g1 = ops.Graph()
      g1._building_function = True  # pylint: disable=protected-access
      with g1.as_default():
        with ops.device("GPU:0"):
          with ops.init_scope():
            # init_scope should preserve device set under `g1`.
            on_gpu = constant_op.constant(1.0)
            self.assertEqual(on_gpu.device, "/device:GPU:0")
          still_on_gpu = constant_op.constant(1.0)
          self.assertEqual(still_on_gpu.device, "/device:GPU:0")
        blank = constant_op.constant(1.0)
        self.assertEqual(blank.device, "")
        with ops.init_scope():
          now_on_cpu = constant_op.constant(1.0)
          self.assertEqual(now_on_cpu.device, "/device:CPU:0")
      on_cpu = constant_op.constant(1.0)
      self.assertEqual(on_cpu.device, "/device:CPU:0")

  def testComposes(self):
    g0 = ops.Graph()
    g1 = ops.Graph()
    g1._building_function = True  # pylint: disable=protected-access
    g2 = ops.Graph()
    g2._building_function = True  # pylint: disable=protected-access
    g3 = ops.Graph()
    g3._building_function = False  # pylint: disable=protected-access

    with g0.as_default():
      with g1.as_default():
        with ops.init_scope():
          # This op should be lifted into g0.
          _ = constant_op.constant(1.0)
          self.assertIs(g0, ops.get_default_graph())
          self.assertLen(g2.get_operations(), 0)
          self.assertLen(g1.get_operations(), 0)
          self.assertLen(g0.get_operations(), 1)
        with g2.as_default():
          with ops.init_scope():
            # This op should be lifted into g0.
            _ = constant_op.constant(1.0)
            self.assertIs(g0, ops.get_default_graph())
            with g3.as_default():
              with ops.init_scope():
                # This op should be lifted into g3, because g3 is not building a
                # function.
                _ = constant_op.constant(1.0)
                self.assertIs(g3, ops.get_default_graph())

    self.assertLen(g3.get_operations(), 1)
    self.assertLen(g2.get_operations(), 0)
    self.assertLen(g1.get_operations(), 0)
    self.assertLen(g0.get_operations(), 2)

  def testEscapesToEagerContext(self):
    g = ops.Graph()
    g._building_function = True  # pylint: disable=protected-access
    with context.eager_mode():
      with context.graph_mode():
        with g.as_default():
          with ops.init_scope():
            # Because g is building a function, init_scope should
            # escape out to the eager context.
            self.assertTrue(context.executing_eagerly())
          # g should be reinstated as the default graph, and the
          # graph context should be re-entered.
          self.assertIs(g, ops.get_default_graph())
          self.assertFalse(context.executing_eagerly())

  def testStaysInEagerWhenOnlyEagerContextActive(self):
    with context.eager_mode():
      with ops.init_scope():
        self.assertTrue(context.eager_mode())
      self.assertTrue(context.eager_mode())

  @test_util.run_v1_only("b/120545219")
  def testFallsBackToGlobalGraphWhenAllGraphsAreBuildingFunctions(self):
    with context.graph_mode():
      ops.reset_default_graph()
      # This doesn't push anything onto the graph stack, but it does
      # set the stack's global graph.
      global_graph = ops.get_default_graph()
      fn_graph = ops.Graph()

      # pylint: disable=protected-access
      fn_graph._building_function = True
      self.assertLen(ops._default_graph_stack.stack, 0)
      with fn_graph.as_default():
        self.assertLen(ops._default_graph_stack.stack, 1)
        with ops.init_scope():
          self.assertGreater(len(ops._default_graph_stack.stack), 1)
          dummy = constant_op.constant(1.0)
        self.assertLen(ops._default_graph_stack.stack, 1)
      # Note that the global graph is _not_ on the graph stack.
      self.assertLen(ops._default_graph_stack.stack, 0)
      # Ensure that `dummy` was added to the global graph.
      self.assertEqual(global_graph, dummy.graph)
      # pylint: enable=protected-access

  def testInstallsDefaultGraphWhenGraphStackIsEmptyInGraphMode(self):
    with context.graph_mode():
      # pylint: disable=protected-access
      self.assertLen(ops._default_graph_stack.stack, 0)
      with ops.init_scope():
        self.assertGreater(len(ops._default_graph_stack.stack), 0)
      self.assertLen(ops._default_graph_stack.stack, 0)
      # pylint: enable=protected-access

  def testPreservesNameScopeInGraphConstruction(self):
    with ops.Graph().as_default():
      function_graph = ops.Graph()
      with function_graph.as_default():
        with ops.name_scope("inner", skip_on_eager=False), ops.init_scope():
          self.assertEqual(ops.get_name_scope(), "inner")
      self.assertEqual(ops.get_name_scope(), "")

  def testEnteringGraphFromEagerIsSticky(self):
    with context.eager_mode():
      g = ops.Graph()
      with g.as_default():
        with ops.init_scope():
          self.assertFalse(context.executing_eagerly())
          self.assertEqual(g, ops.get_default_graph())

  def testMixGraphEager(self):
    with context.eager_mode():
      c = constant_op.constant(1.0)
      with ops.Graph().as_default():
        with self.assertRaisesRegex(RuntimeError,
                                    "Attempting to capture an EagerTensor"):
          math_ops.add(c, c)
        c2 = constant_op.constant(2.0)
      with self.assertRaises(TypeError):
        math_ops.add(c2, c2)

  def testPreservesNameScopeInEagerExecution(self):
    with context.eager_mode():
      def foo():
        with ops.name_scope("inner", skip_on_eager=False), ops.init_scope():
          if context.executing_eagerly():
            # A trailing slash is always appended when eager execution is
            # enabled.
            self.assertEqual(context.context().scope_name, "inner/")
          else:
            self.assertEqual(ops.get_name_scope(), "inner")

      foo()
      self.assertEqual(ops.get_name_scope(), "")
      foo_compiled = def_function.function(foo)
      foo_compiled()
      self.assertEqual(ops.get_name_scope(), "")

  def testExecutingEagerlyOutsideFunctions(self):

    @def_function.function
    def f():
      return ops.executing_eagerly_outside_functions()

    with context.graph_mode():
      self.assertFalse(ops.executing_eagerly_outside_functions())
      with session.Session():
        # Need self.evaluate for these as the return type of functions is
        # tensors.
        self.assertFalse(self.evaluate(f()))

    with context.eager_mode():
      self.assertTrue(ops.executing_eagerly_outside_functions())
      self.assertTrue(f())

      with ops.Graph().as_default():
        self.assertFalse(ops.executing_eagerly_outside_functions())
        with session.Session():
          self.assertFalse(self.evaluate(f()))


class GraphTest(test_util.TensorFlowTestCase):

  def setUp(self):
    ops.reset_default_graph()

  def _AssertDefault(self, expected):
    self.assertIs(expected, ops.get_default_graph())

  def testResetDefaultGraphNesting(self):
    g0 = ops.Graph()
    with self.assertRaises(AssertionError):
      with g0.as_default():
        ops.reset_default_graph()

  def testGraphContextManagerCancelsEager(self):
    with context.eager_mode():
      with ops.Graph().as_default():
        self.assertFalse(context.executing_eagerly())

  def testGraphContextManager(self):
    g0 = ops.Graph()
    with g0.as_default() as g1:
      self.assertIs(g0, g1)

  def testDefaultGraph(self):
    orig = ops.get_default_graph()
    self.assertFalse(ops.has_default_graph())
    self._AssertDefault(orig)
    g0 = ops.Graph()
    self.assertFalse(ops.has_default_graph())
    self._AssertDefault(orig)
    context_manager_0 = g0.as_default()
    self.assertFalse(ops.has_default_graph())
    self._AssertDefault(orig)
    with context_manager_0 as g0:
      self._AssertDefault(g0)
      with ops.Graph().as_default() as g1:
        self.assertTrue(ops.has_default_graph())
        self._AssertDefault(g1)
      self._AssertDefault(g0)
    self._AssertDefault(orig)
    self.assertFalse(ops.has_default_graph())

  def testPreventFeeding(self):
    g = ops.Graph()
    a = constant_op.constant(2.0)
    self.assertTrue(g.is_feedable(a))
    g.prevent_feeding(a)
    self.assertFalse(g.is_feedable(a))

  @test_util.run_deprecated_v1
  def testPreventFetching(self):
    g = ops.Graph()
    a = constant_op.constant(2.0)
    self.assertTrue(g.is_fetchable(a))
    g.prevent_fetching(a.op)
    self.assertFalse(g.is_fetchable(a))

  def testAsGraphElementConversions(self):

    class ConvertibleObj(object):

      def _as_graph_element(self):
        return "FloatOutput:0"

    class NonConvertibleObj(object):

      pass

    g = ops.Graph()
    a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
    self.assertEqual(a, g.as_graph_element(ConvertibleObj()))
    with self.assertRaises(TypeError):
      g.as_graph_element(NonConvertibleObj())

  # Regression test against creating custom __del__ functions in classes
  # involved in cyclic references, e.g. Graph and Operation. (Python won't gc
  # cycles that require calling a __del__ method, because the __del__ method can
  # theoretically increase the object's refcount to "save" it from gc, and any
  # already-deleted objects in the cycle would have be to restored.)
  def testGarbageCollected(self):
    # Create a graph we can delete and a weak reference to monitor if it's gc'd
    g = ops.Graph()
    g_ref = weakref.ref(g)
    # Create some ops
    with g.as_default():
      a = constant_op.constant(2.0)
      b = constant_op.constant(3.0)
      c = math_ops.add(a, b)
    # Create a session we can delete
    with session.Session(graph=g) as sess:
      self.evaluate(c)
    # Delete all references and trigger gc
    del g
    del a
    del b
    del c
    del sess
    gc.collect()
    self.assertIsNone(g_ref())

  def testRunnableAfterInvalidShape(self):
    with ops.Graph().as_default():
      with self.assertRaises(ValueError):
        math_ops.add([1, 2], [1, 2, 3])
      a = constant_op.constant(1)
      with session.Session() as sess:
        self.evaluate(a)

  def testRunnableAfterInvalidShapeWithKernelLabelMap(self):
    g = ops.Graph()
    with g.as_default():
      with g._kernel_label_map({"KernelLabelRequired": "overload_1"}):
        with self.assertRaises(ValueError):
          test_ops.kernel_label_required(1)
      a = constant_op.constant(1)
      with session.Session() as sess:
        self.evaluate(a)


class AttrScopeTest(test_util.TensorFlowTestCase):

  def _get_test_attrs(self):
    x = gen_control_flow_ops.no_op()
    try:
      a = compat.as_text(x.get_attr("_A"))
    except ValueError:
      a = None
    try:
      b = compat.as_text(x.get_attr("_B"))
    except ValueError:
      b = None
    return (a, b)

  @test_util.run_deprecated_v1
  def testNoLabel(self):
    with self.cached_session():
      self.assertAllEqual((None, None), self._get_test_attrs())

  @test_util.run_deprecated_v1
  def testLabelMap(self):
    with self.cached_session() as sess:
      a1 = self._get_test_attrs()
      with sess.graph._attr_scope({
          "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo"))
      }):
        a2 = self._get_test_attrs()
        with sess.graph._attr_scope({
            "_A": None,
            "_B": attr_value_pb2.AttrValue(s=compat.as_bytes("bar"))
        }):
          a3 = self._get_test_attrs()
          with sess.graph._attr_scope({
              "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("baz"))
          }):
            a4 = self._get_test_attrs()
          a5 = self._get_test_attrs()
        a6 = self._get_test_attrs()
      a7 = self._get_test_attrs()

      self.assertAllEqual((None, None), a1)
      self.assertAllEqual(("foo", None), a2)
      self.assertAllEqual((None, "bar"), a3)
      self.assertAllEqual(("baz", "bar"), a4)
      self.assertAllEqual((None, "bar"), a5)
      self.assertAllEqual(("foo", None), a6)
      self.assertAllEqual((None, None), a7)


class KernelLabelTest(test_util.TensorFlowTestCase):

  @test_util.run_deprecated_v1
  def testNoLabel(self):
    with self.cached_session():
      self.assertAllEqual(b"My label is: default",
                          test_ops.kernel_label().eval())

  @test_util.run_deprecated_v1
  def testLabelMap(self):
    with self.cached_session() as sess:
      default_1 = test_ops.kernel_label()
      # pylint: disable=protected-access
      with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}):
        overload_1_1 = test_ops.kernel_label()
        with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}):
          overload_2 = test_ops.kernel_label()
          with sess.graph._kernel_label_map({"KernelLabel": ""}):
            default_2 = test_ops.kernel_label()
        overload_1_2 = test_ops.kernel_label()
      # pylint: enable=protected-access
      default_3 = test_ops.kernel_label()

      self.assertAllEqual(b"My label is: default", self.evaluate(default_1))
      self.assertAllEqual(b"My label is: default", self.evaluate(default_2))
      self.assertAllEqual(b"My label is: default", self.evaluate(default_3))
      self.assertAllEqual(b"My label is: overload_1",
                          self.evaluate(overload_1_1))
      self.assertAllEqual(b"My label is: overload_1",
                          self.evaluate(overload_1_2))
      self.assertAllEqual(b"My label is: overload_2", self.evaluate(overload_2))


class AsGraphDefTest(test_util.TensorFlowTestCase):

  def testGraphDefVersion(self):
    """Test that the graphdef version is plumbed through to kernels."""
    with ops.Graph().as_default() as g:
      version = g.graph_def_versions.producer
      with self.session(graph=g):
        v = test_ops.graph_def_version().eval()
        self.assertEqual(version, v)

  def testAddShapes(self):
    with ops.Graph().as_default() as g:
      t1, t2, t3, t4, t5 = _apply_op(g, "FiveFloatOutputs", [],
                                     [dtypes.float32] * 5)
      t1.set_shape(None)
      t2.set_shape([])
      t3.set_shape([None])
      t4.set_shape([43, 37])
      t5.set_shape([43, None])

      b = constant_op.constant(1.0)  # pylint: disable=unused-variable

      gd = g.as_graph_def(add_shapes=True)
      self.assertProtoEqualsVersion("""
      node { name: "FiveFloatOutputs" op: "FiveFloatOutputs"
        attr {
          key: "_output_shapes"
          value {
            list {
              shape { unknown_rank: true }
              shape { }
              shape { dim { size: -1 } }
              shape { dim { size: 43 } dim { size: 37 } }
              shape { dim { size: 43 } dim { size: -1 } }
            }
          }
        }
      }
    node { name: "Const" op: "Const"
      attr {
        key: "_output_shapes"
        value {
          list {
            shape { }
          }
        }
      }
      attr {
        key: "dtype"
        value { type: DT_FLOAT }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_FLOAT
            tensor_shape { }
         float_val: 1.0  } } } }
      """, gd)


@ops.RegisterStatistics("a", "flops")
def _calc_a_forward_flops(unused_graph, unused_node):
  return ops.OpStats("flops", 20)


class StatisticsTest(test_util.TensorFlowTestCase):

  def testRegisteredNode(self):
    graph = ops.Graph()
    node = ops._NodeDef("a", "an_a")
    flops = ops.get_stats_for_node_def(graph, node, "flops")
    self.assertEqual(20, flops.value)
    missing_stat = ops.get_stats_for_node_def(graph, node, "missing_stat")
    self.assertEqual(None, missing_stat.value)

  def testUnregisteredNode(self):
    graph = ops.Graph()
    node = ops._NodeDef("b", "a_b")
    weight_params = ops.get_stats_for_node_def(graph, node, "weight_params")
    self.assertEqual(None, weight_params.value)

  def testAccumulateStatistics(self):
    flops_total = ops.OpStats("flops")
    self.assertEqual(None, flops_total.value)
    second_flops = ops.OpStats("flops", 3)
    flops_total += second_flops
    self.assertEqual(3, flops_total.value)


class DeviceStackTest(test_util.TensorFlowTestCase):

  @test_util.run_deprecated_v1
  def testBasicDeviceAssignmentMetadata(self):

    def device_func(unused_op):
      return "/cpu:*"

    const_zero = constant_op.constant([0.0], name="zero")
    with ops.device("/cpu"):
      const_one = constant_op.constant([1.0], name="one")
      with ops.device("/cpu:0"):
        const_two = constant_op.constant([2.0], name="two")
    with ops.device(device_func):
      const_three = constant_op.constant(3.0, name="three")

    self.assertEqual(0, len(const_zero.op._device_assignments))

    one_list = const_one.op._device_assignments
    self.assertEqual(1, len(one_list))
    self.assertEqual("/cpu", one_list[0].obj)
    self.assertEqual("ops_test.py", os.path.basename(one_list[0].filename))

    two_list = const_two.op._device_assignments
    self.assertEqual(2, len(two_list))
    devices = [t.obj for t in two_list]
    self.assertEqual(set(["/cpu", "/cpu:0"]), set(devices))

    three_list = const_three.op._device_assignments
    self.assertEqual(1, len(three_list))
    func_description = three_list[0].obj
    expected_regex = r"device_func<.*ops_test.py, [0-9]+"
    self.assertRegex(func_description, expected_regex)

  @test_util.run_deprecated_v1
  def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self):

    with ops.device("/cpu"):
      const_one = constant_op.constant([1.0], name="one")
    with ops.get_default_graph().device("/cpu"):
      const_two = constant_op.constant([2.0], name="two")

    one_metadata = const_one.op._device_assignments[0]
    two_metadata = const_two.op._device_assignments[0]

    # Verify both types of device assignment return the right stack info.
    self.assertRegex("ops_test.py", os.path.basename(one_metadata.filename))
    self.assertEqual(one_metadata.filename, two_metadata.filename)
    self.assertEqual(one_metadata.lineno + 2, two_metadata.lineno)


class ColocationGroupTest(test_util.TensorFlowTestCase):

  @test_util.run_deprecated_v1
  def testBasic(self):
    a = constant_op.constant([2.0], name="a")
    with ops.colocate_with(a.op):
      b = constant_op.constant(3.0)
    c = constant_op.constant(4.0)
    self.assertEqual([b"loc:@a"], a.op.colocation_groups())
    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
    with self.assertRaises(ValueError):
      c.op.get_attr("_class")

  @test_util.run_deprecated_v1
  def testBasicColocationMetadata(self):
    const_two = constant_op.constant([2.0], name="two")
    with ops.colocate_with(const_two.op):
      const_three = constant_op.constant(3.0, name="three")
    locations_dict = const_three.op._colocation_dict
    self.assertIn("two", locations_dict)
    metadata = locations_dict["two"]
    self.assertIsNone(metadata.obj)
    # Check that this test's filename is recorded as the file containing the
    # colocation statement.
    self.assertEqual("ops_test.py", os.path.basename(metadata.filename))

  @test_util.run_deprecated_v1
  def testColocationDeviceInteraction(self):
    with ops.device("/cpu:0"):
      with ops.device("/device:GPU:0"):
        a = constant_op.constant([2.0], name="a")
      with ops.colocate_with(a.op):
        # 'b' is created in the scope of /cpu:0, but it is
        # colocated with 'a', which is on '/device:GPU:0'.  colocate_with
        # overrides devices because it is a stronger constraint.
        b = constant_op.constant(3.0)
    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
    self.assertEqual(a.op.device, b.op.device)

  @test_util.run_deprecated_v1
  def testColocationCanonicalization(self):
    with ops.device("/device:GPU:0"):
      _ = constant_op.constant(2.0)
    with ops.device(lambda op: "/device:GPU:0"):
      b = constant_op.constant(3.0)
    with ops.get_default_graph().colocate_with(b):
      with ops.device("/device:GPU:0"):
        c = constant_op.constant(4.0)

    # A's device will be /device:GPU:0
    # B's device will be /device:GPU:0
    # C's device will be /device:GPU:0 because it
    # inherits B's device name, after canonicalizing the names.
    self.assertEqual(b.op.device, c.op.device)

  @test_util.run_deprecated_v1
  def testLocationOverrides(self):
    with ops.device("/cpu:0"):
      with ops.device("/device:GPU:0"):
        a = constant_op.constant([2.0], name="a")
        # Note that this colocation is "redundant", since we are
        # within the scope of "/device:GPU:0".  However, we would like to
        # preserve in the GraphDef that these two ops should be
        # colocated in a portable way.
        with ops.colocate_with(a.op):
          b = constant_op.constant(3.0)
        c = constant_op.constant(4.0)
      d = constant_op.constant(5.0)

    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
    self.assertEqual("/device:GPU:0", a.op.device)
    self.assertEqual(a.op.device, b.op.device)

    # Test that device function stack is restored.
    self.assertEqual("/device:GPU:0", c.op.device)
    self.assertEqual("/device:CPU:0", d.op.device)

  @test_util.run_deprecated_v1
  def testNestedColocateWith(self):
    a = constant_op.constant([2.0], name="a")
    with ops.colocate_with(a.op):
      b = constant_op.constant(3.0)
      with ops.colocate_with(b.op):
        c = constant_op.constant(4.0)
    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
    self.assertEqual([b"loc:@a"], c.op.colocation_groups())

  @test_util.run_deprecated_v1
  def testMultiColocationGroups(self):
    a = constant_op.constant([2.0], name="a")
    b = constant_op.constant(3.0, name="b")
    with ops.colocate_with(a.op):
      with ops.colocate_with(b.op):
        c = constant_op.constant(4.0)
    self.assertEqual(set([b"loc:@a", b"loc:@b"]), set(c.op.colocation_groups()))

  @test_util.run_deprecated_v1
  def testColocationIgnoreStack(self):
    a = constant_op.constant([2.0], name="a")
    b = constant_op.constant(3.0, name="b")
    with ops.colocate_with(a.op):
      with ops.colocate_with(b.op, ignore_existing=True):
        c = constant_op.constant(4.0)
    self.assertEqual(set([b"loc:@b"]), set(c.op.colocation_groups()))

  @test_util.run_deprecated_v1
  def testColocateWithReset(self):
    a = constant_op.constant([2.0], name="a")
    with ops.colocate_with(a.op):
      b = constant_op.constant(3.0, name="b")
      with ops.colocate_with(None, ignore_existing=True):
        c = constant_op.constant(4.0, name="c")
    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
    self.assertEqual([b"loc:@c"], c.op.colocation_groups())

  @test_util.run_deprecated_v1
  def testColocateWithInitialNoneThenNested(self):
    a = constant_op.constant([2.0], name="a")
    with ops.colocate_with(a.op):
      with ops.colocate_with(None, ignore_existing=True):
        b = constant_op.constant(3.0, name="b")
        with ops.colocate_with(b.op):
          c = constant_op.constant(4.0, name="c")
    self.assertEqual([b"loc:@b"], b.op.colocation_groups())
    self.assertEqual([b"loc:@b"], c.op.colocation_groups())

  @test_util.run_deprecated_v1
  def testColocateVariables(self):
    a = variables.Variable([2.0], name="a")
    with ops.colocate_with(a.op):
      b = variables.Variable([3.0], name="b")
    self.assertEqual([b"loc:@a"], b.op.colocation_groups())

  @test_util.run_deprecated_v1
  def testColocateResourceVariablesInFunction(self):
    with ops.device("/device:CPU:0"):
      a = resource_variable_ops.ResourceVariable(1.0)

    @def_function.function
    def f():
      with ops.colocate_with(a):
        b = array_ops.ones([], name="output")
        self.assertEqual("/device:CPU:0", b.op.device)
    f()

  def testColocateWithVariableInFunction(self):
    v = variables.Variable(1.)

    @def_function.function
    def f():
      with ops.colocate_with(v):
        return array_ops.ones([], name="output")

    f()
    graph_def = f.get_concrete_function().graph.as_graph_def()
    wrap_function.function_from_graph_def(graph_def, [], ["output"])


class DeadlineTest(test_util.TensorFlowTestCase):

  def testNoDeadlineSet(self):
    with ops.Graph().as_default() as g:
      get_deadline = test_ops.get_deadline()
      with self.session(graph=g) as sess:
        run_options = config_pb2.RunOptions()
        with self.assertRaises(errors.InvalidArgumentError):
          sess.run(get_deadline, options=run_options)

  def testDeadlineSetTimesOut(self):
    with ops.Graph().as_default() as g:
      sleep_op = test_ops.sleep_op(10)
      with self.session(graph=g) as sess:
        run_options = config_pb2.RunOptions(timeout_in_ms=3_000)
        with self.assertRaises(errors.DeadlineExceededError):
          sess.run(sleep_op, options=run_options)


class DeprecatedTest(test_util.TensorFlowTestCase):

  def testSuccess(self):
    with ops.Graph().as_default() as g:
      test_util.set_producer_version(g, 7)
      old = test_ops.old()
      with self.session(graph=g):
        old.run()

  def _error(self):
    return ((r"Op Old is not available in GraphDef version %d\. "
             r"It has been removed in version 8\. For reasons\.") %
            versions.GRAPH_DEF_VERSION)

  def testGraphConstructionFail(self):
    with ops.Graph().as_default():
      with self.assertRaisesRegex(NotImplementedError, self._error()):
        test_ops.old()


class NameScopeTest(test_util.TensorFlowTestCase):

  def testStripAndPrependScope(self):
    strs = [
        "hidden1/hidden1/weights",  # Same prefix. Should strip.
        "hidden1///hidden1/weights",  # Extra "/". Should strip.
        "^hidden1/hidden1/weights",  # Same prefix. Should strip.
        "loc:@hidden1/hidden1/weights",  # Same prefix. Should strip.
        "hhidden1/hidden1/weights",  # Different prefix. Should keep.
        "hidden1"
    ]  # Not a prefix. Should keep.
    expected_striped = [
        "hidden1/weights", "hidden1/weights", "^hidden1/weights",
        "loc:@hidden1/weights", "hhidden1/hidden1/weights", "hidden1"
    ]
    expected_prepended = [
        "hidden2/hidden1/weights", "hidden2/hidden1/weights",
        "^hidden2/hidden1/weights", "loc:@hidden2/hidden1/weights",
        "hidden2/hhidden1/hidden1/weights", "hidden2/hidden1"
    ]
    name_scope_to_strip = "hidden1"
    name_scope_to_add = "hidden2"
    for es, ep, s in zip(expected_striped, expected_prepended, strs):
      striped = ops.strip_name_scope(s, name_scope_to_strip)
      self.assertEqual(es, striped)
      self.assertEqual(ep, ops.prepend_name_scope(striped, name_scope_to_add))

  def testGetNameScope(self):
    with ops.Graph().as_default() as g:
      with ops.name_scope("scope1"):
        with ops.name_scope("scope2"):
          with ops.name_scope("scope3"):
            self.assertEqual("scope1/scope2/scope3", g.get_name_scope())
          self.assertEqual("scope1/scope2", g.get_name_scope())
        self.assertEqual("scope1", g.get_name_scope())
      self.assertEqual("", g.get_name_scope())

  def testTwoGraphs(self):

    def f():
      g1 = ops.Graph()
      g2 = ops.Graph()
      with g1.as_default():
        with g2.as_default():
          with ops.name_scope("_"):
            pass

    self.assertRaisesRegex(ValueError,
                           "'_' is not a valid (?:root )?scope name", f)


class EnableEagerExecutionTest(test_util.TensorFlowTestCase):

  @test_util.run_v1_only("b/120545219")
  def testBadArgumentsToEnableEagerExecution(self):
    with self.assertRaisesRegex(TypeError, "config must be a tf.ConfigProto"):
      ops.enable_eager_execution(context.DEVICE_PLACEMENT_SILENT)
    with self.assertRaisesRegex(ValueError, "device_policy must be one of"):
      c = config_pb2.ConfigProto()
      ops.enable_eager_execution(c, c)
    with self.assertRaisesRegex(ValueError, "execution_mode must be one of"):
      c = config_pb2.ConfigProto()
      ops.enable_eager_execution(c, execution_mode=c)


class _TupleTensor(composite_tensor.CompositeTensor):
  """`Tensor`-like `tuple`-like for custom `Tensor` conversion masquerading."""

  def __init__(self, components):
    super(_TupleTensor, self).__init__()
    self._components = tuple(ops.convert_to_tensor(c) for c in components)

  @property
  def _type_spec(self):
    return _TupleTensorSpec(type_spec.from_value(c) for c in self._components)

  def __getitem__(self, key):
    return self._components[key]

  def __len__(self):
    return len(self._components)

  def __iter__(self):
    return iter(self._components)


class _TupleTensorSpec(type_spec.TypeSpec):

  def __init__(self, specs):
    self._specs = specs

  value_type = property(lambda self: _TupleTensor)
  _component_specs = property(lambda self: self._specs)

  def _to_components(self, value):
    return value._components

  def _from_components(self, components):
    return _TupleTensor(*components)

  def _serialize(self):
    return (self._specs,)


class _MyTuple(object):
  """Pretend user-side class for `ConvertToCompositeTensorTest ."""

  def __init__(self, components):
    super(_MyTuple, self).__init__()
    self._components = tuple(components)

  def __getitem__(self, key):
    return self._components[key]

  def __len__(self):
    return len(self._components)

  def __iter__(self):
    return iter(self._components)


tensor_conversion_registry.register_tensor_conversion_function(
    _MyTuple, conversion_func=lambda x, *_, **__: _TupleTensor(x))


class CustomConvertToCompositeTensorTest(test_util.TensorFlowTestCase):

  @test_util.disable_tfrt("TODO(kkb): This makes Kokoro tests fail.")
  def testCompositeTensorConversion(self):
    """Tests that a user can register a CompositeTensor converter."""
    x = _MyTuple((1, [2., 3.], [[4, 5], [6, 7]]))
    y = ops.convert_to_tensor_or_composite(x)
    self.assertFalse(tensor_util.is_tf_type(y))
    self.assertIsInstance(y, _TupleTensor)
    self.assertLen(y, len(x))
    for x_, y_ in zip(x, y):
      self.assertIsInstance(y_, tensor_lib.Tensor)
      self.assertTrue(tensor_util.is_tf_type(y_))
      self.assertAllEqual(x_, tensor_util.constant_value(y_))


@test_util.disable_tfrt("Packing EagerTensors is not supported yet.")
class PackEagerTensorTest(test_util.TensorFlowTestCase):

  def setUp(self):
    super(PackEagerTensorTest, self).setUp()
    context._reset_context()
    cpus = config.list_physical_devices("CPU")
    # Set 2 virtual CPUs
    config.set_logical_device_configuration(cpus[0], [
        context.LogicalDeviceConfiguration(),
        context.LogicalDeviceConfiguration(),
    ])

  def testPack(self):
    with context.eager_mode():
      with ops.device("CPU:0"):
        var0 = resource_variable_ops.ResourceVariable(1.0)
        c0 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
      with ops.device("CPU:1"):
        var1 = resource_variable_ops.ResourceVariable(2.0)
        var2 = resource_variable_ops.ResourceVariable([3.0])
        c1 = constant_op.constant([9.0])

      packed_var0 = ops.pack_eager_tensors([var0.handle, var1.handle])
      self.assertTrue(packed_var0.is_packed)
      self.assertEqual(packed_var0.dtype, var0.handle.dtype)
      self.assertEqual(packed_var0.shape, var0.handle.shape)
      self.assertEqual(packed_var0._handle_data, var0.handle._handle_data)
      self.assertIn("COMPOSITE:0", packed_var0.device)
      self.assertIn("COMPOSITE:0", packed_var0.backing_device)
      with self.assertRaises(errors.InvalidArgumentError):
        packed_var0.numpy()

      # Different dtypes
      with self.assertRaises(ValueError):
        ops.pack_eager_tensors([var0.handle, c1])

      # Different shapes
      with self.assertRaises(ValueError):
        ops.pack_eager_tensors([c0, c1])

      # Different handle data
      with self.assertRaises(ValueError):
        ops.pack_eager_tensors([var0.handle, var2.handle])


class GraphDefInputShapesTest(test_util.TensorFlowTestCase):

  def setUpInputShapes(self, pre_add_input_shapes):

    test_tensor_shape = [None, 1, 1, 1]

    @def_function.function(input_signature=[
        tensor_lib.TensorSpec(shape=test_tensor_shape, dtype=dtypes.float32)
    ])
    def f(x):
      return array_ops.identity(x, name="output")

    x = array_ops.ones([2, 1, 1, 1], dtype=dtypes.float32)
    f(x)

    tensor_shape_proto = tensor_shape_pb2.TensorShapeProto(dim=[
        tensor_shape_pb2.TensorShapeProto.Dim(size=-1 if d is None else d)
        for d in test_tensor_shape
    ])
    list_proto = attr_value_pb2.AttrValue.ListValue(shape=[tensor_shape_proto])
    concrete_function = f.get_concrete_function()
    if pre_add_input_shapes:
      attr_value = attr_value_pb2.AttrValue(list=list_proto)
      concrete_function = eager_function.ConcreteFunction.from_func_graph(
          concrete_function.graph,
          concrete_function.function_type,
          attrs={"_input_shapes": attr_value},
      )

    test_graph = ops.Graph()
    with test_graph.as_default():
      concrete_function.add_to_graph(g=test_graph)
    graph_def = test_graph.as_graph_def(add_shapes=True)
    self.assertLen(graph_def.library.function, 1)
    function_def = graph_def.library.function[0]
    input_shapes = function_def.attr["_input_shapes"]
    return input_shapes

  def testGraphDefInputShapes(self):
    pre_added_input_shapes = self.setUpInputShapes(pre_add_input_shapes=True)
    post_added_input_shapes = self.setUpInputShapes(pre_add_input_shapes=False)
    self.assertProtoEquals(pre_added_input_shapes, post_added_input_shapes)


class TensorTest(test_util.TensorFlowTestCase):

  def testToArrayEagerMode(self):

    with context.eager_mode():
      a = np.array(constant_op.constant(32), dtype=np.float32)
      b = np.array(constant_op.constant(32, dtype=dtypes.int64))

      self.assertEqual(a.dtype, np.dtype(np.float32))
      self.assertEqual(b.dtype, np.dtype(np.int64))

  def testToArrayFunctionMode(self):

    @def_function.function
    def f():
      # Raises during trace compilation.
      return np.array(constant_op.constant(32), dtype=np.int32)

    @def_function.function
    def g():
      # Raises during trace compilation.
      return np.array(constant_op.constant(32))

    with self.assertRaisesRegex(NotImplementedError,
                                "Cannot convert a symbolic tf.Tensor"):
      f()

    with self.assertRaisesRegex(NotImplementedError,
                                "Cannot convert a symbolic tf.Tensor"):
      g()


if __name__ == "__main__":
  googletest.main()