tensorflow/python/kernel_tests/data_structures/tensor_array_ops_test.py
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.tensor_array_ops."""
import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import cond
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_grad
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops import while_loop
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
def _make_converter(tf_dtype):
def _converter(x):
if tf_dtype == dtypes.string:
# In Python3, np.str_ is unicode, while we always want bytes
return np.asarray(x).astype("|S")
x = np.asarray(x).astype(tf_dtype.as_numpy_dtype)
if tf_dtype.is_complex:
# Add a non-zero imaginary component to x.
x -= 1j * x
return x
return _converter
def _make_ta(size, name, dtype=dtypes.float32, infer_shape=False):
return tensor_array_ops.TensorArray(
dtype=dtype, tensor_array_name=name, size=size, infer_shape=infer_shape)
@test_util.run_all_in_graph_and_eager_modes
@test_util.with_control_flow_v2
class TensorArrayTest(test.TestCase):
@classmethod
def setUpClass(cls):
super(TensorArrayTest, cls).setUpClass()
cls._workers, _ = test.create_local_cluster(num_workers=3, num_ps=0)
@classmethod
def tearDownClass(cls):
super(TensorArrayTest, cls).tearDownClass()
session_lib.Session.reset(cls._workers[0].target)
@test_util.run_in_graph_and_eager_modes
def testTensorArrayWriteRead(self):
with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=3,
infer_shape=False)
w0 = ta.write(0, [[4.0, 5.0]])
w1 = w0.write(1, [[1.0]])
w2 = w1.write(2, -3.0)
r0 = w2.read(0)
r1 = w2.read(1)
r2 = w2.read(2)
d0, d1, d2 = self.evaluate([r0, r1, r2])
self.assertAllEqual([[4.0, 5.0]], d0)
self.assertAllEqual([[1.0]], d1)
self.assertAllEqual(-3.0, d2)
def _testTensorArrayWritePack(self, tf_dtype):
with self.cached_session():
ta = tensor_array_ops.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3)
convert = _make_converter(tf_dtype)
w0 = ta.write(0, convert([[4.0, 5.0]]))
w1 = w0.write(1, convert([[6.0, 7.0]]))
w2 = w1.write(2, convert([[8.0, 9.0]]))
c0 = w2.stack()
c0 = self.evaluate(c0)
self.assertAllEqual(
convert([[[4.0, 5.0]], [[6.0, 7.0]], [[8.0, 9.0]]]), c0)
def _testTensorArrayWritePackMaybeLegacy(self):
self._testTensorArrayWritePack(dtypes.float32)
self._testTensorArrayWritePack(dtypes.float64)
self._testTensorArrayWritePack(dtypes.int32)
self._testTensorArrayWritePack(dtypes.int64)
self._testTensorArrayWritePack(dtypes.complex64)
self._testTensorArrayWritePack(dtypes.complex128)
self._testTensorArrayWritePack(dtypes.string)
def testTensorArrayWritePack(self):
self._testTensorArrayWritePackMaybeLegacy()
def testEmptyTensorArrayPack(self):
with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
empty_element = np.zeros((0, 1), dtype=np.float32)
w0 = ta.write(0, empty_element)
w1 = w0.write(1, empty_element)
w2 = w1.write(2, empty_element)
c0 = w2.stack()
c0 = self.evaluate(c0)
self.assertAllEqual([3, 0, 1], c0.shape)
def testTensorArrayWriteConcatInParallel(self):
with self.session():
def _concat_1():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.int32, size=2, infer_shape=False)
w0 = ta.write(0, constant_op.constant([1]))
w1 = w0.write(1, constant_op.constant([],
shape=(0,),
dtype=dtypes.int32))
return w1.concat()
def _concat_2():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.int32, size=3, infer_shape=False)
w0 = ta.write(0, constant_op.constant([8]))
w1 = w0.write(1, constant_op.constant([],
shape=(0,),
dtype=dtypes.int32))
w2 = w1.write(2, constant_op.constant([9]))
return w2.concat()
def _write(index, output):
elements = cond.cond(
math_ops.less(index, 3), _concat_1, _concat_2)
return (index + 1, output.write(index, elements))
num_iterations = 6
init_state = (0,
tensor_array_ops.TensorArray(
dtype=dtypes.int32,
size=num_iterations,
infer_shape=False))
_, final_state = while_loop.while_loop(lambda i, _: i < num_iterations,
_write, init_state)
c0 = final_state.concat()
c0 = self.evaluate(c0)
self.assertAllEqual([1, 1, 1, 8, 9, 8, 9, 8, 9], c0)
def _testTensorArrayWriteConcat(self, tf_dtype):
with self.cached_session():
ta = tensor_array_ops.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3, infer_shape=False)
convert = _make_converter(tf_dtype)
w0 = ta.write(0, convert([[4.0, 5.0], [104.0, 105.0], [204.0, 205.0]]))
w1 = w0.write(1, convert([[6.0, 7.0], [106.0, 107.0]]))
w2 = w1.write(2, convert([[8.0, 9.0]]))
c0 = w2.concat()
c0 = self.evaluate(c0)
self.assertAllEqual(
convert([[4.0, 5.0], [104.0, 105.0], [204.0, 205.0], [6.0, 7.0],
[106.0, 107.0], [8.0, 9.0]]), c0)
@test_util.deprecated_graph_mode_only
def testTensorArrayWriteConcat(self):
self._testTensorArrayWriteConcat(dtypes.float32)
self._testTensorArrayWriteConcat(dtypes.float64)
self._testTensorArrayWriteConcat(dtypes.int32)
self._testTensorArrayWriteConcat(dtypes.int64)
self._testTensorArrayWriteConcat(dtypes.complex64)
self._testTensorArrayWriteConcat(dtypes.complex128)
self._testTensorArrayWriteConcat(dtypes.string)
def _testTensorArrayReadOrPackNotAllValuesAvailableFillsZeros(self):
with self.cached_session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=3,
element_shape=tensor_shape.TensorShape([1, 2]))
self.assertAllEqual([[0.0, 0.0]], self.evaluate(ta.read(0)))
self.assertAllEqual([[[0.0, 0.0]], [[4.0, 5.0]], [[0.0, 0.0]]],
self.evaluate(ta.write(1, [[4.0, 5.0]]).stack()))
self.assertAllEqual([[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]],
self.evaluate(ta.write(1, [[4.0, 5.0]]).concat()))
@test_util.run_v1_only("b/122324791")
def testTensorArrayReadOrPackNotAllValuesAvailableFillsZeros(self):
self._testTensorArrayReadOrPackNotAllValuesAvailableFillsZeros()
def _testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros(self):
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=3)
self.assertAllEqual(
[[0.0, 0.0]], self.evaluate(ta.write(1, [[4.0, 5.0]]).read(0)))
self.assertAllEqual([[[0.0, 0.0]], [[4.0, 5.0]], [[0.0, 0.0]]],
self.evaluate(ta.write(1, [[4.0, 5.0]]).stack()))
self.assertAllEqual([[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]],
self.evaluate(ta.write(1, [[4.0, 5.0]]).concat()))
@test_util.run_v1_only("b/122324791")
def testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros(self):
self._testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros()
@test_util.run_v1_only("Uses placeholders")
def testSkipEagerTensorArrayReadUninitializedInferShapeFillsZeros(self):
with self.cached_session() as sess:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=3)
val = array_ops.placeholder(dtypes.float32)
self.assertAllEqual(
[[0.0, 0.0]], sess.run(ta.write(1, val).read(0), {val: [[4.0, 5.0]]}))
def _testTensorArrayUnpackRead(self, tf_dtype):
with self.cached_session():
convert = _make_converter(tf_dtype)
ta = _make_ta(3, "foo", dtype=tf_dtype)
# Unpack a vector into scalars
w0 = ta.unstack(convert([1.0, 2.0, 3.0]))
r0 = w0.read(0)
r1 = w0.read(1)
r2 = w0.read(2)
d0, d1, d2 = self.evaluate([r0, r1, r2])
self.assertAllEqual(convert(1.0), d0)
self.assertAllEqual(convert(2.0), d1)
self.assertAllEqual(convert(3.0), d2)
# Unpack a matrix into vectors
w1 = ta.unstack(convert([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]))
r0 = w1.read(0)
r1 = w1.read(1)
r2 = w1.read(2)
d0, d1, d2 = self.evaluate([r0, r1, r2])
self.assertAllEqual(convert([1.0, 1.1]), d0)
self.assertAllEqual(convert([2.0, 2.1]), d1)
self.assertAllEqual(convert([3.0, 3.1]), d2)
# Try unpacking an empty matrix, which should not cause an error.
w2 = ta.unstack(convert([[], [], []]))
r0 = w2.read(0)
r1 = w2.read(1)
r2 = w2.read(2)
d0, d1, d2 = self.evaluate([r0, r1, r2])
self.assertAllEqual(convert([]), d0)
self.assertAllEqual(convert([]), d1)
self.assertAllEqual(convert([]), d2)
def _testTensorArrayUnpackReadMaybeLegacy(self):
self._testTensorArrayUnpackRead(dtypes.float32)
self._testTensorArrayUnpackRead(dtypes.float64)
self._testTensorArrayUnpackRead(dtypes.int32)
self._testTensorArrayUnpackRead(dtypes.int64)
self._testTensorArrayUnpackRead(dtypes.complex64)
self._testTensorArrayUnpackRead(dtypes.complex128)
self._testTensorArrayUnpackRead(dtypes.string)
self._testTensorArrayUnpackRead(dtypes.bfloat16)
def testTensorArrayUnpackRead(self):
self._testTensorArrayUnpackReadMaybeLegacy()
def _testTensorArraySplitRead(self, tf_dtype):
with self.cached_session():
convert = _make_converter(tf_dtype)
# Split an empty vector
ta = _make_ta(3, "foo", dtype=tf_dtype)
lengths = constant_op.constant([0, 0, 0])
w0 = ta.split(convert([]), lengths=lengths)
r0 = w0.read(0)
r1 = w0.read(1)
r2 = w0.read(2)
d0, d1, d2 = self.evaluate([r0, r1, r2])
self.assertAllEqual(convert([]), d0)
self.assertAllEqual(convert([]), d1)
self.assertAllEqual(convert([]), d2)
# Split a vector
lengths = constant_op.constant([2, 0, 1])
w0 = ta.split(convert([1.0, 2.0, 3.0]), lengths=lengths)
r0 = w0.read(0)
r1 = w0.read(1)
r2 = w0.read(2)
d0, d1, d2 = self.evaluate([r0, r1, r2])
self.assertAllEqual(convert([1.0, 2.0]), d0)
self.assertAllEqual(convert([]), d1)
self.assertAllEqual(convert([3.0]), d2)
# Split a matrix
lengths = constant_op.constant([2, 0, 1])
w0 = ta.split(
convert([[1.0, 101.0], [2.0, 201.0], [3.0, 301.0]]), lengths=lengths)
r0 = w0.read(0)
r1 = w0.read(1)
r2 = w0.read(2)
d0, d1, d2 = self.evaluate([r0, r1, r2])
self.assertAllEqual(convert([[1.0, 101.0], [2.0, 201.0]]), d0)
self.assertAllEqual(convert([]).reshape(0, 2), d1)
self.assertAllEqual(convert([[3.0, 301.0]]), d2)
@test_util.deprecated_graph_mode_only
def testTensorArraySplitRead(self):
self._testTensorArraySplitRead(dtypes.float32)
self._testTensorArraySplitRead(dtypes.float64)
self._testTensorArraySplitRead(dtypes.int32)
self._testTensorArraySplitRead(dtypes.int64)
self._testTensorArraySplitRead(dtypes.complex64)
self._testTensorArraySplitRead(dtypes.complex128)
self._testTensorArraySplitRead(dtypes.string)
self._testTensorArraySplitRead(dtypes.bfloat16)
@test_util.disable_control_flow_v2("v2 does not support TensorArray.grad.")
@test_util.run_v1_only("v2 does not support TensorArray.grad.")
def testSkipEagerTensorGradArrayWriteRead(self):
with self.session() as session:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=3,
infer_shape=False)
g_ta = ta.grad("grad")
w0 = ta.write(0, [[4.0, 5.0]])
w1 = w0.write(1, [[1.0]])
w2 = w1.write(2, -3.0)
g_w0 = g_ta.write(0, [[5.0, 6.0]])
g_w1 = g_w0.write(1, [[2.0]])
g_w2 = g_w1.write(2, -2.0)
r0 = w2.read(0)
r1 = w2.read(1)
r2 = w2.read(2)
g_r0 = g_w2.read(0)
g_r1 = g_w2.read(1)
g_r2 = g_w2.read(2)
d0, d1, d2, g_d0, g_d1, g_d2 = session.run([r0, r1, r2, g_r0, g_r1, g_r2])
self.assertAllEqual([[4.0, 5.0]], d0)
self.assertAllEqual([[1.0]], d1)
self.assertAllEqual(-3.0, d2)
self.assertAllEqual([[5.0, 6.0]], g_d0)
self.assertAllEqual([[2.0]], g_d1)
self.assertAllEqual(-2.0, g_d2)
@test_util.deprecated_graph_mode_only
def testSkipEagerTensorArrayGradGrad(self):
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
self.skipTest("Legacy TensorArray does not support double derivatives.")
with self.test_session() as session:
x = constant_op.constant(4.0)
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=1,
infer_shape=False)
w0 = ta.write(0, x)
r0 = w0.read(0)
y = r0 * r0
g1 = gradients_impl.gradients(ys=[y], xs=[x])
g2 = gradients_impl.gradients(ys=[g1], xs=[x])
self.assertAllEqual([2.0], session.run(g2))
@test_util.disable_control_flow_v2("v2 does not support TensorArray.grad.")
@test_util.run_v1_only("v2 does not support TensorArray.grad.")
def testSkipEagerTensorGradArrayDynamicWriteRead(self):
with self.session() as session:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=0,
dynamic_size=True,
infer_shape=False)
w0 = ta.write(0, [[4.0, 5.0]])
w1 = w0.write(1, [[1.0]])
w2 = w1.write(2, -3.0)
g_ta = w2.grad("grad") # Get gradient array here so we know the shape
s = w2.size()
g_s = g_ta.size()
g_w0 = g_ta.write(0, [[5.0, 6.0]])
g_w1 = g_w0.write(1, [[2.0]])
g_w2 = g_w1.write(2, -2.0)
r0 = w2.read(0)
r1 = w2.read(1)
r2 = w2.read(2)
g_r0 = g_w2.read(0)
g_r1 = g_w2.read(1)
g_r2 = g_w2.read(2)
d0, d1, d2, g_d0, g_d1, g_d2, vs, g_vs = session.run(
[r0, r1, r2, g_r0, g_r1, g_r2, s, g_s])
self.assertAllEqual([[4.0, 5.0]], d0)
self.assertAllEqual([[1.0]], d1)
self.assertAllEqual(-3.0, d2)
self.assertAllEqual([[5.0, 6.0]], g_d0)
self.assertAllEqual([[2.0]], g_d1)
self.assertAllEqual(-2.0, g_d2)
self.assertAllEqual(3, vs)
self.assertAllEqual(3, g_vs)
@test_util.disable_control_flow_v2("v2 does not support TensorArray.grad.")
@test_util.run_v1_only("v2 does not support TensorArray.grad.")
def testSkipEagerTensorGradAccessTwiceReceiveSameObject(self):
with self.session() as session:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
g_ta_0 = ta.grad("grad")
g_ta_1 = ta.grad("grad")
with ops.control_dependencies([g_ta_0.write(0, [[4.0, 5.0]]).flow]):
# Write with one gradient handle, read with another copy of it
r1_0 = g_ta_1.read(0)
t_g_ta_0, t_g_ta_1, d_r1_0 = session.run(
[g_ta_0.handle.op, g_ta_1.handle.op, r1_0])
self.assertAllEqual(t_g_ta_0, t_g_ta_1)
self.assertAllEqual([[4.0, 5.0]], d_r1_0)
def testTensorArrayWriteWrongIndexOrDataTypeFails(self):
with self.session():
ta = _make_ta(3, "foo", dtype=dtypes.float32)
# TODO(b/129870929): Remove the last 2 checks (runtime checks) after
# back back from preferred_dtype= to dtype= in convert_to_tensor. Also
# restrict error check to only TypeError.
error_msg_regex = (
"("
"Expected float32, got 'wrong_type_scalar' of type 'str' instead."
"|"
"Cannot convert provided value to EagerTensor. Provided value: "
"wrong_type_scalar Requested dtype: float"
"|"
"TensorArray dtype is float.* but Op is trying to write dtype string"
"|"
"Invalid data types; op elements string but list elements float"
")")
with self.assertRaisesRegex((TypeError, errors.InvalidArgumentError),
error_msg_regex):
self.evaluate(ta.write(0, "wrong_type_scalar").flow)
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
not context.executing_eagerly()):
error_msg = "Trying to modify element -1 in a list with 3 elements."
else:
error_msg = "index -1"
with self.assertRaisesOpError(error_msg):
self.evaluate(ta.write(-1, 3.0).flow)
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
not context.executing_eagerly()):
error_msg = "Trying to modify element 3 in a list with 3 elements"
else:
error_msg = ("Tried to write to index 3 but array is not "
"resizeable and size is: 3")
# Test reading from too large an index
with self.assertRaisesOpError(error_msg):
self.evaluate(ta.write(3, 3.0).flow)
def testTensorArrayReadWrongIndexOrDataTypeFails(self):
with self.session():
ta = _make_ta(3, "foo", dtype=dtypes.float32)
w0 = ta.write(0, [[4.0, 5.0]])
# Test reading wrong datatype (only possible when constructing graphs).
if (not context.executing_eagerly() and
not control_flow_util.ENABLE_CONTROL_FLOW_V2):
r0_bad = gen_data_flow_ops.tensor_array_read_v3(
handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow)
with self.assertRaisesOpError(
"TensorArray dtype is float but Op requested dtype double."):
self.evaluate(r0_bad)
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
not context.executing_eagerly()):
error_msg = "Trying to access element -1 in a list with 3 elements."
else:
error_msg = "index -1"
# Test reading from a negative index, which is not allowed
with self.assertRaisesOpError(error_msg):
self.evaluate(ta.read(-1))
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
not context.executing_eagerly()):
error_msg = "Trying to access element 3 in a list with 3 elements."
else:
error_msg = "Tried to read from index 3 but array size is: 3"
# Test reading from too large an index
with self.assertRaisesOpError(error_msg):
self.evaluate(ta.read(3))
@test_util.disable_control_flow_v2("v2 allows multiple writes.")
@test_util.run_v1_only("v2 allows multiple writes.")
def testSkipEagerTensorArrayWriteMultipleFails(self):
with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
with self.assertRaisesOpError(
"Could not write to TensorArray index 2 because "
"it has already been written to."):
self.evaluate(ta.write(2, 3.0).write(2, 3.0).flow)
def testTensorArrayConcatIncompatibleShapesFails(self):
with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=3,
infer_shape=False)
w1 = ta.write(0, 3.0)
w2 = w1.write(1, 4.0)
w3 = w2.write(2, [3.0])
with self.assertRaisesOpError(
"Concat saw a scalar shape at index 0 but requires at least vectors"):
self.evaluate(w3.concat())
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=3,
infer_shape=False)
w1 = ta.write(0, [3.0])
w2 = w1.write(1, [4.0])
w3 = w2.write(2, [[3.0]])
# The exact error messages differ between eager execution and graph
# construction as the former bubbles up the error from array_op.concat.
error_msg = ("Incompatible ranks"
if control_flow_util.ENABLE_CONTROL_FLOW_V2 and
not context.executing_eagerly() else "shape")
with self.assertRaisesRegex(errors.InvalidArgumentError, error_msg):
self.evaluate(w3.concat())
def testTensorArraySplitIncompatibleShapesFails(self):
with self.session():
in_eager_mode = context.executing_eagerly()
ta = _make_ta(3, "foo")
with self.assertRaisesOpError(
r"Expected lengths to be a vector, received shape: \[\]"):
if in_eager_mode:
self.evaluate(ta.split([1.0, 2.0, 3.0], 1))
else:
lengths = array_ops.placeholder(dtypes.int64)
ta.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1})
error_msg = ("Unused values in tensor. Length of tensor: 3 Values used: 1"
if control_flow_util.ENABLE_CONTROL_FLOW_V2 and
not in_eager_mode else
r"Expected sum of lengths to be equal to values.shape\[0\], "
r"but sum of lengths is 1 and value's shape is: \[3\]")
with self.assertRaisesOpError(error_msg):
self.evaluate(ta.split([1.0, 2.0, 3.0], [1]).flow)
ta = _make_ta(1, "baz")
if control_flow_util.ENABLE_CONTROL_FLOW_V2 and not in_eager_mode:
with self.assertRaisesRegex(
ValueError, "Shape must be at least rank 1 but is rank 0"):
self.evaluate(ta.split(1.0, [1]).flow)
else:
with self.assertRaisesOpError(
r"Expected value to be at least a vector, but received shape: \[\]"
):
self.evaluate(ta.split(1.0, [1]).flow)
if not control_flow_util.ENABLE_CONTROL_FLOW_V2 or in_eager_mode:
ta = _make_ta(2, "buz")
with self.assertRaisesOpError(
r"TensorArray's size is not equal to the size of lengths "
r"\(2 vs. 1\), and the TensorArray is not marked as "
r"dynamically resizeable"):
self.evaluate(ta.split([1.0], [1]).flow)
def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype):
with self.cached_session():
ta = tensor_array_ops.TensorArray(
dtype=dtype, tensor_array_name="foo", size=3, infer_shape=False)
ta_grad = ta.grad("grad")
c = lambda x: np.asarray(x, dtype=dtype.as_numpy_dtype)
w0 = ta.write(2, c(3.0))
w1 = w0.write(2, c(4.0))
w0_grad = ta_grad.write(2, c(3.0))
w1_grad = w0_grad.write(2, c(4.0))
w2_grad = w1_grad.write(2, c(5.0))
# Assert that aggregation works correctly
self.assertAllEqual(c(12.00), w2_grad.read(2))
# Assert that if multiple_writes_aggregate is not enabled,
# multiple writes raise an exception.
with self.assertRaisesOpError(
r"TensorArray foo_.*: Could not write to TensorArray index 2 because "
r"it has already been written to."):
self.evaluate(w1.flow)
# Using differing shapes causes an exception
wb0_grad = ta_grad.write(1, c(1.0))
wb1_grad = wb0_grad.write(1, c([1.0]))
with self.assertRaisesOpError(
r"Could not aggregate to TensorArray index 1 because the "
r"existing shape is \[\] but the new input shape is \[1\]"):
self.evaluate(wb1_grad.flow)
@test_util.disable_control_flow_v2("v2 does not support TensorArray.grad.")
@test_util.run_v1_only("v2 does not support TensorArray.grad.")
def testSkipEagerTensorArrayWriteGradientAddMultipleAdds(self):
for dtype in (dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64,
dtypes.complex64, dtypes.complex128):
self._testTensorArrayWriteGradientAddMultipleAdds(dtype)
@test_util.disable_control_flow_v2("Low level legacy TA op test.")
@test_util.run_v1_only("Low level legacy TA op test.")
def testSkipEagerTensorArrayGradWithShapeKnownElementShape(self):
with self.session() as sess:
ta = tensor_array_ops.TensorArray(
size=3,
dtype=dtypes.float32,
element_shape=tensor_shape.TensorShape([2, 3]))
handle, flow = data_flow_ops.tensor_array_grad_with_shape(
handle=ta.handle,
flow_in=ta.flow,
shape_to_prepend=tensor_shape.TensorShape([4, 5]),
source="source")
ta_grad = tensor_array_ops.TensorArray(
dtypes.float32, handle=handle, flow=flow)
value = array_ops.placeholder(dtypes.float32)
ta_grad = ta_grad.write(0, value)
read_value = ta_grad.read(0)
# Make sure shape inference worked.
self.assertAllEqual([None, None, 2, 3], read_value.shape.as_list())
# Writing with wrong shape should not work.
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Could not write to TensorArray"):
fed_value = np.random.random([2, 3])
sess.run(read_value, feed_dict={value: fed_value})
# Writing with correct shape should work.
fed_value = np.random.random([4, 5, 2, 3])
self.assertAllClose(fed_value,
sess.run(read_value, feed_dict={value: fed_value}))
@test_util.disable_control_flow_v2("Low level legacy TA op test.")
@test_util.run_v1_only("Low level legacy TA op test.")
def testSkipEagerTensorArrayGradWithShapeUnknownElementShape(self):
with self.session() as sess:
ta = tensor_array_ops.TensorArray(
size=3, dtype=dtypes.float32,
element_shape=None) # Note that element_shape is unknown
handle, flow = data_flow_ops.tensor_array_grad_with_shape(
handle=ta.handle,
flow_in=ta.flow,
shape_to_prepend=tensor_shape.TensorShape([4, 5]),
source="source")
ta_grad = tensor_array_ops.TensorArray(
dtypes.float32, handle=handle, flow=flow)
value = array_ops.placeholder(dtypes.float32)
ta_grad = ta_grad.write(0, value)
read_value = ta_grad.read(0)
# Make sure shape inference worked.
self.assertIsNone(read_value.shape.ndims)
# Write with some shape and check read value.
fed_value = np.random.random([4, 5, 7])
self.assertAllClose(fed_value,
sess.run(read_value, feed_dict={value: fed_value}))
def testMultiTensorArray(self):
with self.session():
h1 = tensor_array_ops.TensorArray(
size=1, dtype=dtypes.float32, tensor_array_name="foo")
w1 = h1.write(0, 4.0)
r1 = w1.read(0)
h2 = tensor_array_ops.TensorArray(
size=1, dtype=dtypes.float32, tensor_array_name="bar")
w2 = h2.write(0, 5.0)
r2 = w2.read(0)
r = r1 + r2
val = self.evaluate(r)
self.assertAllClose(9.0, val)
def _testTensorArrayGradientWriteReadType(self, dtype):
with self.cached_session() as session:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.as_dtype(dtype),
tensor_array_name="foo",
size=3,
infer_shape=False)
c = lambda x: np.array(x, dtype=dtype)
value_0 = constant_op.constant(c([[4.0, 5.0]]))
value_1 = constant_op.constant(c(3.0))
w0 = ta.write(0, value_0)
w1 = w0.write(1, value_1)
r0 = w1.read(0)
r1 = w1.read(1)
r0_2 = w1.read(0)
# Test individual components' gradients
grad_just_r0 = gradients_impl.gradients(
ys=[r0], xs=[value_0], grad_ys=[c([[2.0, 3.0]])])
grad_just_r0_vals = session.run(grad_just_r0)
self.assertAllEqual(c([[2.0, 3.0]]), grad_just_r0_vals[0])
grad_r0_r0_2 = gradients_impl.gradients(
ys=[r0, r0_2],
xs=[value_0],
grad_ys=[c([[2.0, 3.0]]), c([[1.0, -1.0]])])
grad_r0_r0_2_vals = session.run(grad_r0_r0_2)
self.assertAllEqual(c([[3.0, 2.0]]), grad_r0_r0_2_vals[0])
grad_just_r1 = gradients_impl.gradients(
ys=[r1], xs=[value_1], grad_ys=[c(-2.0)])
grad_just_r1_vals = session.run(grad_just_r1)
self.assertAllEqual(c(-2.0), grad_just_r1_vals[0])
# Test combined gradients
grad = gradients_impl.gradients(
ys=[r0, r0_2, r1],
xs=[value_0, value_1],
grad_ys=[c([[2.0, 3.0]]), c([[1.0, -1.0]]), c(-2.0)])
grad_vals = session.run(grad)
self.assertEqual(len(grad_vals), 2)
self.assertAllEqual(c([[3.0, 2.0]]), grad_vals[0])
self.assertAllEqual(c(-2.0), grad_vals[1])
@test_util.deprecated_graph_mode_only
def testSkipEagerTensorArrayGradientWriteRead(self):
for dtype in (np.float32, np.float64, np.complex64, np.complex128):
self._testTensorArrayGradientWriteReadType(dtype)
def _testTensorArrayGradientWritePackConcatAndRead(self):
with self.cached_session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=2,
clear_after_read=False)
value_0 = constant_op.constant([-1.0, 1.0])
value_1 = constant_op.constant([-10.0, 10.0])
w0 = ta.write(0, value_0)
w1 = w0.write(1, value_1)
p0 = w1.stack()
r0 = w1.read(0)
s0 = w1.concat()
# Test gradient accumulation between read(0), pack(), and concat()
with ops.control_dependencies([p0, r0, s0]):
grad_r = gradients_impl.gradients(
ys=[p0, r0, s0],
xs=[value_0, value_1],
grad_ys=[
[[2.0, 3.0], [4.0, 5.0]], # pack gradient
[-0.5, 1.5], # read(0) gradient
[20.0, 30.0, 40.0, 50.0]
]) # concat gradient
grad_vals = self.evaluate(grad_r) # 2 + 2 entries
self.assertAllClose([2.0 - 0.5 + 20.0, 3.0 + 1.5 + 30.0], grad_vals[0])
self.assertAllEqual([4.0 + 40.0, 5.0 + 50.0], grad_vals[1])
@test_util.deprecated_graph_mode_only
def testSkipEagerTensorArrayGradientWritePackConcatAndRead(self):
self._testTensorArrayGradientWritePackConcatAndRead()
@test_util.disable_control_flow_v2("v2 does not support clear_after_read.")
@test_util.run_v1_only("v2 does not support clear_after_read.")
def testTensorArrayReadTwice(self):
with self.session():
value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
ta_readonce = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=2)
w_readonce = ta_readonce.unstack(value)
r0_readonce = w_readonce.read(0)
with self.assertRaisesOpError(
r"Could not read index 0 twice because it was cleared after a "
r"previous read \(perhaps try setting clear_after_read = false\?\)"):
with ops.control_dependencies([r0_readonce]):
self.evaluate(w_readonce.read(0))
ta_readtwice = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=2,
clear_after_read=False)
w_readtwice = ta_readtwice.unstack(value)
r0_readtwice = w_readtwice.read(0)
with ops.control_dependencies([r0_readtwice]):
r1_readtwice = w_readtwice.read(0)
self.assertAllEqual([1.0, -1.0], self.evaluate(r1_readtwice))
def _testTensorArrayGradientUnpackRead(self):
with self.cached_session() as session:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=2,
clear_after_read=False)
value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
w = ta.unstack(value)
r0 = w.read(0)
r0_1 = w.read(0)
r1 = w.read(1)
# Test combined gradients + aggregation of read(0)
grad = gradients_impl.gradients(
ys=[r0, r0_1, r1],
xs=[value],
grad_ys=[[2.0, 3.0], [-1.5, 1.5], [4.0, 5.0]])
grad_vals = session.run(grad)
self.assertEqual(len(grad_vals), 1)
self.assertAllEqual([[2.0 - 1.5, 3.0 + 1.5], [4.0, 5.0]], grad_vals[0])
@test_util.deprecated_graph_mode_only
def testSkipEagerTensorArrayGradientUnpackRead(self):
self._testTensorArrayGradientUnpackRead()
@test_util.deprecated_graph_mode_only
def testSkipEagerTensorArrayGradientSplitConcat(self):
with self.session() as session:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=2,
infer_shape=False)
value = constant_op.constant(
[[1.0, -1.0], [10.0, -10.0], [100.0, -100.0]])
w = ta.split(value, [2, 1])
r = w.concat()
# Test combined gradients
grad = gradients_impl.gradients(
ys=[r],
xs=[value],
grad_ys=[[[2.0, -2.0], [20.0, -20.0], [200.0, -200.0]]])
grad_vals = session.run(grad)
self.assertEqual(len(grad_vals), 1)
self.assertAllEqual([[2.0, -2.0], [20.0, -20.0], [200.0, -200.0]],
grad_vals[0])
def _testTensorArrayGradientDynamicUnpackRead(self):
with self.cached_session() as session:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=0,
dynamic_size=True)
value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
w = ta.unstack(value)
r0 = w.read(0)
r1 = w.read(1)
# Test combined gradients + aggregation of read(0)
grad = gradients_impl.gradients(
ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]])
grad_vals = session.run(grad)
self.assertEqual(len(grad_vals), 1)
self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0])
@test_util.deprecated_graph_mode_only
def testSkipEagerTensorArrayGradientDynamicUnpackRead(self):
self._testTensorArrayGradientDynamicUnpackRead()
def testCloseTensorArray(self):
with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
self.evaluate(ta.close())
def testSizeTensorArray(self):
with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
s = ta.size()
self.assertAllEqual(3, self.evaluate(s))
def testWriteCloseTensorArray(self):
with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=3,
infer_shape=False)
w0 = ta.write(0, [[4.0, 5.0]])
w1 = w0.write(1, [3.0])
self.evaluate(w1.close()) # Expected to run without problems
def _testWhileLoopWritePackGradients(self, dynamic_size, dtype):
np_dtype = dtype.as_numpy_dtype
with self.cached_session():
def func(v0, state0, var):
ta = tensor_array_ops.TensorArray(
dtype=dtype,
tensor_array_name="foo",
size=0 if dynamic_size else 3,
dynamic_size=dynamic_size)
time_0 = array_ops.identity(0)
def body(time, ta_t, state):
sliced = array_ops.slice(
v0, begin=array_ops_stack.stack([time, 0]), size=[1, -1])
sliced = array_ops.squeeze(sliced)
out = sliced + var + state
state += sliced
ta_t = ta_t.write(time, out)
return (time + 1, ta_t, state)
(unused_0, h_final, unused_2) = while_loop.while_loop(
cond=lambda time, unused_1, unused_2: time < 3,
body=body,
loop_vars=(time_0, ta, state0),
shape_invariants=(time_0.get_shape(), tensor_shape.unknown_shape(),
tensor_shape.unknown_shape()),
parallel_iterations=3)
vout = h_final.stack()
return vout
v0 = array_ops.identity(np.arange(3 * 5, dtype=np_dtype).reshape(3, 5))
state0 = array_ops.identity(np.array([1] * 5, dtype=np_dtype))
init_val = np.arange(100, 105, dtype=np_dtype)
var = variable_scope.get_variable(
"var",
shape=init_val.shape,
dtype=np_dtype,
initializer=init_ops.constant_initializer(init_val))
vout = func(v0, state0, var)
grad_val = -np.arange(3 * 5, dtype=np_dtype).reshape(3, 5)
if context.executing_eagerly():
grad_fn = backprop.gradients_function(func)
v0_grad, state0_grad, var_grad = grad_fn(v0, state0, var, dy=grad_val)
else:
v0_grad = gradients_impl.gradients([vout], [v0], [grad_val])[0]
state0_grad = gradients_impl.gradients([vout], [state0], [grad_val])[0]
var_grad = gradients_impl.gradients([vout], [var], [grad_val])[0]
self.evaluate(variables.global_variables_initializer())
state0_t, var_t, v0_t, vout_t, v0_grad_t, var_grad_t, state0_grad_t = (
self.evaluate(
([state0, var, v0, vout, v0_grad, var_grad, state0_grad])))
just_v0_grad_t = self.evaluate(v0_grad)
# state = [ state0 | state0 + v0[0] | state0 + v0[0] + v0[1] ]
# vout = [ v0[0] + var + state[0] |
# v0[1] + var + state[1] |
# v0[2] + var + state[2] ]
# = [ v0[0] + var + state0 |
# v0[1] + var + state0 + v0[0] |
# v0[2] + var + state0 + v0[0] + v0[1] ]
#
# d(vout[0])/d(v0) = [1 | 0 | 0 ]
# d(vout[1])/d(v0) = [1 | 1 | 0 ]
# d(vout[2])/d(v0) = [1 | 1 | 1 ]
# d(vout)/d(var) = [1 | 1 | 1]
# d(vout)/d(state0) = [ 1 | 1 | 1 ]
state_per_time = np.array(
[state0_t, state0_t + v0_t[0, :], state0_t + v0_t[0, :] + v0_t[1, :]])
# Compare forward prop
self.assertAllClose(v0_t + var_t + state_per_time, vout_t)
# Compare backward prop
expected_v0_grad_t = np.array([
grad_val[0, :] + grad_val[1, :] + grad_val[2, :],
grad_val[1, :] + grad_val[2, :], grad_val[2, :]
])
self.assertAllEqual(expected_v0_grad_t, v0_grad_t)
self.assertAllEqual(expected_v0_grad_t, just_v0_grad_t)
self.assertAllClose(grad_val.sum(axis=0), var_grad_t)
self.assertAllClose(grad_val.sum(axis=0), state0_grad_t)
def testWhileLoopWritePackGradients(self):
self._testWhileLoopWritePackGradients(
dynamic_size=False, dtype=dtypes.float32)
# TODO(ebrevdo): re-enable when While supports non-float32 gradients.
# self._testWhileLoopWritePackGradients(
# dynamic_size=False, dtype=tf.int64)
@test_util.run_deprecated_v1
def testSkipEagerWhileLoopDynamicWritePackGradients(self):
self._testWhileLoopWritePackGradients(
dynamic_size=True, dtype=dtypes.float32)
def testGradSerialTwoLoops(self):
with self.session():
def loop(x):
num_steps = 100
acc = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
size=num_steps,
clear_after_read=False,
element_shape=tensor_shape.TensorShape([]))
i = constant_op.constant(0, name="i")
c = lambda i, acc: i < 5
def b(i, acc):
x1 = cond.cond(
math_ops.equal(i, 0), lambda: x,
lambda: math_ops.multiply(acc.read(i - 1), 2.0))
return i + 1, acc.write(i, x1)
i1, acc1 = while_loop.while_loop(c, b, [i, acc])
z = constant_op.constant(0.0)
def fn(i, acc):
return i + 1, acc.write(i, z)
_, acc2 = while_loop.while_loop(lambda i, acc: i < num_steps, fn,
[i1, acc1])
r = acc2.stack()
return r
x = constant_op.constant(2.0, name="x")
if context.executing_eagerly():
grad = backprop.gradients_function(loop)(x)[0]
else:
grad = gradients_impl.gradients(loop(x), [x])[0]
self.assertAllClose(31.0, self.evaluate(grad))
def testShapeAfterWhileLoop(self):
size = 10
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=size)
_, ta = while_loop.while_loop(
lambda i, _: i < size,
lambda i, ta: (i + 1, ta.write(i, [[0.]])), [0, ta],
parallel_iterations=1)
self.assertIsNotNone(ta.element_shape.dims)
@test_util.deprecated_graph_mode_only
def testSkipEagerSumOfTwoReadVariablesWithoutRepeatGrad(self):
with self.session() as session:
a = array_ops.identity(
np.arange(
3 * 5, dtype=np.float32).reshape(3, 5) + 1)
b = array_ops.identity(
np.arange(
3 * 5, dtype=np.float32).reshape(3, 5) + 1 + 3 * 5)
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
ta = ta.write(0, a, name="write_a")
ta = ta.write(1, b, name="write_b")
c = (
ta.read(
0, name="read_a_0") + # a + b
ta.read(
1, name="read_b_0"))
g0 = -(np.arange(3 * 5, dtype=np.float32).reshape(3, 5) + 1)
grad_a = gradients_impl.gradients([c], [a], [g0])[0] # d(a+b)/da = 1
grad_b = gradients_impl.gradients([c], [b], [g0])[0] # d(a+b)/db = 1
# Test gradients calculated individually
grad_a_t, = session.run([grad_a])
self.assertAllEqual(grad_a_t, g0)
grad_b_t, = session.run([grad_b])
self.assertAllEqual(grad_b_t, g0)
# Test gradients calculated jointly
joint_grad_a_t, joint_grad_b_t = session.run([grad_a, grad_b])
self.assertAllEqual(joint_grad_a_t, g0)
self.assertAllEqual(joint_grad_b_t, g0)
def _grad_source_for_name(self, name):
return tensor_array_grad._GetGradSource(constant_op.constant(0, name=name))
@test_util.deprecated_graph_mode_only
def testSkipEagerGetGradSource_Invalid(self):
with self.assertRaises(ValueError):
self._grad_source_for_name("")
with self.assertRaises(ValueError):
self._grad_source_for_name("foo")
with self.assertRaises(ValueError):
self._grad_source_for_name("foo/bar")
@test_util.deprecated_graph_mode_only
def testSkipEagerGetGradSource_NoEnclosingScope(self):
self.assertEqual("gradients:0", self._grad_source_for_name("gradients"))
self.assertEqual("gradients_0:0", self._grad_source_for_name("gradients_0"))
self.assertEqual("gradients", self._grad_source_for_name("gradients/foo"))
self.assertEqual("gradients_0",
self._grad_source_for_name("gradients_0/foo"))
self.assertEqual("gradients",
self._grad_source_for_name("gradients/foo/bar"))
self.assertEqual("gradients_0",
self._grad_source_for_name("gradients_0/foo/bar"))
@test_util.deprecated_graph_mode_only
def testSkipEagerGetGradSource_EnclosingScope(self):
self.assertEqual("foo/gradients:0",
self._grad_source_for_name("foo/gradients"))
self.assertEqual("foo/gradients_0:0",
self._grad_source_for_name("foo/gradients_0"))
self.assertEqual("foo/gradients",
self._grad_source_for_name("foo/gradients/bar"))
self.assertEqual("foo/gradients_0",
self._grad_source_for_name("foo/gradients_0/bar"))
self.assertEqual("foo/bar/gradients",
self._grad_source_for_name("foo/bar/gradients/baz"))
self.assertEqual("foo/bar/gradients_0",
self._grad_source_for_name("foo/bar/gradients_0/baz"))
@test_util.deprecated_graph_mode_only
def testSkipEagerGetGradSource_NestedUsesInnermost(self):
self.assertEqual(
"foo/gradients/bar/gradients_0",
self._grad_source_for_name("foo/gradients/bar/gradients_0/baz"))
@test_util.deprecated_graph_mode_only
def testSkipEagerWriteShape(self):
with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
c0 = constant_op.constant([4.0, 5.0])
w0 = ta.write(0, c0)
r0 = w0.read(0)
self.assertAllEqual(c0.get_shape(), r0.get_shape())
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
c1 = constant_op.constant([6.0, 7.0])
w1 = w0.write(1, c1)
r0 = w1.read(0)
r1 = w1.read(1)
self.assertAllEqual(c0.get_shape(), r0.get_shape())
self.assertAllEqual(c1.get_shape(), r1.get_shape())
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
c2 = constant_op.constant([4.0, 5.0, 6.0])
with self.assertRaises(ValueError):
w0.write(0, c2)
@test_util.deprecated_graph_mode_only
def testSkipEagerPartlyUnknownShape(self):
with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=6)
c0 = array_ops.placeholder(dtypes.float32, [None, None, None, 3])
w0 = ta.write(0, c0)
r0 = w0.read(0)
self.assertAllEqual([None, None, None, 3], r0.get_shape().as_list())
c1 = array_ops.placeholder(dtypes.float32, [None, None, None, 3])
w1 = w0.write(1, c1)
r1 = w1.read(0)
self.assertAllEqual([None, None, None, 3], r1.get_shape().as_list())
# Writing less specific shape (doesn't change type.)
c2 = array_ops.placeholder(dtypes.float32, [None, None, None, None])
w2 = w1.write(2, c2)
r2 = w2.read(0)
self.assertAllEqual([None, None, None, 3], r2.get_shape().as_list())
# Writing more specific shape in one dimension and less specific in
# another.
c3 = array_ops.placeholder(dtypes.float32, [None, None, 2, None])
w3 = w2.write(3, c3)
r3 = w3.read(0)
self.assertAllEqual([None, None, 2, 3], r3.get_shape().as_list())
# Writing partly defined shape using TensorArray.scatter.
c4 = array_ops.placeholder(dtypes.float32, [2, None, 4, 2, 3])
w4 = w3.scatter([4, 5], c4)
r4 = w4.read(0)
self.assertAllEqual([None, 4, 2, 3], r4.get_shape().as_list())
# Writing fully defined shape using TensorArray.split.
c5 = array_ops.placeholder(dtypes.float32, [10, 4, 2, 3])
w5 = w4.split(c5, constant_op.constant([5, 5]))
r5 = w5.read(0)
self.assertAllEqual([5, 4, 2, 3], r5.get_shape().as_list())
def _testUnpackShape(self):
with self.cached_session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=0,
dynamic_size=True,
infer_shape=True)
value = constant_op.constant(
[[1.0, -1.0], [10.0, -10.0], [100.0, -100.0]])
w0 = ta.unstack(value)
r0 = w0.read(0)
self.assertAllEqual((2,), r0.get_shape())
c1 = constant_op.constant([4.0, 5.0])
w1 = w0.write(3, c1)
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
# TensorArray v2 does not support clear_after_read.
with self.assertRaisesOpError(
r"Could not read index 0 twice because it was cleared after a "
r"previous read \(perhaps try setting clear_after_read = false\?\)"
):
with ops.control_dependencies([r0]):
self.evaluate(w1.read(0))
r1 = w1.read(1)
self.assertAllEqual(c1.get_shape(), r1.shape)
c2 = constant_op.constant([4.0, 5.0, 6.0])
with self.assertRaises(ValueError):
w1.write(4, c2)
def testUnpackShape(self):
self._testUnpackShape()
@test_util.deprecated_graph_mode_only
def testSplitShape(self):
with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=0,
dynamic_size=True,
infer_shape=True)
value = constant_op.constant([[1.0, -1.0], [2.0, -2.0], [3.0, -3.0]])
w0 = ta.split(value, [1, 1, 1])
r0 = w0.read(0)
self.assertAllEqual((1, 2), r0.get_shape())
ta1 = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo1",
size=0,
dynamic_size=True,
infer_shape=True)
w0 = ta1.split(value, [1, 2])
r0 = w0.read(0)
if context.executing_eagerly():
self.assertEqual((1, 2), r0.get_shape())
self.assertEqual((2, 2), w0.read(1).get_shape())
else:
self.assertEqual(r0.get_shape().ndims, None)
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
self.assertEqual(
tensor_shape.TensorShape(
ta1.handle.op.get_attr("element_shape")).ndims, None)
@test_util.deprecated_graph_mode_only
def testSkipEagerWriteUnknownShape(self):
with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=3,
infer_shape=True)
c0 = array_ops.placeholder(dtypes.float32)
w0 = ta.write(0, c0)
r0 = w0.read(0)
self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape())
def _testGradientWhenNotAllComponentsRead(self):
with self.cached_session() as session:
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
x = constant_op.constant([2.0, 3.0])
w = ta.unstack(x)
r0 = w.read(0)
# calculate (dr0/dx0, dr0/dx1). since r0 = x0, gradients are (1, 0).
grad_r0 = gradients_impl.gradients(ys=[r0], xs=[x], grad_ys=[1.0])
grad_r0_vals = session.run(grad_r0)[0]
self.assertAllEqual(grad_r0_vals, [1.0, 0.0])
@test_util.deprecated_graph_mode_only
def testSkipEagerGradientWhenNotAllComponentsRead(self):
self._testGradientWhenNotAllComponentsRead()
@test_util.deprecated_graph_mode_only
def testSkipEagerWriteButNotAllComponentsReadGrad(self):
with self.cached_session() as session:
x0 = constant_op.constant(5.0)
x1 = constant_op.constant(10.0)
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=2).write(0, x0).write(1, x1)
r0 = ta.read(0)
# calculate (dr0/dx0, dr0/dx1). since r0 = x0, gradients are (1, 0).
grad_r0_x1 = gradients_impl.gradients(ys=[r0], xs=[x0, x1], grad_ys=[1.0])
grad_r0_x1_vals = session.run(grad_r0_x1)
self.assertAllEqual(grad_r0_x1_vals, [1.0, 0.0])
def _testTensorArrayUnpackDynamic(self):
with self.cached_session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=3, dynamic_size=True)
x = constant_op.constant([1.0, 2.0, 3.0])
w0 = ta.unstack(x)
w1 = w0.write(3, 4.0)
r = w1.stack()
self.assertAllEqual(np.array([1.0, 2.0, 3.0, 4.0]), self.evaluate(r))
grad = gradients_impl.gradients(ys=[r], xs=[x])
self.assertAllEqual(np.array([1.0, 1.0, 1.0]), self.evaluate(grad)[0])
@test_util.run_deprecated_v1
def testSkipEagerTensorArrayUnpackDynamic(self):
self._testTensorArrayUnpackDynamic()
@test_util.run_deprecated_v1
def testSkipEagerTensorArraySplitDynamic(self):
with self.session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=3, dynamic_size=True)
x = constant_op.constant([1.0, 2.0, 3.0])
w0 = ta.split(x, [1, 1, 1])
w1 = w0.write(3, [4.0])
r = w1.concat()
self.assertAllEqual(np.array([1.0, 2.0, 3.0, 4.0]), self.evaluate(r))
grad = gradients_impl.gradients(ys=[r], xs=[x])
self.assertAllEqual(np.array([1.0, 1.0, 1.0]), self.evaluate(grad)[0])
def testStackShape(self):
@def_function.function
def ta_stack():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=3)
x = constant_op.constant([1.0, 2.0, 3.0])
ta = ta.write(0, x)
t = ta.stack()
self.assertEqual(t.shape.as_list(), [3, 3])
return t
ta_stack()
def testReadShape(self):
@def_function.function
def ta_read():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=3)
x = constant_op.constant([1.0, 2.0, 3.0])
ta = ta.write(0, x)
t = ta.read(0)
self.assertEqual(t.shape.as_list(), [3])
return t
ta_read()
def testGatherShape(self):
def ta_gather(indices):
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=3)
x = constant_op.constant([1.0, 2.0, 3.0])
ta = ta.write(0, x)
t = ta.gather(indices)
self.assertEqual(t.shape.as_list(), [first_dim, 3])
return t
# This propagates shape of `indices` when compiling ta_gather.
ta_gather_with_known_indices_shape = def_function.function(ta_gather)
first_dim = 1
ta_gather_with_known_indices_shape([0])
# Here were force the shape of `indices` to be [None] during ta_gather's
# compilation.
ta_gather_with_unknown_indices_shape = def_function.function(
ta_gather,
input_signature=[
tensor_spec.TensorSpec(dtype=dtypes.int32, shape=[None])
])
first_dim = None
ta_gather_with_unknown_indices_shape([0])
def _testTensorArrayEvalEmpty(self):
with self.cached_session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=0, dynamic_size=False, infer_shape=False)
v2_msg = ("Tried to stack elements of an empty list with "
"non-fully-defined element_shape")
v1_msg = (
"TensorArray has size zero, but element shape <unknown> is not "
"fully defined. Currently only static shapes are supported when "
"packing zero-size TensorArrays.")
with self.assertRaisesOpError(
v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
ta.stack().eval()
@test_util.run_deprecated_v1
def testSkipEagerTensorArrayEvalEmpty(self):
self._testTensorArrayEvalEmpty()
# this test is ill-defined for Eager mode --- unpacking an empty tensor
# gives an empty list / there is not equivalent of "mark_used" in Eager
def _testTensorArrayEvalEmptyWithDefault(self):
with self.cached_session():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=0, dynamic_size=False, infer_shape=True)
self.assertEqual(0, ta.size().eval())
# Don't actually perform the pack. This stores the static shape.
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
ta = ta.unstack(array_ops.zeros([0, 3, 5]))
else:
ta.unstack(array_ops.zeros([0, 3, 5])).mark_used()
packed = ta.stack()
concatenated = ta.concat()
self.assertAllEqual([0, 3, 5], self.evaluate(packed).shape)
# Concatenating zero tensors along their first dimension gives a
# first dimension of zero
self.assertAllEqual([0, 5], self.evaluate(concatenated).shape)
@test_util.run_deprecated_v1
def testSkipEagerTensorArrayEvalEmptyWithDefault(self):
self._testTensorArrayEvalEmptyWithDefault()
@test_util.run_deprecated_v1
def testSkipEagerTensorArrayScatterReadAndGradients(self):
with self.session() as session:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=0,
dynamic_size=True)
indices = constant_op.constant([1, 8])
value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
w = ta.scatter(indices, value)
r0 = w.read(1)
r1 = w.read(8)
# Test combined gradients + aggregation of read(0)
grad = gradients_impl.gradients(
ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]])
read_vals, grad_vals = session.run([[r0, r1], grad])
self.assertEqual(len(read_vals), 2)
self.assertEqual(len(grad_vals), 1)
self.assertAllEqual([1.0, -1.0], read_vals[0])
self.assertAllEqual([10.0, -10.0], read_vals[1])
self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0])
@test_util.run_deprecated_v1
def testSkipEagerTensorArrayScatterPartialReadAndGradients(self):
with self.session() as session:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=0,
dynamic_size=True)
indices = constant_op.constant([1, 8])
value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
w = ta.scatter(indices, value)
r0 = w.read(1)
# Test combined gradients + aggregation of read(0)
grad = gradients_impl.gradients(
ys=[r0], xs=[value], grad_ys=[[2.0, 3.0]])[0]
read_val, grad_val = session.run([r0, grad])
self.assertAllEqual([1.0, -1.0], read_val)
self.assertAllEqual([[2.0, 3.0], [0.0, 0.0]], grad_val)
def testScatterIntoExistingList(self):
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=5)
ta = ta.scatter(indices=[3, 4], value=array_ops.ones([2]))
self.assertAllEqual(ta.stack(), [0., 0., 0., 1., 1.])
ta = ta.scatter(indices=[1], value=array_ops.ones([1]))
self.assertAllEqual(ta.stack(), [0., 1., 0., 1., 1.])
ta = ta.scatter(indices=[0, 2], value=[5., 6.])
self.assertAllEqual(ta.stack(), [5., 1., 6., 1., 1.])
@test_util.run_v1_only("b/118890905")
def testTensorArrayWriteGatherAndGradients(self):
with self.session() as session:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=0,
dynamic_size=True)
def func(values):
indices = constant_op.constant([1, 8])
w = ta.unstack(values)
g = w.gather(indices)
return g
values = constant_op.constant([[1.0 * x, -1.0 * x] for x in range(10)])
g = func(values)
grad_ys = [[[2.0, 3.0], [4.0, 5.0]]]
# Test combined gradients + aggregation of read(0)
if context.executing_eagerly():
g_vals = [g]
grad_vals = backprop.gradients_function(func)(
values, dy=constant_op.constant(grad_ys[0], dtype=dtypes.float32))
else:
grad = gradients_impl.gradients(ys=[g], xs=[values], grad_ys=grad_ys)
g_vals, grad_vals = session.run([[g], grad])
# Gradients for 8 of the 10 unread components are zero.
expected_grad = np.zeros((10, 2))
expected_grad[1] = [2.0, 3.0]
expected_grad[8] = [4.0, 5.0]
self.assertEqual(len(g_vals), 1)
self.assertEqual(len(grad_vals), 1)
self.assertAllEqual([[1.0, -1.0], [8.0, -8.0]], g_vals[0])
self.assertAllEqual(expected_grad, grad_vals[0])
@test_util.disable_control_flow_v2("colocate_with not supported in v2.")
@test_util.run_v1_only("b/120545219")
def testSkipEagerTensorArrayGetsDeviceFromFirstWrite(self):
with ops.device("/job:worker/task:0/cpu:0"):
# this initial device will be ignored.
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
with ops.device("/job:worker/task:1/cpu:0"):
# the first write sets the op's device.
ta = ta.write(0, 1.0)
with ops.device("/job:worker/task:2/cpu:0"):
# subsequent writes do not modify the op's device.
ta = ta.write(1, 1.0)
# The gradient TA will sit on the same device as the forward TA.
ta_grad = ta.grad("grad")
flows = [ta.flow, ta_grad.flow]
# Similar tests for unpack and split
with ops.device("/job:worker/task:0/cpu:0"):
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=3)
with ops.device("/job:worker/task:1/cpu:0"):
ta = ta.unstack([1.0, 2.0])
with ops.device("/job:worker/task:2/cpu:0"):
ta = ta.write(2, 3.0)
flows.append(ta.flow)
with ops.device("/job:worker/task:0/cpu:0"):
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
with ops.device("/job:worker/task:1/cpu:0"):
ta = ta.split([1.0, 2.0], [1, 1])
flows.append(ta.flow)
session = session_lib.Session(self._workers[0].target)
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()
session.run(flows, options=run_options, run_metadata=run_metadata)
self.assertTrue(run_metadata.HasField("step_stats"))
dev_stats = {d.device: d.node_stats
for d in run_metadata.step_stats.dev_stats}
for d in dev_stats:
if "/task:1/" in d:
self.assertTrue(
[s for s in dev_stats[d] if "/TensorArray" in s.node_name])
elif "/host:CPU" not in d:
self.assertFalse(
[s for s in dev_stats[d] if "/TensorArray" in s.node_name])
@test_util.disable_control_flow_v2("colocate_with not supported in v2.")
@test_util.run_v1_only("b/120545219")
def testSkipEagerTensorArrayGetsDeviceFromFirstWriteInWhileLoop(self):
with ops.device("/job:worker/task:0/cpu:0"):
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
def _body(i, ta_i):
with ops.device("/job:worker/task:1/cpu:0"):
return i + 1, ta_i.write(i, constant_op.constant(0.0))
_, ta_out = while_loop.while_loop(
lambda i, ta: i < 2, _body, loop_vars=[0, ta])
session = session_lib.Session(self._workers[0].target)
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()
session.run(ta_out.flow, options=run_options, run_metadata=run_metadata)
self.assertTrue(run_metadata.HasField("step_stats"))
dev_stats = {d.device: d.node_stats
for d in run_metadata.step_stats.dev_stats}
for d in dev_stats:
if "/task:1/" in d:
self.assertTrue(
[s for s in dev_stats[d] if "TensorArray" == s.node_name])
else:
self.assertFalse(
[s for s in dev_stats[d] if "TensorArray" == s.node_name])
@test_util.disable_control_flow_v2("colocate_with not supported in v2.")
@test_util.run_v1_only("b/120545219")
def testSkipEagerTensorArrayDisabledColocateWithFirstWriteCall(self):
with ops.device("/job:worker/task:0/cpu:0"):
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=2, colocate_with_first_write_call=False)
def _body(i, ta_i):
with ops.device("/job:worker/task:1/cpu:0"):
return i + 1, ta_i.write(i, constant_op.constant(0.0))
_, ta_out = while_loop.while_loop(
lambda i, ta: i < 2, _body, loop_vars=[0, ta])
session = session_lib.Session(self._workers[0].target)
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()
session.run(ta_out.flow, options=run_options, run_metadata=run_metadata)
self.assertTrue(run_metadata.HasField("step_stats"))
dev_stats = {d.device: list(d.node_stats)
for d in run_metadata.step_stats.dev_stats}
for d in dev_stats:
if "/task:0/" in d and "CPU" in d: # Skip any GPU node stats
self.assertTrue(
[s for s in dev_stats[d] if "TensorArray" == s.node_name])
else:
self.assertFalse(
[s for s in dev_stats[d] if "TensorArray" == s.node_name])
def testTensorArrayIdentity(self):
with self.session():
ta0 = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2,
infer_shape=False)
ta1 = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=4,
infer_shape=True)
ta0 = ta0.write(0, 0.)
ta1 = ta1.write(0, 1)
v0 = variable_scope.get_variable(
"v0", shape=(), initializer=init_ops.zeros_initializer())
v1 = variable_scope.get_variable(
"v1", shape=(), initializer=init_ops.zeros_initializer())
with ops.control_dependencies([v0.assign_add(1)]):
ta0 = ta0.identity()
with ops.control_dependencies([v1.assign_add(1)]):
ta1 = ta1.identity()
read0 = ta0.read(0)
read1 = ta1.read(0)
size0 = ta0.size()
size1 = ta1.size()
# Tests correct properties on new TensorArrays.
self.assertEqual(dtypes.float32, ta0.dtype)
self.assertEqual(dtypes.int32, ta1.dtype)
if context.executing_eagerly():
self.assertEqual(tensor_shape.TensorShape([]), read0.get_shape())
else:
self.assertEqual(tensor_shape.unknown_shape(), read0.get_shape())
self.assertEqual(tensor_shape.TensorShape([]), read1.get_shape())
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
read0_v, read1_v, size0_v, size1_v = self.evaluate((read0, read1, size0,
size1))
# Tests that the control dependencies was added and executed.
self.assertEqual(1, self.evaluate(v0))
self.assertEqual(1, self.evaluate(v1))
# Tests correct TensorArray.
self.assertEqual(read0_v, 0)
self.assertEqual(read1_v, 1)
self.assertEqual(size0_v, 2)
self.assertEqual(size1_v, 4)
@test_util.deprecated_graph_mode_only
def testSkipEagerTensorArrayGradYsInCorrectScope(self):
n_time = 1
n_dim = 1
x = constant_op.constant([[1.42]])
dy = constant_op.constant([[2.42]])
ta = tensor_array_ops.TensorArray(
dtypes.float32, size=n_time, element_shape=[n_dim])
for t in range(n_time):
ta = ta.write(index=t, value=x[t])
y = ta.stack()
# dy is outside of the gradients name scope; tf.gradients must
# wrap it in the correct name scope.
dx, = gradients_impl.gradients(ys=[y], xs=[x], grad_ys=[dy])
with self.cached_session():
vdx, vdy = self.evaluate([dx, dy])
self.assertAllClose(vdx, vdy)
@test_util.deprecated_graph_mode_only
def testSkipEagerTensorArrayInt64GPU(self):
if not test.is_gpu_available():
return
with self.session(force_gpu=True) as sess:
value = array_ops.placeholder(dtypes.int64)
ta = tensor_array_ops.TensorArray(dtype=dtypes.int64, size=2)
ta = ta.scatter([0, 1], value)
r0 = ta.read(0)
r1 = ta.read(1)
v0, v1 = sess.run([r0, r1], feed_dict={value: [-3, 100]})
self.assertAllEqual(v0, -3)
self.assertAllEqual(v1, 100)
@test_util.deprecated_graph_mode_only
def testTensorArrayScatterBfloat16GPU(self):
if not test.is_gpu_available():
return
with self.session(force_gpu=True) as sess:
ta = tensor_array_ops.TensorArray(
dtype=dtypes.bfloat16, tensor_array_name="foo", size=5)
ta = ta.scatter(
indices=[3, 4], value=array_ops.ones([2], dtype=dtypes.bfloat16))
self.assertAllEqual(ta.stack(), [0., 0., 0., 1., 1.])
def testInferShapeFalseValid(self):
ta = tensor_array_ops.TensorArray(
dtypes.float32, size=3, infer_shape=False, element_shape=[None, 10, 20])
ta = ta.write(0, array_ops.ones([50, 10, 20]))
ta = ta.write(1, array_ops.ones([50, 10, 20]))
ta = ta.write(2, array_ops.ones([1, 10, 20]))
ta = ta.concat()
correct = np.ones([101, 10, 20])
self.assertAllEqual(ta, correct)
def testInferShapeFalseInvalid(self):
ta = tensor_array_ops.TensorArray(
dtypes.float32, size=2, infer_shape=False, element_shape=[None, 10, 20])
ta = ta.write(0, array_ops.ones([50, 10, 20]))
with self.assertRaises(ValueError):
ta = ta.write(1, array_ops.ones([1, 20, 20]))
def testInferShapeTrue(self):
ta = tensor_array_ops.TensorArray(
dtypes.float32, size=3, infer_shape=True, element_shape=[None, 10, 20])
self.assertAllEqual((None, 10, 20), ta.element_shape.as_list())
ta = ta.write(0, array_ops.ones([50, 10, 20]))
self.assertAllEqual((50, 10, 20), ta.element_shape.as_list())
ta = ta.write(1, array_ops.ones([50, 10, 20]))
with self.assertRaises(ValueError):
ta = ta.write(
2, array_ops.ones([1, 10, 20])
) # Inconsistent shapes: saw (1, 10, 20) but expected (50, 10, 20)
def testStackShapeOnEmpty(self):
ta = tensor_array_ops.TensorArray(
dtypes.float32, size=0, element_shape=(5, 10), dynamic_size=True)
self.assertAllEqual([0, 5, 10], self.evaluate(ta.stack()).shape)
@test_util.run_deprecated_v1
def testSkipEagerStackOnPartiallyDefinedShape(self):
ta = tensor_array_ops.TensorArray(
dtypes.float32, size=0, element_shape=(5, None), dynamic_size=True)
self.assertEqual([None, 5, None], ta.stack().shape.as_list())
def testStackShapeOnStaticSize(self):
ta = tensor_array_ops.TensorArray(dtypes.float32, size=42)
ta = ta.write(0, [0])
self.assertEqual([42, 1], ta.stack().shape.as_list())
def testTensorArrayConcatFailsWhenMissingStepContainer(self):
@def_function.function
def func():
y = data_flow_ops.TensorArrayConcatV2(
handle=["a", "b"],
flow_in=0.1,
dtype=dtypes.int32,
element_shape_except0=1,
)
return y
with self.assertRaisesRegex(
errors.NotFoundError, "Container .* does not exist"
):
self.evaluate(func())
class TensorArrayBenchmark(test.Benchmark):
def _tensorArrayWriteInWhile(self):
size = 10000
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=size)
(_, ta) = while_loop.while_loop(
lambda i, _: i < size,
lambda i, ta: (i + 1, ta.write(i, 0.)), [0, ta],
parallel_iterations=1)
return ta.stack()
def _benchmarkWriteInWhile(self):
ops.reset_default_graph()
op = self._tensorArrayWriteInWhile()
self.run_op_benchmark(session_lib.Session(), op)
def benchmarkWriteInWhile(self):
self._benchmarkWriteInWhile()
@test_util.enable_control_flow_v2
def benchmarkWriteInWhileWithControlFlowV2(self):
self._benchmarkWriteInWhile()
def benchmarkWriteInDatasetMapFn(self):
ds = dataset_ops.Dataset.from_tensors(array_ops.zeros([10])).repeat()
ds = ds.map(lambda _: self._tensorArrayWriteInWhile())
op = ds.make_one_shot_iterator().get_next()
self.run_op_benchmark(session_lib.Session(), op)
def benchmarkWriteInDatasetParallelMapFn(self):
ds = dataset_ops.Dataset.from_tensors(array_ops.zeros([10])).repeat()
ds = ds.map(lambda _: self._tensorArrayWriteInWhile(), num_parallel_calls=2)
op = ds.make_one_shot_iterator().get_next()
self.run_op_benchmark(session_lib.Session(), op)
if __name__ == "__main__":
test.main()