tensorflow/tensorflow

View on GitHub
tensorflow/python/kernel_tests/data_structures/list_ops_test.py

Summary

Maintainability
F
1 wk
Test Coverage
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ops which manipulate lists of tensors."""

# pylint: disable=g-bad-name
from absl.testing import parameterized
import numpy as np  # pylint: disable=unused-import

from tensorflow.core.framework import full_type_pb2
from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import cond
from tensorflow.python.ops import gen_list_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import while_loop
from tensorflow.python.platform import test


@test_util.run_all_in_graph_and_eager_modes
class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):

  def _testPushPop(self, max_num_elements):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=[],
        max_num_elements=max_num_elements)
    l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
    l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
    l = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
    l, e = self.evaluate((l, e))
    self.assertAllEqual(l, [])
    self.assertAllEqual(e, 1.0)

  @parameterized.named_parameters(("NoMaxNumElements", None),
                                  ("WithMaxNumElements", 2))
  def testPushPop(self, max_num_elements):
    self._testPushPop(max_num_elements)

  @parameterized.named_parameters(("NoMaxNumElements", None),
                                  ("WithMaxNumElements", 2))
  @test_util.run_gpu_only
  def testPushPopGPU(self, max_num_elements):
    with context.device("gpu:0"):
      self._testPushPop(max_num_elements)

  @test_util.run_deprecated_v1
  def testPushInFullListFails(self):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=[], max_num_elements=1)
    l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
    with self.assertRaisesRegex(errors.InvalidArgumentError,
                                "Tried to push item into a full list"):
      l = list_ops.tensor_list_push_back(l, 2.)
      self.evaluate(l)

  @parameterized.named_parameters(("NoMaxNumElements", None),
                                  ("WithMaxNumElements", 2))
  @test_util.run_deprecated_v1
  def testPopFromEmptyTensorListFails(self, max_num_elements):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=[],
        max_num_elements=max_num_elements)
    with self.assertRaisesRegex(errors.InvalidArgumentError,
                                "Trying to pop from an empty list"):
      l = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
      self.evaluate(l)

  def testTensorListReserveWithNonScalarNumElements(self):
    # list_kernels.cc in tf/core/kernels raises InvalidArgumentError, and
    # tf_ops_n_z.cc in tf/compiler/mlir/tf/ir raises UnknownError.
    with self.assertRaises((errors.InvalidArgumentError, errors.UnknownError)):
      l = list_ops.tensor_list_reserve(
          element_dtype=dtypes.float32,
          element_shape=[2, 3],
          num_elements=constant_op.constant([1, 1]))
      self.evaluate(l)

  def testPopUninitializedTensorUseListElementShape(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=[2, 3], num_elements=3)
    _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
    l = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
    l, e = self.evaluate((l, e))
    self.assertAllEqual(e, np.zeros((2, 3)))
    self.assertAllEqual(l, np.zeros((3, 2, 3)))

  def testPopUninitializedTensorUseSpecifiedElementShape(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=[None, 3], num_elements=3)
    _, e = gen_list_ops.tensor_list_pop_back(
        l, element_dtype=dtypes.float32, element_shape=[4, 3])
    self.assertAllEqual(e, np.zeros((4, 3)))

  def testPopUninitializedTensorWithInvalidElementShapeFails(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=None, num_elements=3)
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        "Trying to read an uninitialized tensor but "
        "element_shape is not fully defined"):
      _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
      self.evaluate(e)

    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=[None, 2], num_elements=3)
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        r"Incompatible shapes during merge: \[1,3\] vs. \[\?,2\]"):
      _, e = gen_list_ops.tensor_list_pop_back(
          l, element_dtype=dtypes.float32, element_shape=[1, 3])
      self.evaluate(e)

  def testPushGetGrad(self):
    with backprop.GradientTape() as tape:
      l = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32, element_shape=None)
      c0 = constant_op.constant(5.0)
      c1 = constant_op.constant([10.0, 20.0])
      tape.watch(c0)
      tape.watch(c1)
      l = list_ops.tensor_list_push_back(l, c0)
      l = list_ops.tensor_list_push_back(l, c1)
      t1 = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)
      self.assertAllEqual(self.evaluate(t1), [10.0, 20.0])
      # t1 == c1 so the gradient should be [0., [1., 1.]]
      # This tests that the gradient of push_back correctly converts DT_INVALID
      # tensors to zeros. The list returned by the gradient of GetItem will
      # have only have tensor at index 1 set and others set to DT_INVALID.
      dt0, dt1 = tape.gradient(t1, [c0, c1])
      self.assertAllEqual(self.evaluate(dt1), [1.0, 1.0])
      self.assertEqual(self.evaluate(dt0), 0.0)

  def _testStack(self, max_num_elements):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=[],
        max_num_elements=max_num_elements)
    l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
    l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
    t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
    if not context.executing_eagerly():
      self.assertAllEqual(t.shape.as_list(), [None])
    self.assertAllEqual(self.evaluate(t), [1.0, 2.0])

  @parameterized.named_parameters(("NoMaxNumElements", None),
                                  ("WithMaxNumElements", 2))
  def testStack(self, max_num_elements):
    self._testStack(max_num_elements)

  @parameterized.named_parameters(("NoMaxNumElements", None),
                                  ("WithMaxNumElements", 2))
  @test_util.run_gpu_only
  def testStackGPU(self, max_num_elements):
    with context.device("gpu:0"):
      self._testStack(max_num_elements)

  @parameterized.named_parameters(("NoMaxNumElements", None),
                                  ("WithMaxNumElements", 3))
  @test_util.run_deprecated_v1
  def testStackWithUnknownElementShape(self, max_num_elements):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=None,
        max_num_elements=max_num_elements)
    l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
    l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))

    t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [1.0, 2.0])

    # Should raise an error when the element tensors do not all have the same
    # shape.
    with self.assertRaisesRegex(errors.InvalidArgumentError,
                                "Incompatible ranks during merge: 0 vs. 1"):
      l = list_ops.tensor_list_push_back(l, constant_op.constant([3.0, 4.0]))
      t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
      self.evaluate(t)

  @parameterized.named_parameters(("NoMaxNumElements", None),
                                  ("WithMaxNumElements", 3))
  @test_util.run_deprecated_v1
  def testStackWithPartiallyDefinedElementShape(self, max_num_elements):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=[None],
        max_num_elements=max_num_elements)
    l = list_ops.tensor_list_push_back(l, constant_op.constant([1.0]))
    l = list_ops.tensor_list_push_back(l, constant_op.constant([2.0]))

    t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [[1.0], [2.0]])

    # Should raise an error when the element tensors do not all have the same
    # shape.
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        r"Incompatible shapes during merge: \[1\] vs. \[2\]"):
      l = list_ops.tensor_list_push_back(l, constant_op.constant([2.0, 3.0]))
      t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
      self.evaluate(t)

  @parameterized.named_parameters(("NoMaxNumElements", None),
                                  ("WithMaxNumElements", 2))
  @test_util.run_deprecated_v1
  def testStackEmptyList(self, max_num_elements):
    # Should be able to stack empty lists with fully defined element_shape.
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=[1, 2],
        max_num_elements=max_num_elements)
    t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t).shape, (0, 1, 2))

    # Should not be able to stack empty lists with partially defined
    # element_shape.
    with self.assertRaisesRegex(errors.InvalidArgumentError,
                                "non-fully-defined"):
      l = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32,
          element_shape=[None, 2],
          max_num_elements=max_num_elements)
      t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
      self.evaluate(t)

    # Should not be able to stack empty lists with undefined element_shape.
    with self.assertRaisesRegex(errors.InvalidArgumentError,
                                "non-fully-defined"):
      l = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32,
          element_shape=None,
          max_num_elements=max_num_elements)
      t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
      self.evaluate(t)

  def _testStackWithUninitializedTensors(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=[], num_elements=3)
    t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
    self.assertAllEqual(t, [0., 0., 0.])

  def testStackWithUninitializedTensors(self):
    self._testStackWithUninitializedTensors()

  @test_util.run_gpu_only
  def testStackWithUninitializedTensorsGpu(self):
    with context.device("gpu:0"):
      self._testStackWithUninitializedTensors()

  def _testStackWithUninitializedTensorsInferShape(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=None, num_elements=3)
    l = list_ops.tensor_list_set_item(l, 1, [1., 2.])
    t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
    self.assertAllEqual(t, [[0., 0.], [1., 2.], [0., 0.]])

  def testStackWithUninitializedTensorsInferShape(self):
    self._testStackWithUninitializedTensorsInferShape()

  @test_util.run_gpu_only
  def testStackWithUninitializedTensorsInferShapeGpu(self):
    with context.device("gpu:0"):
      self._testStackWithUninitializedTensorsInferShape()

  def testStackReservedListWithNoElementsAndPartialElementShapeFails(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=None, num_elements=3)
    with self.assertRaisesRegex(
        errors.InvalidArgumentError, "Tried to stack list which only contains "
        "uninitialized tensors and has a "
        "non-fully-defined element_shape: <unknown>"):
      t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
      self.evaluate(t)

  def testStackUsingSpecifiedElementShape(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=None, num_elements=3)
    t = gen_list_ops.tensor_list_stack(
        l, element_dtype=dtypes.float32, element_shape=[])
    if context.executing_eagerly():
      self.assertEqual(t.shape.as_list(), [3])
    else:
      self.assertEqual(t.shape.as_list(), [None])
    self.assertAllEqual(self.evaluate(t), np.zeros((3,)))

  @parameterized.named_parameters(("NoMaxNumElements", None),
                                  ("WithMaxNumElements", 2))
  def testGatherGrad(self, max_num_elements):
    with backprop.GradientTape() as tape:
      l = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32,
          element_shape=[],
          max_num_elements=max_num_elements)
      c0 = constant_op.constant(1.0)
      tape.watch(c0)
      l = list_ops.tensor_list_push_back(l, c0)
      l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
      t = list_ops.tensor_list_gather(l, [1, 0], element_dtype=dtypes.float32)
      self.assertAllEqual(self.evaluate(t), [2.0, 1.0])
      s = (t[0] + t[1]) * (t[0] + t[1])
    dt = tape.gradient(s, c0)
    self.assertAllEqual(self.evaluate(dt), 6.0)

  @parameterized.named_parameters(("NoMaxNumElements", None),
                                  ("WithMaxNumElements", 3))
  @test_util.run_deprecated_v1
  def testGatherWithUnknownElementShape(self, max_num_elements):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=None,
        max_num_elements=max_num_elements)
    l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
    l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
    l = list_ops.tensor_list_push_back(l, constant_op.constant([3.0, 4.0]))

    t = list_ops.tensor_list_gather(l, [1, 0], element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [2.0, 1.0])

    t = list_ops.tensor_list_gather(l, [2], element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [[3.0, 4.0]])

    # Should raise an error when the requested tensors do not all have the same
    # shape.
    with self.assertRaisesRegex(errors.InvalidArgumentError,
                                "Incompatible ranks during merge: 0 vs. 1"):
      t = list_ops.tensor_list_gather(l, [0, 2], element_dtype=dtypes.float32)
      self.evaluate(t)

  @parameterized.named_parameters(("NoMaxNumElements", None),
                                  ("WithMaxNumElements", 3))
  @test_util.run_deprecated_v1
  def testGatherWithPartiallyDefinedElementShape(self, max_num_elements):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=[None],
        max_num_elements=max_num_elements)
    l = list_ops.tensor_list_push_back(l, constant_op.constant([1.0]))
    l = list_ops.tensor_list_push_back(l, constant_op.constant([2.0, 3.0]))
    l = list_ops.tensor_list_push_back(l, constant_op.constant([4.0, 5.0]))

    t = list_ops.tensor_list_gather(l, [0], element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [[1.0]])

    t = list_ops.tensor_list_gather(l, [1, 2], element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [[2.0, 3.0], [4.0, 5.0]])

    # Should raise an error when the requested tensors do not all have the same
    # shape.
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        r"Incompatible shapes during merge: \[1\] vs. \[2\]"):
      t = list_ops.tensor_list_gather(l, [0, 2], element_dtype=dtypes.float32)
      self.evaluate(t)

  @parameterized.named_parameters(("NoMaxNumElements", None),
                                  ("WithMaxNumElements", 3))
  @test_util.run_deprecated_v1
  def testGatherEmptyList(self, max_num_elements):
    # Should be able to gather from empty lists with fully defined
    # element_shape.
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=[1, 2],
        max_num_elements=max_num_elements)
    t = list_ops.tensor_list_gather(l, [], element_dtype=dtypes.float32)
    self.assertAllEqual((0, 1, 2), self.evaluate(t).shape)

    # Should not be able to gather from empty lists with partially defined
    # element_shape.
    with self.assertRaisesRegex(errors.InvalidArgumentError,
                                "non-fully-defined"):
      l = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32,
          element_shape=[None, 2],
          max_num_elements=max_num_elements)
      t = list_ops.tensor_list_gather(l, [], element_dtype=dtypes.float32)
      self.evaluate(t)

    # Should not be able to gather from empty lists with undefined
    # element_shape.
    with self.assertRaisesRegex(errors.InvalidArgumentError,
                                "non-fully-defined"):
      l = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32,
          element_shape=None,
          max_num_elements=max_num_elements)
      t = list_ops.tensor_list_gather(l, [], element_dtype=dtypes.float32)
      self.evaluate(t)

  def testGatherGradWithNonContiguousIndices(self):
    with backprop.GradientTape(persistent=True) as tape:
      t = constant_op.constant([1.0, 2.0, 3.0])
      l = list_ops.tensor_list_from_tensor(t, element_shape=[])
      c = constant_op.constant(5.0)
      tape.watch(c)
      l = list_ops.tensor_list_set_item(l, 1, c)
      t = list_ops.tensor_list_gather(l, [1], element_dtype=dtypes.float32)
      self.assertAllEqual(self.evaluate(t), [5.0])
      s = t[0] * t[0]
    dt = tape.gradient(s, c)
    self.assertAllEqual(self.evaluate(dt), 10.0)
    dl = tape.gradient(t, l)
    dl_length = list_ops.tensor_list_length(dl)
    self.assertAllEqual(self.evaluate(dl_length), 3)

  def _testGatherWithUninitializedTensors(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=[], num_elements=3)
    t = list_ops.tensor_list_gather(l, [0, 2], element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [0., 0.])

  def testGatherWithUninitializedTensors(self):
    self._testGatherWithUninitializedTensors()

  @test_util.run_gpu_only
  def testGatherWithUninitializedTensorsGpu(self):
    with context.device("gpu:0"):
      self._testGatherWithUninitializedTensors()

  def _testGatherWithUninitializedTensorsInferShape(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=None, num_elements=3)
    l = list_ops.tensor_list_set_item(l, 1, [1., 2.])
    t = list_ops.tensor_list_gather(l, [1, 2], element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [[1., 2.], [0., 0.]])

  def testGatherWithUninitializedTensorsInferShape(self):
    self._testGatherWithUninitializedTensorsInferShape()

  @test_util.run_gpu_only
  def testGatherWithUninitializedTensorsInferShapeGpu(self):
    with context.device("gpu:0"):
      self._testGatherWithUninitializedTensorsInferShape()

  def testGatherReservedListWithNoElementsAndPartialElementShapeFails(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=None, num_elements=3)
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        "Tried to gather uninitialized tensors from a"
        " list with non-fully-defined element_shape"):
      t = list_ops.tensor_list_gather(l, [0], element_dtype=dtypes.float32)
      self.evaluate(t)

  def testGatherUsingSpecifiedElementShape(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=None, num_elements=3)
    t = gen_list_ops.tensor_list_gather(
        l, [0, 1, 2], element_dtype=dtypes.float32, element_shape=[])
    self.assertEqual(t.shape.as_list(), [3])
    self.assertAllEqual(self.evaluate(t), np.zeros((3,)))

  def testGatherWithInvalidIndicesFails(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=None, num_elements=3
    )

    # Should raise an error when the input index is negative.
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        "Trying to gather element -1 in a list with 3 elements.",
    ):
      t = list_ops.tensor_list_gather(l, [-1], element_dtype=dtypes.float32)
      self.evaluate(t)

    # Should raise an error when the input index is larger than the number of
    # elements in the list.
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        "Trying to gather element 3 in a list with 3 elements.",
    ):
      t = list_ops.tensor_list_gather(l, [3], element_dtype=dtypes.float32)
      self.evaluate(t)

  def testScatterOutputListSize(self):
    c0 = constant_op.constant([1.0, 2.0])
    l = list_ops.tensor_list_scatter(c0, [1, 3], [])
    # TensorListScatter should return a list with size largest index + 1.
    self.assertAllEqual(list_ops.tensor_list_length(l), 4)

  def testScatterOutputListSizeWithNumElementsSpecified(self):
    c0 = constant_op.constant([1.0, 2.0])
    l = gen_list_ops.tensor_list_scatter_v2(
        c0, [1, 3], list_ops._build_element_shape([]), num_elements=5)
    # TensorListScatter should return a list with size num_elements.
    self.assertAllEqual(list_ops.tensor_list_length(l), 5)

  def testScatterFailsWhenElementShapeIsNotVector(self):
    c0 = constant_op.constant([1.0, 2.0])
    # In Eager mode, InvalidArgumentError is generated by the Compute function.
    # In graph mode, ValueError is generated by the shape function.
    with self.assertRaisesRegex(
        (errors.InvalidArgumentError, ValueError),
        "must be at most rank 1"):
      l = gen_list_ops.tensor_list_scatter(
          # Wrong element_shape. Should be at most rank 1.
          c0, [1, 3], element_shape=[[1]])
      self.evaluate(l)

  def testScatterV2FailsWhenElementShapeIsNotVector(self):
    c0 = constant_op.constant([1.0, 2.0])
    # In Eager mode, InvalidArgumentError is generated by the Compute function.
    # In graph mode, ValueError is generated by the shape function.
    with self.assertRaisesRegex(
        (errors.InvalidArgumentError, ValueError),
        "must be at most rank 1"):
      l = gen_list_ops.tensor_list_scatter_v2(
          # Wrong element_shape. Should be at most rank 1.
          c0, [1, 3], element_shape=[[1]], num_elements=2)
      self.evaluate(l)

  def testScatterFailsWhenIndexLargerThanNumElements(self):
    c0 = constant_op.constant([1.0, 2.0])
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        "TensorListScatter: Trying to scatter at index 3 in list with size 3"):
      l = gen_list_ops.tensor_list_scatter_v2(
          c0, [1, 3], list_ops._build_element_shape([]), num_elements=3)
      self.evaluate(l)

  def testScatterFailsWithInvalidNumElements(self):
    c0 = constant_op.constant([1.0, 2.0])
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        "TensorListScatter expects num_elements >= -1, found: -2"):
      l = gen_list_ops.tensor_list_scatter_v2(
          c0, [1, 3], list_ops._build_element_shape([]), num_elements=-2)
      self.evaluate(l)

  def testScatterWithInvalidRowsInInputTensorFails(self):
    c0 = constant_op.constant([1.0, 2.0])
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        "Invalid number of rows in input tensor. Expected: 3 Actual: 2"):
      l = list_ops.tensor_list_scatter(c0, [1, 0, 2], [])
      self.evaluate(l)

  def testScatterWithNegativeIndicesFails(self):
    c0 = constant_op.constant([1.0, 2.0])
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        "Indices in TensorListScatter must all be non-negative."):
      l = list_ops.tensor_list_scatter(c0, [-1, -2], element_shape=[])
      self.evaluate(l)

  @test_util.run_in_graph_and_eager_modes
  def testScatterWithNonScalarFails(self):
    c = constant_op.constant(value=[2])
    num_elements = np.array([[], [], []], dtype=np.float32)
    with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
                                r"Shape must be rank 0 but is rank \d+|"
                                r"\w+ must be a scalar"):
      self.evaluate(
          gen_list_ops.TensorListScatterV2(
              tensor=c, indices=c, element_shape=c, num_elements=num_elements))

  def testScatterIntoExistingList(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=[], num_elements=3)
    l = list_ops.tensor_list_scatter(tensor=[1.], indices=[0], element_shape=[])
    l = list_ops.tensor_list_scatter(
        tensor=[2., 3.], indices=[1, 2], element_shape=[], input_handle=l)
    self.assertAllEqual(
        list_ops.tensor_list_stack(l, element_dtype=dtypes.float32),
        [1., 2., 3.])

  def testScatterGrad(self):
    with backprop.GradientTape() as tape:
      c0 = constant_op.constant([1.0, 2.0])
      tape.watch(c0)
      l = list_ops.tensor_list_scatter(c0, [1, 0], element_shape=[])
      t0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
      t1 = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)
      self.assertAllEqual(self.evaluate(t0), 2.0)
      self.assertAllEqual(self.evaluate(t1), 1.0)
      loss = t0 * t0 + t1 * t1
    dt = tape.gradient(loss, c0)
    self.assertAllEqual(self.evaluate(dt), [2., 4.])

  def testScatterWithPartialReadGrad(self):
    with backprop.GradientTape() as tape:
      c0 = constant_op.constant([1.0, 2.0])
      tape.watch(c0)
      l = list_ops.tensor_list_scatter(c0, [1, 0], element_shape=[])
      t0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
      self.assertAllEqual(self.evaluate(t0), 2.0)
      loss = t0 * t0
    dt = tape.gradient(loss, c0)
    self.assertAllEqual(self.evaluate(dt), [0., 4.])

  def testTensorListFromTensor(self):
    t = constant_op.constant([1.0, 2.0])
    l = list_ops.tensor_list_from_tensor(t, element_shape=[])
    e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
    self.assertAllEqual(e, 1.0)
    l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
    self.assertAllEqual(e, 2.0)
    l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
    self.assertAllEqual(e, 1.0)
    self.assertAllEqual(list_ops.tensor_list_length(l), 0)

  def testTensorListFromTensorFailsWhenElementShapeIsNotVector(self):
    t = constant_op.constant([1.0, 2.0])
    # In Eager mode, InvalidArgumentError is generated by the Compute function.
    # In graph mode, ValueError is generated by the shape function.
    with self.assertRaisesRegex(
        (errors.InvalidArgumentError, ValueError),
        "must be at most rank 1"):
      # Wrong element_shape. Should be at most rank 1.
      l = list_ops.tensor_list_from_tensor(t, element_shape=[[1]])
      self.evaluate(l)

  @test_util.run_gpu_only
  def testFromTensorGPU(self):
    with context.device("gpu:0"):
      self.testTensorListFromTensor()

  def testGetSetBool(self):
    t = constant_op.constant([True, False])
    l = list_ops.tensor_list_from_tensor(t, element_shape=[])
    e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.bool)
    self.assertAllEqual(self.evaluate(e0), True)
    l = list_ops.tensor_list_set_item(l, 0, False)
    t = list_ops.tensor_list_stack(l, element_dtype=dtypes.bool)
    self.assertAllEqual(self.evaluate(t), [False, False])

  @test_util.run_gpu_only
  def testGetSetBoolGPU(self):
    with context.device("gpu:0"):
      self.testGetSetBool()

  def _testGetSetNumeric(self, dtype):
    t = constant_op.constant([1.0, 2.0], dtype=dtype)
    l = list_ops.tensor_list_from_tensor(t, element_shape=[])
    e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtype)
    self.assertAllEqual(self.evaluate(e0), 1.0)
    l = list_ops.tensor_list_set_item(
        l, 0, constant_op.constant(3.0, dtype=dtype))
    t = list_ops.tensor_list_stack(l, element_dtype=dtype)
    self.assertAllEqual(self.evaluate(t), [3.0, 2.0])

  @parameterized.parameters([dtypes.float32, dtypes.float64,
                             dtypes.complex64, dtypes.complex128])
  def testGetSetNumeric(self, dtype):
    self._testGetSetNumeric(dtype)

  @parameterized.parameters([dtypes.float32, dtypes.float64,
                             dtypes.complex64, dtypes.complex128])
  @test_util.run_gpu_only
  def testGetSetNumericGPU(self, dtype):
    with context.device("gpu:0"):
      self._testGetSetNumeric(dtype)

  def testGetSetReserved(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=[], num_elements=2)
    e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
    self.assertAllEqual(e0, 0.0)
    l = list_ops.tensor_list_set_item(l, 0, 3.0)
    t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
    self.assertAllEqual(t, [3.0, 0.0])

  @test_util.run_gpu_only
  def testGetSetReservedGPU(self):
    with context.device("gpu:0"):
      self.testGetSetReserved()

  def testSetGetGrad(self):
    with backprop.GradientTape() as tape:
      t = constant_op.constant(5.)
      tape.watch(t)
      l = list_ops.tensor_list_reserve(
          element_dtype=dtypes.float32, element_shape=[], num_elements=3)
      l = list_ops.tensor_list_set_item(l, 1, 2. * t)
      e = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)
      self.assertAllEqual(self.evaluate(e), 10.0)
    self.assertAllEqual(self.evaluate(tape.gradient(e, t)), 2.0)

  def testGetUninitializedTensorUseListElementShape(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=[], num_elements=3)
    l = list_ops.tensor_list_set_item(l, 0, 5.)
    e1 = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)
    e2 = list_ops.tensor_list_get_item(l, 2, element_dtype=dtypes.float32)
    self.assertEqual(self.evaluate(e1), 0.)
    self.assertEqual(self.evaluate(e2), 0.)

  def testGetUninitializedTensorUseSpecifiedElementShape(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=None, num_elements=3)
    e0 = gen_list_ops.tensor_list_get_item(
        l, 0, element_shape=[], element_dtype=dtypes.float32)
    e1 = gen_list_ops.tensor_list_get_item(
        l, 1, element_shape=[2, 3], element_dtype=dtypes.float32)
    self.assertEqual(e0.shape.as_list(), [])
    self.assertEqual(e1.shape.as_list(), [2, 3])
    self.assertEqual(self.evaluate(e0), 0.)
    self.assertAllEqual(self.evaluate(e1), np.zeros((2, 3)))

    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=[None, 3], num_elements=3)
    e1 = gen_list_ops.tensor_list_get_item(
        l, 1, element_shape=[2, 3], element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(e1), np.zeros((2, 3)))

  def testGetUninitializedTensorWithInvalidElementShapeFails(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=None, num_elements=3)
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        "Trying to read an uninitialized tensor but "
        "element_shape is not fully defined"):
      e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
      self.evaluate(e0)

    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=[None, 2], num_elements=3)

    # In eager mode the shape mismatch is caught in the TensorListGetItem
    # kernel which raises an InvalidArgumentError.
    # In graph mode the shape mismatch is caught in the C++ shape inference
    # which raises a ValueError.
    if context.executing_eagerly():
      error_type = errors.InvalidArgumentError
    else:
      error_type = ValueError
    with self.assertRaisesRegex(error_type, r"shapes"):
      e0 = gen_list_ops.tensor_list_get_item(
          l, 0, element_dtype=dtypes.float32, element_shape=[1, 3])
      self.evaluate(e0)

  @test_util.run_deprecated_v1
  @test_util.enable_control_flow_v2
  def testSkipEagerSetItemIndexOutOfBounds(self):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=[])
    e0 = constant_op.constant(5.)
    l = list_ops.tensor_list_set_item(
        l, 0, 2. * e0, resize_if_index_out_of_bounds=True)
    l = list_ops.tensor_list_set_item(
        l, 1, 1., resize_if_index_out_of_bounds=True)
    t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
    grad = gradients_impl.gradients(t, e0)[0]
    self.assertAllEqual(self.evaluate(grad), 2.)

  @test_util.run_deprecated_v1
  def testSetOnEmptyListWithMaxNumElementsFails(self):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=[], max_num_elements=3)
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        "Trying to modify element 0 in a list with 0 elements."):
      l = list_ops.tensor_list_set_item(l, 0, 1.)
      self.evaluate(l)

  def testUnknownShape(self):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=None)
    l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
    l = list_ops.tensor_list_push_back(l, constant_op.constant([1.0, 2.0]))
    l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(e), [1.0, 2.0])
    l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(e), 1.0)

  @test_util.run_gpu_only
  def testCPUGPUCopy(self):
    t = constant_op.constant([1.0, 2.0])
    l = list_ops.tensor_list_from_tensor(t, element_shape=[])
    with context.device("gpu:0"):
      l_gpu = array_ops.identity(l)
      self.assertAllEqual(
          self.evaluate(
              list_ops.tensor_list_pop_back(
                  l_gpu, element_dtype=dtypes.float32)[1]), 2.0)
    l_cpu = array_ops.identity(l_gpu)
    self.assertAllEqual(
        self.evaluate(
            list_ops.tensor_list_pop_back(
                l_cpu, element_dtype=dtypes.float32)[1]), 2.0)

  @test_util.run_gpu_only
  def testCPUGPUCopyNested(self):
    t = constant_op.constant([1.0, 2.0])
    child_l = list_ops.tensor_list_from_tensor(t, element_shape=[])
    l = list_ops.empty_tensor_list(
        element_shape=constant_op.constant([], dtype=dtypes.int32),
        element_dtype=dtypes.variant)
    l = list_ops.tensor_list_push_back(l, child_l)
    with context.device("gpu:0"):
      l_gpu = array_ops.identity(l)
      _, child_l_gpu = list_ops.tensor_list_pop_back(
          l_gpu, element_dtype=dtypes.variant)
      self.assertAllEqual(
          self.evaluate(
              list_ops.tensor_list_pop_back(
                  child_l_gpu, element_dtype=dtypes.float32)[1]), 2.0)
    l_cpu = array_ops.identity(l_gpu)
    _, child_l_cpu = list_ops.tensor_list_pop_back(
        l_cpu, element_dtype=dtypes.variant)
    self.assertAllEqual(
        self.evaluate(
            list_ops.tensor_list_pop_back(
                child_l_cpu, element_dtype=dtypes.float32)[1]), 2.0)

  def testGraphStack(self):
    with self.cached_session():
      tl = list_ops.empty_tensor_list(
          element_shape=constant_op.constant([1], dtype=dtypes.int32),
          element_dtype=dtypes.int32)
      tl = list_ops.tensor_list_push_back(tl, [1])
      self.assertAllEqual(
          self.evaluate(
              list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32)),
          [[1]])

  def testSkipEagerStackInLoop(self):
    with self.cached_session():
      t1 = list_ops.empty_tensor_list(
          element_shape=constant_op.constant([], dtype=dtypes.int32),
          element_dtype=dtypes.int32)
      i = constant_op.constant(0, dtype=dtypes.int32)

      def body(i, t1):
        t1 = list_ops.tensor_list_push_back(t1, i)
        i += 1
        return i, t1

      i, t1 = while_loop.while_loop(lambda i, t1: math_ops.less(i, 4), body,
                                    [i, t1])
      s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.int32)
      self.assertAllEqual(self.evaluate(s1), [0, 1, 2, 3])

  def testSkipEagerStackSwitchDtype(self):
    with self.cached_session():
      list_ = list_ops.empty_tensor_list(
          element_shape=constant_op.constant([], dtype=dtypes.int32),
          element_dtype=dtypes.int32)
      m = constant_op.constant([1, 2, 3], dtype=dtypes.float32)

      def body(list_, m):
        list_ = cond.cond(
            math_ops.equal(list_ops.tensor_list_length(list_), 0),
            lambda: list_ops.empty_tensor_list(m.shape, m.dtype), lambda: list_)
        list_ = list_ops.tensor_list_push_back(list_, m)
        return list_, m

      for _ in range(2):
        list_, m = body(list_, m)

      s1 = list_ops.tensor_list_stack(list_, element_dtype=dtypes.float32)
      np_s1 = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32)
      self.assertAllEqual(self.evaluate(s1), np_s1)

  def testSkipEagerStackInLoopSwitchDtype(self):
    with self.cached_session():
      t1 = list_ops.empty_tensor_list(
          element_shape=constant_op.constant([], dtype=dtypes.int32),
          element_dtype=dtypes.int32)
      i = constant_op.constant(0, dtype=dtypes.float32)
      m = constant_op.constant([1, 2, 3], dtype=dtypes.float32)

      def body(i, m, t1):
        t1 = cond.cond(
            math_ops.equal(list_ops.tensor_list_length(t1), 0),
            lambda: list_ops.empty_tensor_list(m.shape, m.dtype), lambda: t1)

        t1 = list_ops.tensor_list_push_back(t1, m * i)
        i += 1.0
        return i, m, t1

      i, m, t1 = while_loop.while_loop(lambda i, m, t1: math_ops.less(i, 4),
                                       body, [i, m, t1])
      s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.float32)
      np_s1 = np.vstack([np.arange(1, 4) * i for i in range(4)])
      self.assertAllEqual(self.evaluate(s1), np_s1)

  def testSerialize(self):
    worker = test_util.create_local_cluster(num_workers=1, num_ps=1)[0][0]
    with ops.Graph().as_default(), session.Session(target=worker.target):
      with ops.device("/job:worker"):
        t = constant_op.constant([[1.0], [2.0]])
        l = list_ops.tensor_list_from_tensor(t, element_shape=[1])
      with ops.device("/job:ps"):
        l_ps = array_ops.identity(l)
        l_ps, e = list_ops.tensor_list_pop_back(
            l_ps, element_dtype=dtypes.float32)
      with ops.device("/job:worker"):
        worker_e = array_ops.identity(e)
      self.assertAllEqual(self.evaluate(worker_e), [2.0])

  def testSerializeListWithInvalidTensors(self):
    worker = test_util.create_local_cluster(num_workers=1, num_ps=1)[0][0]
    with ops.Graph().as_default(), session.Session(target=worker.target):
      with ops.device("/job:worker"):
        l = list_ops.tensor_list_reserve(
            element_dtype=dtypes.float32, element_shape=[], num_elements=2)
        l = list_ops.tensor_list_set_item(l, 0, 1.)
      with ops.device("/job:ps"):
        l_ps = array_ops.identity(l)
        l_ps = list_ops.tensor_list_set_item(l_ps, 1, 2.)
        t = list_ops.tensor_list_stack(l_ps, element_dtype=dtypes.float32)
      with ops.device("/job:worker"):
        worker_t = array_ops.identity(t)
      self.assertAllEqual(self.evaluate(worker_t), [1.0, 2.0])

  def testSerializeListWithUnknownRank(self):
    worker = test_util.create_local_cluster(num_workers=1, num_ps=1)[0][0]
    with ops.Graph().as_default(), session.Session(target=worker.target):
      with ops.device("/job:worker"):
        t = constant_op.constant([[1.0], [2.0]])
        l = list_ops.tensor_list_from_tensor(t, element_shape=None)
      with ops.device("/job:ps"):
        l_ps = array_ops.identity(l)
        element_shape = list_ops.tensor_list_element_shape(
            l_ps, shape_type=dtypes.int32)
      with ops.device("/job:worker"):
        element_shape = array_ops.identity(element_shape)
      self.assertEqual(self.evaluate(element_shape), -1)

  def testSerializeListWithMaxNumElements(self):
    worker = test_util.create_local_cluster(num_workers=1, num_ps=1)[0][0]
    with ops.Graph().as_default(), session.Session(target=worker.target):
      with ops.device("/job:worker"):
        l = list_ops.empty_tensor_list(
            element_shape=None,
            element_dtype=dtypes.float32,
            max_num_elements=2)
        l = list_ops.tensor_list_push_back(l, 1.)
      with ops.device("/job:ps"):
        l_ps = array_ops.identity(l)
        l_ps = list_ops.tensor_list_push_back(l_ps, 2.)
      with self.assertRaisesRegex(errors.InvalidArgumentError,
                                  "Tried to push item into a full list"):
        with ops.device("/job:worker"):
          l_worker = array_ops.identity(l_ps)
          l_worker = list_ops.tensor_list_push_back(l_worker, 3.0)
          self.evaluate(l_worker)

  def testPushPopGradients(self):
    with backprop.GradientTape() as tape:
      l = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32, element_shape=[])
      c = constant_op.constant(1.0)
      tape.watch(c)
      l = list_ops.tensor_list_push_back(l, c)
      l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
      e = 2 * e
    self.assertAllEqual(self.evaluate(tape.gradient(e, [c])[0]), 2.0)

  def testStackFromTensorGradients(self):
    with backprop.GradientTape() as tape:
      c = constant_op.constant([1.0, 2.0])
      tape.watch(c)
      l = list_ops.tensor_list_from_tensor(c, element_shape=[])
      c2 = list_ops.tensor_list_stack(
          l, element_dtype=dtypes.float32, num_elements=2)
      result = c2 * 2.0
    grad = tape.gradient(result, [c])[0]
    self.assertAllEqual(self.evaluate(grad), [2.0, 2.0])

  def testGetSetGradients(self):
    with backprop.GradientTape() as tape:
      c = constant_op.constant([1.0, 2.0])
      tape.watch(c)
      l = list_ops.tensor_list_from_tensor(c, element_shape=[])
      c2 = constant_op.constant(3.0)
      tape.watch(c2)
      l = list_ops.tensor_list_set_item(l, 0, c2)
      e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
      ee = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)
      y = e * e + ee * ee
    grad_c, grad_c2 = tape.gradient(y, [c, c2])
    self.assertAllEqual(self.evaluate(grad_c), [0.0, 4.0])
    self.assertAllEqual(self.evaluate(grad_c2), 6.0)

  @test_util.run_deprecated_v1
  def testSetOutOfBounds(self):
    c = constant_op.constant([1.0, 2.0])
    l = list_ops.tensor_list_from_tensor(c, element_shape=[])
    with self.assertRaises(errors.InvalidArgumentError):
      self.evaluate(list_ops.tensor_list_set_item(l, 20, 3.0))

  @test_util.run_deprecated_v1
  def testSkipEagerSetItemWithMismatchedShapeFails(self):
    with self.cached_session() as sess:
      ph = array_ops.placeholder(dtypes.float32)
      c = constant_op.constant([1.0, 2.0])
      l = list_ops.tensor_list_from_tensor(c, element_shape=[])
      # Set a placeholder with unknown shape to satisfy the shape inference
      # at graph building time.
      l = list_ops.tensor_list_set_item(l, 0, ph)
      l_0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
      with self.assertRaisesRegex(errors.InvalidArgumentError,
                                  "incompatible shape"):
        sess.run(l_0, {ph: [3.0]})

  def testResourceVariableScatterGather(self):
    c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32)
    l = list_ops.tensor_list_from_tensor(c, element_shape=[])
    v = vs.get_variable("var", initializer=[l] * 10, use_resource=True)
    v_r_0_stacked = list_ops.tensor_list_stack(v[0], dtypes.float32)
    self.evaluate(v.initializer)
    self.assertAllEqual([1.0, 2.0], self.evaluate(v_r_0_stacked))
    v_r_sparse_stacked = list_ops.tensor_list_stack(
        v.sparse_read(0), dtypes.float32)
    self.assertAllEqual([1.0, 2.0], self.evaluate(v_r_sparse_stacked))
    l_new_0 = list_ops.tensor_list_from_tensor([3.0, 4.0], element_shape=[])
    l_new_1 = list_ops.tensor_list_from_tensor([5.0, 6.0], element_shape=[])
    updated_v = state_ops.scatter_update(v, [3, 5], [l_new_0, l_new_1])
    updated_v_elems = array_ops_stack.unstack(updated_v)
    updated_v_stacked = [
        list_ops.tensor_list_stack(el, dtypes.float32) for el in updated_v_elems
    ]
    expected = ([[1.0, 2.0]] * 3 + [[3.0, 4.0], [1.0, 2.0], [5.0, 6.0]] +
                [[1.0, 2.0]] * 4)
    self.assertAllEqual(self.evaluate(updated_v_stacked), expected)

  def testResourceVariableScatterGatherInt64(self):
    c = constant_op.constant([1, 2], dtype=dtypes.int64)
    l = list_ops.tensor_list_from_tensor(c, element_shape=[])
    v = vs.get_variable("var", initializer=[l] * 10, use_resource=True)
    v_r_0_stacked = list_ops.tensor_list_stack(v[0], dtypes.int64)
    self.evaluate(v.initializer)
    self.assertAllEqual([1, 2], self.evaluate(v_r_0_stacked))
    v_r_sparse_stacked = list_ops.tensor_list_stack(
        v.sparse_read(0), dtypes.int64)
    self.assertAllEqual([1, 2], self.evaluate(v_r_sparse_stacked))
    c34 = constant_op.constant([3, 4], dtype=dtypes.int64)
    l_new_0 = list_ops.tensor_list_from_tensor(c34, element_shape=[])
    c56 = constant_op.constant([5, 6], dtype=dtypes.int64)
    l_new_1 = list_ops.tensor_list_from_tensor(c56, element_shape=[])
    updated_v = state_ops.scatter_update(v, [3, 5], [l_new_0, l_new_1])
    updated_v_elems = array_ops_stack.unstack(updated_v)
    updated_v_stacked = [
        list_ops.tensor_list_stack(el, dtypes.int64) for el in updated_v_elems
    ]
    expected = ([[1, 2]] * 3 + [[3, 4], [1, 2], [5, 6]] +
                [[1, 2]] * 4)
    self.assertAllEqual(self.evaluate(updated_v_stacked), expected)

  @test_util.run_deprecated_v1
  def testConcat(self):
    c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32)
    l0 = list_ops.tensor_list_from_tensor(c, element_shape=[])
    l1 = list_ops.tensor_list_from_tensor([-1.0], element_shape=[])
    l_batch_0 = array_ops_stack.stack([l0, l1])
    l_batch_1 = array_ops_stack.stack([l1, l0])

    l_concat_01 = list_ops.tensor_list_concat_lists(
        l_batch_0, l_batch_1, element_dtype=dtypes.float32)
    l_concat_10 = list_ops.tensor_list_concat_lists(
        l_batch_1, l_batch_0, element_dtype=dtypes.float32)
    l_concat_00 = list_ops.tensor_list_concat_lists(
        l_batch_0, l_batch_0, element_dtype=dtypes.float32)
    l_concat_11 = list_ops.tensor_list_concat_lists(
        l_batch_1, l_batch_1, element_dtype=dtypes.float32)

    expected_0 = [[1.0, 2.0], [-1.0]]
    expected_1 = [[-1.0], [1.0, 2.0]]
    expected_00 = [[1.0, 2.0, 1.0, 2.0], [-1.0, -1.0]]
    expected_01 = [[1.0, 2.0, -1.0], [-1.0, 1.0, 2.0]]
    expected_10 = [[-1.0, 1.0, 2.0], [1.0, 2.0, -1.0]]
    expected_11 = [[-1.0, -1.0], [1.0, 2.0, 1.0, 2.0]]

    for i, (concat, expected) in enumerate(zip(
        [l_batch_0, l_batch_1,
         l_concat_00, l_concat_01, l_concat_10, l_concat_11],
        [expected_0, expected_1,
         expected_00, expected_01, expected_10, expected_11])):
      splitted = array_ops_stack.unstack(concat)                # go/LEGACY_TYPO
      splitted_stacked_ret = self.evaluate(                     # go/LEGACY_TYPO
          (list_ops.tensor_list_stack(splitted[0], dtypes.float32),
           list_ops.tensor_list_stack(splitted[1], dtypes.float32)))
      print("Test concat %d: %s, %s, %s, %s"
            % (i, expected[0], splitted_stacked_ret[0],
               expected[1], splitted_stacked_ret[1]))
      self.assertAllClose(expected[0], splitted_stacked_ret[0])
      self.assertAllClose(expected[1], splitted_stacked_ret[1])

    # Concatenating mismatched shapes fails.
    with self.assertRaises((errors.InvalidArgumentError, ValueError)):
      self.evaluate(
          list_ops.tensor_list_concat_lists(
              l_batch_0,
              list_ops.empty_tensor_list([], dtypes.float32),
              element_dtype=dtypes.float32))

    if context.executing_eagerly():
      expected_error = (
          errors.InvalidArgumentError,
          "element shapes are not identical at index 0")
    else:
      expected_error = (ValueError, "Shapes must be equal rank")
    with self.assertRaisesRegex(*expected_error):
      l_batch_of_vec_tls = array_ops_stack.stack(
          [list_ops.tensor_list_from_tensor([[1.0]], element_shape=[1])] * 2)
      self.evaluate(
          list_ops.tensor_list_concat_lists(l_batch_0, l_batch_of_vec_tls,
                                            element_dtype=dtypes.float32))

    if context.executing_eagerly():
      expected_error = (errors.InvalidArgumentError,
                        r"input_b\[0\].dtype != element_dtype.")
    else:
      expected_error = (ValueError, "input_b.type != element_dtype")
    with self.assertRaisesRegex(*expected_error):
      l_batch_of_int_tls = array_ops_stack.stack(
          [list_ops.tensor_list_from_tensor([1], element_shape=[])] * 2)
      self.evaluate(
          list_ops.tensor_list_concat_lists(l_batch_0, l_batch_of_int_tls,
                                            element_dtype=dtypes.float32))

  @test_util.run_deprecated_v1
  def testPushBackBatch(self):
    c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32)
    l0 = list_ops.tensor_list_from_tensor(c, element_shape=[])
    l1 = list_ops.tensor_list_from_tensor([-1.0], element_shape=[])
    l_batch = array_ops_stack.stack([l0, l1])
    l_push = list_ops.tensor_list_push_back_batch(l_batch, [3.0, 4.0])
    l_unstack = array_ops_stack.unstack(l_push)
    l0_ret = list_ops.tensor_list_stack(l_unstack[0], dtypes.float32)
    l1_ret = list_ops.tensor_list_stack(l_unstack[1], dtypes.float32)
    self.assertAllClose([1.0, 2.0, 3.0], self.evaluate(l0_ret))
    self.assertAllClose([-1.0, 4.0], self.evaluate(l1_ret))

    with ops.control_dependencies([l_push]):
      l_unstack_orig = array_ops_stack.unstack(l_batch)
      l0_orig_ret = list_ops.tensor_list_stack(l_unstack_orig[0],
                                               dtypes.float32)
      l1_orig_ret = list_ops.tensor_list_stack(l_unstack_orig[1],
                                               dtypes.float32)

    # Check that without aliasing, push_back_batch still works; and
    # that it doesn't modify the input.
    l0_r_v, l1_r_v, l0_orig_v, l1_orig_v = self.evaluate(
        (l0_ret, l1_ret, l0_orig_ret, l1_orig_ret))
    self.assertAllClose([1.0, 2.0, 3.0], l0_r_v)
    self.assertAllClose([-1.0, 4.0], l1_r_v)
    self.assertAllClose([1.0, 2.0], l0_orig_v)
    self.assertAllClose([-1.0], l1_orig_v)

    # Pushing back mismatched shapes fails.
    with self.assertRaises((errors.InvalidArgumentError, ValueError)):
      self.evaluate(list_ops.tensor_list_push_back_batch(l_batch, []))

    with self.assertRaisesRegex(errors.InvalidArgumentError,
                                "incompatible shape to a list at index 0"):
      self.evaluate(
          list_ops.tensor_list_push_back_batch(l_batch, [[3.0], [4.0]]))

    if context.executing_eagerly():
      expected_error = (errors.InvalidArgumentError, "Invalid data type")
    else:
      expected_error = (ValueError, "wrong element dtype")
    with self.assertRaisesRegex(*expected_error):
      self.evaluate(list_ops.tensor_list_push_back_batch(l_batch, [3, 4]))

  def testZerosLike(self):
    for dtype in (dtypes.uint8, dtypes.uint16, dtypes.int8, dtypes.int16,
                  dtypes.int32, dtypes.int64, dtypes.float16, dtypes.float32,
                  dtypes.float64, dtypes.complex64, dtypes.complex128,
                  dtypes.bool):
      l_empty = list_ops.empty_tensor_list(
          element_dtype=dtype, element_shape=[])
      l_empty_zeros = array_ops.zeros_like(l_empty)
      t_empty_zeros = list_ops.tensor_list_stack(
          l_empty_zeros, element_dtype=dtype)

      l_full = list_ops.tensor_list_push_back(l_empty,
                                              math_ops.cast(0, dtype=dtype))
      l_full = list_ops.tensor_list_push_back(l_full,
                                              math_ops.cast(1, dtype=dtype))
      l_full_zeros = array_ops.zeros_like(l_full)
      t_full_zeros = list_ops.tensor_list_stack(
          l_full_zeros, element_dtype=dtype)

      self.assertAllEqual(self.evaluate(t_empty_zeros), [])
      self.assertAllEqual(
          self.evaluate(t_full_zeros), np.zeros(
              (2,), dtype=dtype.as_numpy_dtype))

  def testZerosLikeNested(self):
    for dtype in (dtypes.uint8, dtypes.uint16, dtypes.int8, dtypes.int16,
                  dtypes.int32, dtypes.int64, dtypes.float16, dtypes.float32,
                  dtypes.float64, dtypes.complex64, dtypes.complex128,
                  dtypes.bool):
      l = list_ops.empty_tensor_list(
          element_dtype=dtypes.variant, element_shape=[])

      sub_l = list_ops.empty_tensor_list(element_dtype=dtype, element_shape=[])
      l = list_ops.tensor_list_push_back(l, sub_l)
      sub_l = list_ops.tensor_list_push_back(sub_l, math_ops.cast(
          1, dtype=dtype))
      l = list_ops.tensor_list_push_back(l, sub_l)
      sub_l = list_ops.tensor_list_push_back(sub_l, math_ops.cast(
          2, dtype=dtype))
      l = list_ops.tensor_list_push_back(l, sub_l)

      # l : [[],
      #      [1],
      #      [1, 2]]
      #
      # l_zeros : [[],
      #            [0],
      #            [0, 0]]
      l_zeros = array_ops.zeros_like(l)

      outputs = []
      for _ in range(3):
        l_zeros, out = list_ops.tensor_list_pop_back(
            l_zeros, element_dtype=dtypes.variant)
        outputs.append(list_ops.tensor_list_stack(out, element_dtype=dtype))

      # Note: `outputs` contains popped values so the order is reversed.
      self.assertAllEqual(self.evaluate(outputs[2]), [])
      self.assertAllEqual(
          self.evaluate(outputs[1]), np.zeros((1,), dtype=dtype.as_numpy_dtype))
      self.assertAllEqual(
          self.evaluate(outputs[0]), np.zeros((2,), dtype=dtype.as_numpy_dtype))

  def testElementShape(self):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=None)
    shape = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32)
    self.assertEqual(self.evaluate(shape), -1)

  def testZerosLikeUninitialized(self):
    l0 = list_ops.tensor_list_reserve([], 3, element_dtype=dtypes.float32)
    l1 = list_ops.tensor_list_set_item(l0, 0, 1.)  # [1., _, _]
    zeros_1 = array_ops.zeros_like(l1)  # [0., _, _]
    l2 = list_ops.tensor_list_set_item(l1, 2, 2.)  # [1., _, 2.]
    zeros_2 = array_ops.zeros_like(l2)  # [0., _, 0.]

    # Gather indices with zeros in `zeros_1`.
    res_1 = list_ops.tensor_list_gather(
        zeros_1, [0], element_dtype=dtypes.float32)
    # Gather indices with zeros in `zeros_2`.
    res_2 = list_ops.tensor_list_gather(
        zeros_2, [0, 2], element_dtype=dtypes.float32)

    self.assertAllEqual(self.evaluate(res_1), [0.])
    self.assertAllEqual(self.evaluate(res_2), [0., 0.])

  @test_util.run_deprecated_v1
  def testSkipEagerTensorListGetItemGradAggregation(self):
    l = list_ops.tensor_list_reserve(
        element_shape=[], num_elements=1, element_dtype=dtypes.float32)
    x = constant_op.constant(1.0)
    l = list_ops.tensor_list_set_item(l, 0, x)
    l_read1 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
    l_read2 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
    grad = gradients_impl.gradients([l_read1, l_read2], [x])
    with self.cached_session() as sess:
      self.assertSequenceEqual(self.evaluate(grad), [2.])

  @test_util.run_deprecated_v1
  def testSkipEagerBuildElementShape(self):
    fn = list_ops._build_element_shape
    # Unknown shape -> -1.
    self.assertEqual(fn(None), -1)
    self.assertEqual(fn(tensor_shape.unknown_shape()), -1)
    # Scalar shape -> [] with type int32.
    self.assertEqual(fn([]).dtype, dtypes.int32)
    self.assertEqual(fn(tensor_shape.TensorShape([])).dtype, dtypes.int32)
    self.assertAllEqual(self.evaluate(fn([])), np.array([], np.int32))
    self.assertAllEqual(
        self.evaluate(fn(tensor_shape.TensorShape([]))), np.array([], np.int32))
    # Tensor -> Tensor
    shape = constant_op.constant(1)
    self.assertIs(fn(shape), shape)
    # Shape with unknown dims -> shape list with -1's.
    shape = [None, 5]
    self.assertAllEqual(fn(shape), [-1, 5])
    self.assertAllEqual(fn(tensor_shape.TensorShape(shape)), [-1, 5])
    # Shape with unknown dims and tensor dims -> shape list with -1's and tensor
    # dims.
    t = array_ops.placeholder(dtypes.int32)
    shape = [None, 5, t]
    result = fn(shape)
    self.assertAllEqual(result[:2], [-1, 5])
    self.assertIs(result[2], t)

  def testAddN(self):
    l1 = list_ops.tensor_list_from_tensor([1.0, 2.0], element_shape=[])
    l2 = list_ops.tensor_list_from_tensor([3.0, 4.0], element_shape=[])
    l3 = list_ops.tensor_list_from_tensor([5.0, 6.0], element_shape=[])
    result = math_ops.add_n((l1, l2, l3))
    result_t = list_ops.tensor_list_stack(result, element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(result_t), [9., 12.])

  def testAddNNestedList(self):
    l1 = list_ops.tensor_list_from_tensor([1.0, 2.0], element_shape=[])
    l2 = list_ops.tensor_list_from_tensor([3.0, 4.0], element_shape=[])
    l3 = list_ops.tensor_list_from_tensor([5.0, 6.0], element_shape=[])
    l4 = list_ops.tensor_list_from_tensor([7.0, 8.0], element_shape=[])
    a = list_ops.empty_tensor_list(
        element_dtype=dtypes.variant, element_shape=[])
    a = list_ops.tensor_list_push_back(a, l1)
    a = list_ops.tensor_list_push_back(a, l2)
    b = list_ops.empty_tensor_list(
        element_dtype=dtypes.variant, element_shape=[])
    b = list_ops.tensor_list_push_back(b, l3)
    b = list_ops.tensor_list_push_back(b, l4)
    result = math_ops.add_n((a, b))
    result_0 = list_ops.tensor_list_stack(
        list_ops.tensor_list_get_item(result, 0, element_dtype=dtypes.variant),
        element_dtype=dtypes.float32)
    result_1 = list_ops.tensor_list_stack(
        list_ops.tensor_list_get_item(result, 1, element_dtype=dtypes.variant),
        element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(result_0), [6., 8.])
    self.assertAllEqual(self.evaluate(result_1), [10., 12.])

  def testAddTensorListsFailsIfLeadingDimsMismatch(self):
    l1 = list_ops.tensor_list_reserve(
        element_shape=[], element_dtype=dtypes.float32, num_elements=2)
    l2 = list_ops.tensor_list_reserve(
        element_shape=[], element_dtype=dtypes.float32, num_elements=3)
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        "Trying to add two lists of tensors with different lengths"):
      l = math_ops.add_n([l1, l2])
      self.evaluate(list_ops.tensor_list_stack(l, element_dtype=dtypes.float32))

  @test_util.run_v1_only("Uses placeholders")
  def testSkipEagerAddTensorListsFailsIfElementShapesMismatch(self):
    with self.cached_session() as sess:
      # Use placeholders instead of constant values for shapes to prevent TF's
      # shape inference from catching this early.
      l1_element_shape = array_ops.placeholder(dtype=dtypes.int32)
      l2_element_shape = array_ops.placeholder(dtype=dtypes.int32)
      l1 = list_ops.tensor_list_reserve(
          element_shape=l1_element_shape,
          element_dtype=dtypes.float32,
          num_elements=3)
      l2 = list_ops.tensor_list_reserve(
          element_shape=l2_element_shape,
          element_dtype=dtypes.float32,
          num_elements=3)
      l = math_ops.add_n([l1, l2])
      with self.assertRaisesRegex(
          errors.InvalidArgumentError,
          "Trying to add two lists of tensors with incompatible element shapes"
      ):
        sess.run(
            list_ops.tensor_list_stack(l, element_dtype=dtypes.float32), {
                l1_element_shape: [],
                l2_element_shape: [2]
            })

  @test_util.run_deprecated_v1
  def testSkipEagerConcatShapeInference(self):

    def BuildTensor(element_shape):
      l = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32, element_shape=element_shape)
      return list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)

    self.assertIsNone(BuildTensor(None).shape.rank)
    self.assertAllEqual(BuildTensor([None, 2, 3]).shape.as_list(), [None, 2, 3])
    self.assertAllEqual(
        BuildTensor([None, 2, None]).shape.as_list(), [None, 2, None])
    self.assertAllEqual(BuildTensor([1, 2, 3]).shape.as_list(), [None, 2, 3])

  def testConcatWithFullyDefinedElementShape(self):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=[2, 2])
    l = list_ops.tensor_list_push_back(l, [[0., 1.], [2., 3.]])
    l = list_ops.tensor_list_push_back(l, [[4., 5.], [6., 7.]])
    t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
    self.assertAllEqual(
        self.evaluate(t), [[0., 1.], [2., 3.], [4., 5.], [6., 7.]])

  def testConcatWithNonFullyDefinedElementShape(self):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=[None, 2])
    l = list_ops.tensor_list_push_back(l, [[0., 1.]])
    l = list_ops.tensor_list_push_back(l, [[2., 3.], [4., 5.]])
    t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [[0., 1.], [2., 3.], [4., 5.]])

  def testConcatWithMismatchingTensorShapesFails(self):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=None)
    l = list_ops.tensor_list_push_back(l, [[0., 1.]])
    l = list_ops.tensor_list_push_back(l, [[2.], [4.]])
    with self.assertRaisesRegex(
        errors.InvalidArgumentError, r"Incompatible shapes during merge: "
        r"\[2\] vs. \[1\]"):
      t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
      self.evaluate(t)

  def testConcatEmptyListWithFullyDefinedElementShape(self):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=[5, 2])
    t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t).shape, (0, 2))
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=[None, 2])
    t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t).shape, (0, 2))

  def testConcatEmptyListWithUnknownElementShapeFails(self):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=None)
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        "All except the first dimension must be fully"
        " defined when concating an empty tensor list"):
      t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
      self.evaluate(t)

  def testConcatEmptyListWithPartiallyDefinedElementShapeFails(self):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=[2, None])
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        "All except the first dimension must be fully"
        " defined when concating an empty tensor list"):
      t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
      self.evaluate(t)

  def testConcatListWithScalarElementShapeFails(self):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=tensor_shape.TensorShape([]))
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        "Concat requires elements to be at least vectors, "
        "found scalars instead"):
      t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
      self.evaluate(t)

  def testConcatListWithScalarElementsFails(self):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=None)
    l1 = list_ops.tensor_list_push_back(l, 1.)
    with self.assertRaisesRegex(
        errors.InvalidArgumentError, "Concat saw a scalar shape at index 0"
        " but requires at least vectors"):
      t = list_ops.tensor_list_concat(l1, element_dtype=dtypes.float32)
      self.evaluate(t)
    l1 = list_ops.tensor_list_push_back(l, [1.])
    l1 = list_ops.tensor_list_push_back(l1, 2.)
    with self.assertRaisesRegex(
        errors.InvalidArgumentError, "Concat saw a scalar shape at index 1"
        " but requires at least vectors"):
      t = list_ops.tensor_list_concat(l1, element_dtype=dtypes.float32)
      self.evaluate(t)

  def testConcatWithUninitializedTensorsUseListElementShape(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=[2, 3], num_elements=3)
    t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
    self.assertAllEqual(np.zeros((6, 3)), t)

  def testConcatWithUninitializedTensorsUseProvidedElementShape(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=None, num_elements=3)
    t = list_ops.tensor_list_concat(
        l, element_dtype=dtypes.float32, element_shape=(2, 3))
    self.assertAllEqual(np.zeros((6, 3)), t)

  def testConcatWithUninitializedTensorsUseProvidedElementShapeAndLengths(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=None, num_elements=3)
    t, _ = gen_list_ops.tensor_list_concat_v2(
        l,
        element_dtype=dtypes.float32,
        element_shape=list_ops._build_element_shape((None, 3)),
        leading_dims=[2, 3, 5])
    self.assertAllEqual(np.zeros((10, 3)), t)
    l = list_ops.tensor_list_set_item(l, 1, [[2., 3.], [4., 5.], [6., 7.]])
    t, _ = gen_list_ops.tensor_list_concat_v2(
        l,
        element_dtype=dtypes.float32,
        element_shape=list_ops._build_element_shape((None, 2)),
        leading_dims=[2, 3, 4])
    self.assertAllEqual([[0., 0.], [0., 0.], [2., 3.], [4., 5.], [6., 7.],
                         [0., 0.], [0., 0.], [0., 0.], [0., 0.]], t)

  def testConcatWithUninitializedTensorsInferShapeFromElements(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=None, num_elements=3)
    l = list_ops.tensor_list_set_item(l, 1, [[2., 3.], [4., 5.], [6., 7.]])
    t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
    self.assertAllEqual([[0., 0.], [0., 0.], [0., 0.], [2., 3.], [4., 5.],
                         [6., 7.], [0., 0.], [0., 0.], [0., 0.]], t)

  def testConcatWithUninitializedTensorsFailsIfNoElementShape(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=None, num_elements=3)
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        r"Trying to concat list with only uninitialized tensors "
        r"but element_shape_except_first_dim is not fully defined"):
      t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
      self.evaluate(t)

  def testConcatWithUninitializedTensorsFailsIfNoInputLengths(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=[None, 3], num_elements=3)
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        r"List contains uninitialized tensor at index 0"
        r" but leading_dims has only 0 elements."):
      t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
      self.evaluate(t)

  @test_util.run_in_graph_and_eager_modes
  def testConcatWithInvalidElementShape(self):
    l = list_ops.tensor_list_reserve(
        element_dtype=dtypes.float32, element_shape=[], num_elements=0)
    with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
                                r"element_shape must not be empty"):
      self.evaluate(gen_list_ops.tensor_list_concat(
          input_handle=l, element_dtype=dtypes.float32, element_shape=[]))

  def testEmptyTensorListInvalidShape(self):
    with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
                                r"Shape must be at most rank 1 but is rank 2"):
      t = gen_list_ops.EmptyTensorList(
          element_shape=array_ops.ones(dtype=dtypes.int32, shape=[1, 0]),
          max_num_elements=constant_op.constant(1),
          element_dtype=dtypes.int32)
      self.evaluate(t)

  def testEvenSplit(self):

    def RunTest(input_tensor, lengths, expected_stacked_output):
      l = list_ops.tensor_list_split(
          input_tensor, element_shape=None, lengths=lengths)
      self.assertAllEqual(
          list_ops.tensor_list_stack(l, element_dtype=dtypes.float32),
          expected_stacked_output)

    RunTest([1., 2., 3.], [1, 1, 1], [[1.], [2.], [3.]])
    RunTest([1., 2., 3., 4.], [2, 2], [[1., 2.], [3., 4.]])
    RunTest([[1., 2.], [3., 4.]], [1, 1], [[[1., 2.]], [[3., 4.]]])

  def testUnevenSplit(self):
    l = list_ops.tensor_list_split([1., 2., 3., 4., 5],
                                   element_shape=None,
                                   lengths=[3, 2])
    self.assertAllEqual(list_ops.tensor_list_length(l), 2)
    self.assertAllEqual(
        list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32),
        [1., 2., 3.])
    self.assertAllEqual(
        list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32),
        [4., 5.])

  @test_util.run_deprecated_v1
  def testSkipEagerSplitWithInvalidTensorShapeFails(self):
    with self.cached_session():
      tensor = array_ops.placeholder(dtype=dtypes.float32)
      l = list_ops.tensor_list_split(tensor, element_shape=None, lengths=[1])
      with self.assertRaisesRegex(
          errors.InvalidArgumentError,
          r"Tensor must be at least a vector, but saw shape: \[\]"):
        l.eval({tensor: 1})

  @test_util.run_deprecated_v1
  def testSkipEagerSplitWithInvalidLengthsShapeFails(self):
    with self.cached_session():
      lengths = array_ops.placeholder(dtype=dtypes.int64)
      l = list_ops.tensor_list_split([1., 2.],
                                     element_shape=None,
                                     lengths=lengths)
      with self.assertRaisesRegex(
          errors.InvalidArgumentError,
          r"Expected lengths to be a vector, received shape: \[\]"):
        l.eval({lengths: 1})

  def testSplitWithInvalidLengthsFails(self):
    with self.assertRaisesRegex(errors.InvalidArgumentError,
                                r"Invalid value in lengths: -1"):
      l = list_ops.tensor_list_split([1., 2.],
                                     element_shape=None,
                                     lengths=[1, -1])
      self.evaluate(l)
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        r"Attempting to slice \[0, 3\] from tensor with length 2"):
      l = list_ops.tensor_list_split([1., 2.], element_shape=None, lengths=[3])
      self.evaluate(l)
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        r"Unused values in tensor. Length of tensor: 2 Values used: 1"):
      l = list_ops.tensor_list_split([1., 2.], element_shape=None, lengths=[1])
      self.evaluate(l)

  @test_util.run_deprecated_v1
  def testSkipEagerSplitWithScalarElementShapeFails(self):
    with self.assertRaisesRegex(ValueError,
                                r"Shapes must be equal rank, but are 1 and 0"):
      l = list_ops.tensor_list_split([1., 2.], element_shape=[], lengths=[1, 1])
    with self.cached_session():
      with self.assertRaisesRegex(
          errors.InvalidArgumentError,
          r"TensorListSplit requires element_shape to be at least of rank 1, "
          r"but saw: \[\]"):
        element_shape = array_ops.placeholder(dtype=dtypes.int32)
        l = list_ops.tensor_list_split([1., 2.],
                                       element_shape=element_shape,
                                       lengths=[1, 1])
        l.eval({element_shape: []})

  def testEagerOnlySplitWithScalarElementShapeFails(self):
    if context.executing_eagerly():
      with self.assertRaisesRegex(
          errors.InvalidArgumentError,
          r"TensorListSplit requires element_shape to be at least of rank 1, "
          r"but saw: \[\]"):
        list_ops.tensor_list_split([1., 2.], element_shape=[], lengths=[1, 1])

  @test_util.run_deprecated_v1
  def testSkipEagerSplitWithIncompatibleTensorShapeAndElementShapeFails(self):
    with self.assertRaisesRegex(ValueError,
                                r"Shapes must be equal rank, but are 2 and 1"):
      l = list_ops.tensor_list_split([[1.], [2.]],
                                     element_shape=[1],
                                     lengths=[1, 1])

    with self.cached_session():
      with self.assertRaisesRegex(
          errors.InvalidArgumentError,
          r"tensor shape \[2,1\] is not compatible with element_shape \[1\]"):
        element_shape = array_ops.placeholder(dtype=dtypes.int32)
        l = list_ops.tensor_list_split([[1.], [2.]],
                                       element_shape=element_shape,
                                       lengths=[1, 1])
        l.eval({element_shape: [1]})

  def testEagerOnlySplitWithIncompatibleTensorShapeAndElementShapeFails(self):
    if context.executing_eagerly():
      with self.assertRaisesRegex(
          errors.InvalidArgumentError,
          r"tensor shape \[2,1\] is not compatible with element_shape \[1\]"):
        list_ops.tensor_list_split([[1.], [2.]],
                                   element_shape=[1],
                                   lengths=[1, 1])

  def testResizeGrow(self):
    l = list_ops.tensor_list_from_tensor([1., 2.], element_shape=[])
    l = list_ops.tensor_list_resize(l, 4)
    self.assertEqual(self.evaluate(list_ops.tensor_list_length(l)), 4)
    self.assertEqual(
        self.evaluate(
            list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)),
        1.)
    self.assertEqual(
        self.evaluate(
            list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)),
        2.)

  def testResizeShrink(self):
    l = list_ops.tensor_list_from_tensor([1., 2., 3.], element_shape=[])
    l = list_ops.tensor_list_resize(l, 2)
    self.assertEqual(self.evaluate(list_ops.tensor_list_length(l)), 2)
    self.assertAllEqual(
        self.evaluate(
            list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)),
        [1., 2.])

  def testResizeWithInvalidSizeFails(self):
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        "TensorListSlice expects size to be non-negative"):
      l = list_ops.tensor_list_from_tensor([1., 2., 3.], element_shape=[])
      l = list_ops.tensor_list_resize(l, -1)
      self.evaluate(l)

  @test_util.run_in_graph_and_eager_modes
  def testResizeWithNonScalarFails(self):
    l = list_ops.tensor_list_from_tensor([3, 4, 5], element_shape=[])
    size = np.zeros([0, 2, 3, 3])
    with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
                                r"Shape must be rank 0 but is rank \d+|"
                                r"\w+ must be a scalar"):
      self.evaluate(gen_list_ops.TensorListResize(input_handle=l, size=size))

  @test_util.run_deprecated_v1
  @test_util.enable_control_flow_v2
  def testSkipEagerResizeGrad(self):
    t = constant_op.constant([1., 2., 3.])
    l = list_ops.tensor_list_from_tensor(t, element_shape=[])
    l = list_ops.tensor_list_set_item(
        l, 3, 4., resize_if_index_out_of_bounds=True)
    t1 = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
    grad = gradients_impl.gradients(t1, t)[0]
    self.assertAllEqual(self.evaluate(grad), [1., 1., 1.])

  def testHandleDataAcrossFunctionCall(self):

    @def_function.function
    def func():
      t = constant_op.constant([1., 2., 3.])
      l = list_ops.tensor_list_from_tensor(t, element_shape=[])
      handle_data = resource_variable_ops.get_eager_safe_handle_data(l)
      self.assertTrue(handle_data.is_set)
      self.assertEqual(handle_data.shape_and_type[0].type.type_id,
                       full_type_pb2.TFT_ARRAY)
      return l

    tensor_list = func()
    handle_data = resource_variable_ops.get_eager_safe_handle_data(tensor_list)
    self.assertTrue(handle_data.is_set)
    self.assertEqual(dtypes.float32, handle_data.shape_and_type[0].dtype)
    self.assertEqual(handle_data.shape_and_type[0].type.type_id,
                     full_type_pb2.TFT_ARRAY)
    element = list_ops.tensor_list_get_item(
        tensor_list, 0, element_dtype=dtypes.float32)
    self.assertAllEqual(element.shape.as_list(), [])

  @test_util.run_gpu_only
  def testNestedListDevicetoDeviceCopy(self):
    if context.num_gpus() < 2:
      self.skipTest("Need at least 2 GPUs for this test, found %d" %
                    context.num_gpus())
    with ops.device("gpu:0"):
      t = constant_op.constant([1.0, 2.0, 3.0])
      inner_l = list_ops.tensor_list_from_tensor(t, element_shape=[])
      outer_l = list_ops.empty_tensor_list(
          element_dtype=dtypes.variant, element_shape=[])
      outer_l = list_ops.tensor_list_push_back(outer_l, inner_l)

    # Stress test.
    for _ in range(1024):
      with ops.device("gpu:1"):
        outer_l = array_ops.identity(outer_l)
      with ops.device("gpu:0"):
        outer_l = array_ops.identity(outer_l)

    with ops.device("gpu:1"):
      _, inner_l = list_ops.tensor_list_pop_back(
          outer_l, element_dtype=dtypes.variant)
      t = list_ops.tensor_list_stack(inner_l, element_dtype=dtypes.float32)
      self.assertAllEqual(t, [1.0, 2.0, 3.0])

  def testTensorListStrings(self):
    @def_function.function
    def f():
      return map_fn.map_fn(string_ops.string_upper,
                           constant_op.constant(["a", "b", "c"]))

    self.assertAllEqual(f(), [b"A", b"B", b"C"])

  def testTensorListStringsNoInline(self):
    # Generator function output type is a variant with a host-only underlying
    # data type. "ColocationGraph::AddHostOnlyDataTypesConstraints" needs to
    # have "deep op inspection" to be able to correctly place the while loop
    # generated from map_fn.
    self.skipTest("b/150742232")

    @def_function.function(experimental_attributes={"_noinline": True})
    def generator(c):
      return list_ops.tensor_list_from_tensor(c, element_shape=[])

    @def_function.function
    def f(c):
      l = generator(c)

      def upper(i):
        e = list_ops.tensor_list_get_item(l, i, element_dtype=dtypes.string)
        return string_ops.string_upper(e)

      return map_fn.map_fn(
          upper, constant_op.constant([0, 1, 2]), dtype=dtypes.string)

    c = constant_op.constant(["a", "b", "c"])
    self.assertAllEqual(f(c), [b"A", b"B", b"C"])

  def testPopBackGrad(self):
    # https://github.com/tensorflow/tensorflow/issues/37230

    @def_function.function
    def g(x):
      x_prod = constant_op.constant([1.])
      for unused_i in math_ops.range(3):
        x_prod = x_prod * x
      return x_prod

    x = constant_op.constant(1.)
    with backprop.GradientTape() as t:
      t.watch(x)
      with backprop.GradientTape() as tt:
        tt.watch(x)
        loss = g(x)
      jac = tt.gradient(loss, x)
    hess = t.gradient(jac, x)
    self.assertAllEqual(hess, 6.)

  def testTensorListElementShapeShapeInference(self):

    @def_function.function
    def f():
      l = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32, element_shape=None)
      l_element_shape = list_ops.tensor_list_element_shape(l, dtypes.int32)
      self.assertIsNone(l_element_shape.shape.rank)
      shape_l = list_ops.empty_tensor_list(
          element_dtype=dtypes.int32, element_shape=l_element_shape.shape)
      shape_l = list_ops.tensor_list_push_back(shape_l, l_element_shape)
      return list_ops.tensor_list_pop_back(shape_l, dtypes.int32)[1]

    self.assertAllEqual(f(), -1)

  def testElementShapeArgOfTensorListFromTensor(self):

    @def_function.function
    def f():
      t = array_ops.ones([3, 3])
      l = list_ops.tensor_list_from_tensor(t, element_shape=[-1])
      l = list_ops.tensor_list_push_back(l, array_ops.ones([4]))
      read_val = list_ops.tensor_list_get_item(
          l, 3, element_dtype=dtypes.float32)
      self.assertAllEqual(read_val.shape.as_list(), [None])
      return read_val

    self.assertAllEqual(f(), [1.0, 1.0, 1.0, 1.0])


if __name__ == "__main__":
  test.main()