tensorflow/python/ops/ragged/ragged_tensor_test.py
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for third_party.tensorflow.python.ops.ragged_tensor."""
import functools
from absl.testing import parameterized
import numpy as np
from tensorflow.core.framework import full_type_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor as tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.framework import type_spec
from tensorflow.python.framework.type_utils import fulltypes_for_flat_tensors
from tensorflow.python.ops import array_grad # pylint: disable=unused-import
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import cond
from tensorflow.python.ops import gen_ragged_conversion_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_gather_ops
from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensorSpec
from tensorflow.python.ops.ragged.row_partition import RowPartition
from tensorflow.python.platform import googletest
from tensorflow.python.util import nest
def int32array(values):
return np.array(values, dtype=np.int32)
@test_util.run_all_in_graph_and_eager_modes
class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
longMessage = True # Property in unittest.Testcase. pylint: disable=invalid-name
#=============================================================================
# RaggedTensor class docstring examples
#=============================================================================
def testClassDocStringExamples(self):
# From section: "Component Tensors"
rt = RaggedTensor.from_row_splits(
values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
self.assertAllEqual(rt, [[3, 1, 4, 1], [], [5, 9, 2], [6], []])
del rt
# From section: "Alternative Row-Partitioning Schemes"
values = [3, 1, 4, 1, 5, 9, 2, 6]
rt1 = RaggedTensor.from_row_splits(values, row_splits=[0, 4, 4, 7, 8, 8])
rt2 = RaggedTensor.from_row_lengths(values, row_lengths=[4, 0, 3, 1, 0])
rt3 = RaggedTensor.from_value_rowids(
values, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5)
rt4 = RaggedTensor.from_row_starts(values, row_starts=[0, 4, 4, 7, 8])
rt5 = RaggedTensor.from_row_limits(values, row_limits=[4, 4, 7, 8, 8])
for rt in (rt1, rt2, rt3, rt4, rt5):
self.assertAllEqual(rt, [[3, 1, 4, 1], [], [5, 9, 2], [6], []])
del rt1, rt2, rt3, rt4, rt5
# From section: "Multiple Ragged Dimensions"
inner_rt = RaggedTensor.from_row_splits(
values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
outer_rt = RaggedTensor.from_row_splits(
values=inner_rt, row_splits=[0, 3, 3, 5])
self.assertEqual(outer_rt.ragged_rank, 2)
self.assertAllEqual(outer_rt,
[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
del inner_rt, outer_rt
# From section: "Multiple Ragged Dimensions"
rt = RaggedTensor.from_nested_row_splits(
flat_values=[3, 1, 4, 1, 5, 9, 2, 6],
nested_row_splits=([0, 3, 3, 5], [0, 4, 4, 7, 8, 8]))
self.assertAllEqual(rt, [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
del rt
# From section: "Uniform Inner Dimensions"
rt = RaggedTensor.from_row_splits(
values=array_ops.ones([5, 3]), row_splits=[0, 2, 5])
self.assertAllEqual(
rt, [[[1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]])
self.assertEqual(rt.shape.as_list(), [2, None, 3])
del rt
#=============================================================================
# RaggedTensorValue Constructor
#=============================================================================
def testRaggedTensorValueConstruction(self):
values = np.array(b'a b c d e f g'.split())
splits = np.array([0, 2, 5, 6, 6, 7], dtype=np.int64)
splits2 = np.array([0, 3, 5], dtype=np.int64)
# Test construction of a RaggedTensorValue with ragged_rank=1.
rt_value = ragged_tensor_value.RaggedTensorValue(values, splits)
self.assertEqual(rt_value.row_splits.dtype, np.int64)
self.assertEqual(rt_value.shape, (5, None))
self.assertLen(rt_value.nested_row_splits, 1)
self.assertAllEqual(splits, rt_value.row_splits)
self.assertAllEqual(values, rt_value.values)
self.assertAllEqual(splits, rt_value.nested_row_splits[0])
self.assertAllEqual(values, rt_value.flat_values)
# Test construction of a RaggedTensorValue with ragged_rank=2.
rt_value = ragged_tensor_value.RaggedTensorValue(
values=ragged_tensor_value.RaggedTensorValue(values, splits),
row_splits=splits2)
self.assertEqual(rt_value.row_splits.dtype, np.int64)
self.assertEqual(rt_value.shape, (2, None, None))
self.assertLen(rt_value.nested_row_splits, 2)
self.assertAllEqual(splits2, rt_value.row_splits)
self.assertAllEqual(splits, rt_value.values.row_splits)
self.assertAllEqual(splits2, rt_value.nested_row_splits[0])
self.assertAllEqual(splits, rt_value.nested_row_splits[1])
self.assertAllEqual(values, rt_value.values.values)
self.assertAllEqual(values, rt_value.flat_values)
#=============================================================================
# RaggedTensor Constructor (private)
#=============================================================================
def testRaggedTensorConstruction(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
rp = RowPartition.from_row_splits(row_splits)
rt = RaggedTensor(values=values, row_partition=rp, internal=True)
self.assertAllEqual(rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testRaggedTensorConstructionErrors(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
rp = RowPartition.from_row_splits(row_splits)
with self.assertRaisesRegex(ValueError,
'RaggedTensor constructor is private'):
RaggedTensor(values=values, row_partition=rp)
with self.assertRaisesRegex(
TypeError, r'type\(values\) must be one of: Tensor, RaggedTensor'):
RaggedTensor(values=range(7), row_partition=rp, internal=True)
with self.assertRaisesRegex(
TypeError, 'Argument `row_partition` must be a RowPartition'):
RaggedTensor(
values=values, row_partition=[0, 2, 2, 5, 6, 7], internal=True)
#=============================================================================
# RaggedTensor Factory Ops
#=============================================================================
def testFromValueRowIdsWithDerivedNRows(self):
# nrows is known at graph creation time.
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
rt = RaggedTensor.from_value_rowids(values, value_rowids, validate=False)
self.assertEqual(rt.dtype, dtypes.string)
self.assertEqual(rt.shape.as_list(), [5, None])
self.assertEqual(rt.ragged_rank, 1)
rt_values = rt.values
rt_value_rowids = rt.value_rowids()
rt_nrows = rt.nrows()
self.assertIs(rt_values, values)
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
self.assertAllEqual(rt_value_rowids, value_rowids)
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromValueRowIdsWithDerivedNRowsDynamic(self):
# nrows is not known at graph creation time.
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
value_rowids = array_ops.placeholder_with_default(value_rowids, shape=None)
rt = RaggedTensor.from_value_rowids(values, value_rowids, validate=False)
self.assertEqual(rt.dtype, dtypes.string)
if context.executing_eagerly():
self.assertEqual(rt.shape.as_list(), [5, None])
else:
self.assertEqual(rt.shape.as_list(), [None, None])
self.assertEqual(rt.ragged_rank, 1)
rt_values = rt.values
rt_value_rowids = rt.value_rowids()
rt_nrows = rt.nrows()
self.assertIs(rt_values, values)
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
self.assertAllEqual(rt_value_rowids, value_rowids)
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromValueRowIdsWithExplicitNRows(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
nrows = constant_op.constant(7, dtypes.int64)
rt = RaggedTensor.from_value_rowids(
values, value_rowids, nrows, validate=False)
self.assertEqual(rt.dtype, dtypes.string)
self.assertEqual(rt.shape.as_list(), [7, None])
self.assertEqual(rt.ragged_rank, 1)
rt_values = rt.values
rt_value_rowids = rt.value_rowids()
rt_nrows = rt.nrows()
self.assertIs(rt_values, values)
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
self.assertIs(rt_nrows, nrows) # cached_nrows
self.assertAllEqual(
rt, [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g'], [], []])
def testFromValueRowIdsWithExplicitNRowsEqualToDefault(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
nrows = constant_op.constant(5, dtypes.int64)
rt = RaggedTensor.from_value_rowids(
values, value_rowids, nrows, validate=False)
self.assertEqual(rt.dtype, dtypes.string)
self.assertEqual(rt.shape.as_list(), [5, None])
self.assertEqual(rt.ragged_rank, 1)
rt_values = rt.values
rt_value_rowids = rt.value_rowids()
rt_nrows = rt.nrows()
self.assertIs(rt_values, values)
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
self.assertIs(rt_nrows, nrows) # cached_nrows
self.assertAllEqual(rt_value_rowids, value_rowids)
self.assertAllEqual(rt_nrows, nrows)
self.assertAllEqual(rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromValueRowIdsWithEmptyValues(self):
rt = RaggedTensor.from_value_rowids([], [])
rt_nrows = rt.nrows()
self.assertEqual(rt.dtype, dtypes.float32)
self.assertEqual(rt.shape.as_list(), [0, None])
self.assertEqual(rt.ragged_rank, 1)
self.assertEqual(rt.values.shape.as_list(), [0])
self.assertEqual(rt.value_rowids().shape.as_list(), [0])
self.assertAllEqual(rt_nrows, 0)
self.assertAllEqual(rt, [])
def testFromRowSplits(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
rt = RaggedTensor.from_row_splits(values, row_splits, validate=False)
self.assertEqual(rt.dtype, dtypes.string)
self.assertEqual(rt.shape.as_list(), [5, None])
self.assertEqual(rt.ragged_rank, 1)
rt_values = rt.values
rt_row_splits = rt.row_splits
rt_nrows = rt.nrows()
self.assertIs(rt_values, values)
self.assertIs(rt_row_splits, row_splits)
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromRowSplitsWithDifferentSplitTypes(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
splits1 = [0, 2, 2, 5, 6, 7]
splits2 = np.array([0, 2, 2, 5, 6, 7], np.int64)
splits3 = np.array([0, 2, 2, 5, 6, 7], np.int32)
splits4 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
splits5 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int32)
rt1 = RaggedTensor.from_row_splits(values, splits1)
rt2 = RaggedTensor.from_row_splits(values, splits2)
rt3 = RaggedTensor.from_row_splits(values, splits3)
rt4 = RaggedTensor.from_row_splits(values, splits4)
rt5 = RaggedTensor.from_row_splits(values, splits5)
self.assertEqual(rt1.row_splits.dtype, dtypes.int64)
self.assertEqual(rt2.row_splits.dtype, dtypes.int64)
self.assertEqual(rt3.row_splits.dtype, dtypes.int32)
self.assertEqual(rt4.row_splits.dtype, dtypes.int64)
self.assertEqual(rt5.row_splits.dtype, dtypes.int32)
def testFromRowSplitsWithEmptySplits(self):
err_msg = 'row_splits tensor may not be empty'
with self.assertRaisesRegex(ValueError, err_msg):
RaggedTensor.from_row_splits([], [])
def testFromRowStarts(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
row_starts = constant_op.constant([0, 2, 2, 5, 6], dtypes.int64)
rt = RaggedTensor.from_row_starts(values, row_starts, validate=False)
self.assertEqual(rt.dtype, dtypes.string)
self.assertEqual(rt.shape.as_list(), [5, None])
self.assertEqual(rt.ragged_rank, 1)
rt_values = rt.values
rt_row_starts = rt.row_starts()
rt_nrows = rt.nrows()
self.assertIs(rt_values, values)
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(rt_row_starts, row_starts)
self.assertAllEqual(rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromRowLimits(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
row_limits = constant_op.constant([2, 2, 5, 6, 7], dtypes.int64)
rt = RaggedTensor.from_row_limits(values, row_limits, validate=False)
self.assertEqual(rt.dtype, dtypes.string)
self.assertEqual(rt.shape.as_list(), [5, None])
self.assertEqual(rt.ragged_rank, 1)
rt_values = rt.values
rt_row_limits = rt.row_limits()
rt_nrows = rt.nrows()
self.assertIs(rt_values, values)
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(rt_row_limits, row_limits)
self.assertAllEqual(rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromRowLengths(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
row_lengths = constant_op.constant([2, 0, 3, 1, 1], dtypes.int64)
rt = RaggedTensor.from_row_lengths(values, row_lengths, validate=False)
self.assertEqual(rt.dtype, dtypes.string)
self.assertEqual(rt.shape.as_list(), [5, None])
self.assertEqual(rt.ragged_rank, 1)
rt_values = rt.values
rt_row_lengths = rt.row_lengths()
rt_nrows = rt.nrows()
self.assertIs(rt_values, values)
self.assertIs(rt_row_lengths, row_lengths) # cached_nrows
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(rt_row_lengths, row_lengths)
self.assertAllEqual(rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromRowLengthsInt32(self):
rt = RaggedTensor.from_row_lengths([1, 2, 3, 4],
constant_op.constant([1, 0, 3],
dtype=dtypes.int32))
rt2 = RaggedTensor.from_row_lengths(rt, [2, 1, 0])
self.assertAllEqual([2, 1, 0], rt2.row_lengths())
def testFromUniformRowLength(self):
values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
a1 = RaggedTensor.from_uniform_row_length(values, 2)
a2 = RaggedTensor.from_uniform_row_length(values, 2, 8)
self.assertAllEqual(
a1,
[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]])
self.assertAllEqual(a1, a2)
self.assertEqual(a1.shape.as_list(), [8, 2])
self.assertEqual(a2.shape.as_list(), [8, 2])
b1 = RaggedTensor.from_uniform_row_length(a1, 2)
b2 = RaggedTensor.from_uniform_row_length(a1, 2, 4)
self.assertAllEqual(b1, [[[1, 2], [3, 4]], [[5, 6], [7, 8]],
[[9, 10], [11, 12]], [[13, 14], [15, 16]]])
self.assertAllEqual(b1, b2)
self.assertEqual(b1.shape.as_list(), [4, 2, 2])
self.assertEqual(b2.shape.as_list(), [4, 2, 2])
c1 = RaggedTensor.from_uniform_row_length(b1, 2)
c2 = RaggedTensor.from_uniform_row_length(b1, 2, 2)
self.assertAllEqual(c1, [[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
[[[9, 10], [11, 12]], [[13, 14], [15, 16]]]])
self.assertAllEqual(c1, c2)
self.assertEqual(c1.shape.as_list(), [2, 2, 2, 2])
self.assertEqual(c2.shape.as_list(), [2, 2, 2, 2])
def testFromUniformRowLengthWithEmptyValues(self):
empty_values = []
a = RaggedTensor.from_uniform_row_length(empty_values, 0, nrows=10)
self.assertEqual(a.shape.as_list(), [10, 0])
b = RaggedTensor.from_uniform_row_length(a, 2)
self.assertEqual(b.shape.as_list(), [5, 2, 0])
# Make sure we avoid divide-by-zero when finding nrows for nvals=rowlen=0.
c = RaggedTensor.from_uniform_row_length(empty_values, 0)
self.assertEqual(c.shape.as_list(), [0, 0])
d = RaggedTensor.from_uniform_row_length(empty_values, 0, nrows=0)
self.assertEqual(d.shape.as_list(), [0, 0])
def testFromUniformRowLengthWithPlaceholders(self):
ph_values = array_ops.placeholder_with_default([1, 2, 3, 4, 5, 6], [None])
ph_rowlen = array_ops.placeholder_with_default(3, None)
rt1 = RaggedTensor.from_uniform_row_length(ph_values, 3)
rt2 = RaggedTensor.from_uniform_row_length(ph_values, ph_rowlen)
rt3 = RaggedTensor.from_uniform_row_length([1, 2, 3, 4, 5, 6], ph_rowlen)
self.assertAllEqual(rt1, [[1, 2, 3], [4, 5, 6]])
self.assertAllEqual(rt2, [[1, 2, 3], [4, 5, 6]])
self.assertAllEqual(rt3, [[1, 2, 3], [4, 5, 6]])
if context.executing_eagerly():
self.assertEqual(rt1.shape.as_list(), [2, 3])
self.assertEqual(rt2.shape.as_list(), [2, 3])
self.assertEqual(rt3.shape.as_list(), [2, 3])
else:
self.assertEqual(rt1.shape.as_list(), [None, 3])
self.assertEqual(rt2.shape.as_list(), [None, None])
self.assertEqual(rt3.shape.as_list(), [None, None])
b = RaggedTensor.from_uniform_row_length(rt1, 2)
self.assertAllEqual(b, [[[1, 2, 3], [4, 5, 6]]])
# Make sure we avoid divide-by-zero when finding nrows for nvals=rowlen=0.
ph_empty_values = array_ops.placeholder_with_default(
array_ops.zeros([0], dtypes.int64), [None])
ph_zero = array_ops.placeholder_with_default(0, [])
c = RaggedTensor.from_uniform_row_length(ph_empty_values, ph_zero)
if context.executing_eagerly():
self.assertEqual(c.shape.as_list(), [0, 0])
else:
self.assertEqual(c.shape.as_list(), [None, None])
def testFromNestedValueRowIdsWithDerivedNRows(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
nested_value_rowids = [
constant_op.constant([0, 0, 1, 3, 3], dtypes.int64),
constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
]
rt = RaggedTensor.from_nested_value_rowids(values, nested_value_rowids)
self.assertEqual(rt.dtype, dtypes.string)
self.assertEqual(rt.shape.as_list(), [4, None, None])
self.assertEqual(rt.ragged_rank, 2)
rt_values = rt.values
rt_value_rowids = rt.value_rowids()
rt_values_values = rt_values.values
rt_values_value_rowids = rt_values.value_rowids()
self.assertIs(rt_values_values, values)
self.assertAllEqual(rt_value_rowids, nested_value_rowids[0])
self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1])
self.assertAllEqual(
rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
def testFromNestedRowPartitions(self):
flat_values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
nested_row_splits = [[0, 2, 3, 3, 5], [0, 2, 2, 5, 6, 7]]
nested_row_partition = [
RowPartition.from_row_splits(constant_op.constant(x, dtypes.int64))
for x in nested_row_splits
]
rt = RaggedTensor._from_nested_row_partitions(
flat_values, nested_row_partition, validate=False)
self.assertEqual(rt.dtype, dtypes.string)
self.assertEqual(rt.shape.as_list(), [4, None, None])
self.assertEqual(rt.ragged_rank, 2)
self.assertAllEqual(
rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
def testFromNestedValueRowIdsWithExplicitNRows(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
nested_value_rowids = [
constant_op.constant([0, 0, 1, 3, 3, 3], dtypes.int64),
constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
]
nrows = [
constant_op.constant(6, dtypes.int64),
constant_op.constant(6, dtypes.int64)
]
rt = RaggedTensor.from_nested_value_rowids(values, nested_value_rowids,
nrows)
self.assertEqual(rt.dtype, dtypes.string)
self.assertEqual(rt.shape.as_list(), [6, None, None])
self.assertEqual(rt.ragged_rank, 2)
rt_values = rt.values
rt_value_rowids = rt.value_rowids()
rt_nrows = rt.nrows()
rt_values_values = rt_values.values
rt_values_value_rowids = rt_values.value_rowids()
rt_values_nrows = rt_values.nrows()
self.assertIs(rt_values_values, values)
self.assertAllEqual(rt_value_rowids, nested_value_rowids[0])
self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1])
self.assertAllEqual(rt_nrows, nrows[0])
self.assertAllEqual(rt_values_nrows, nrows[1])
self.assertAllEqual(rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [],
[[b'f'], [b'g'], []], [], []])
def testFromNestedValueRowIdsWithExplicitNRowsMismatch(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
nested_value_rowids = [
constant_op.constant([0, 0, 1, 3, 3, 3], dtypes.int64),
constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
]
nrows = [constant_op.constant(6, dtypes.int64)]
with self.assertRaisesRegex(
ValueError, 'Argument `nested_nrows` must have the same length as '
'argument `nested_value_rowids`'):
RaggedTensor.from_nested_value_rowids(values, nested_value_rowids, nrows)
def testFromNestedValueRowIdsWithNonListInput(self):
with self.assertRaisesRegex(
TypeError, 'Argument `nested_value_rowids` must be a list of Tensors'):
RaggedTensor.from_nested_value_rowids(
[1, 2, 3], constant_op.constant([[0, 1, 2], [0, 1, 2]], dtypes.int64))
with self.assertRaisesRegex(
TypeError, 'Argument `nested_nrows` must be a list of Tensors'):
RaggedTensor.from_nested_value_rowids([1, 2, 3], [[0, 1, 2], [0, 1, 2]],
constant_op.constant([3, 3]))
def testFromNestedRowSplits(self):
flat_values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
nested_row_splits = [
constant_op.constant([0, 2, 3, 3, 5], dtypes.int64),
constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
]
rt = RaggedTensor.from_nested_row_splits(
flat_values, nested_row_splits, validate=False)
self.assertEqual(rt.dtype, dtypes.string)
self.assertEqual(rt.shape.as_list(), [4, None, None])
self.assertEqual(rt.ragged_rank, 2)
rt_values = rt.values
rt_row_splits = rt.row_splits
rt_values_values = rt_values.values
rt_values_row_splits = rt_values.row_splits
self.assertIs(rt_values_values, flat_values)
self.assertIs(rt_row_splits, nested_row_splits[0])
self.assertIs(rt_values_row_splits, nested_row_splits[1])
self.assertAllEqual(
rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
def testWithRowSplits(self):
flat_values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
nested_row_splits = [
constant_op.constant([0, 2, 3, 3, 5], dtypes.int64),
constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
]
rt = RaggedTensor.from_nested_row_splits(
flat_values, nested_row_splits, validate=False)
rt = rt.with_row_splits_dtype(dtypes.int32)
self.assertEqual(rt.dtype, dtypes.string)
self.assertEqual(rt.shape.as_list(), [4, None, None])
self.assertEqual(rt.ragged_rank, 2)
rt_values = rt.values
rt_row_splits = rt.row_splits
rt_values_values = rt_values.values
rt_values_row_splits = rt_values.row_splits
self.assertAllEqual(rt_values_values, flat_values)
self.assertAllEqual(rt_row_splits, nested_row_splits[0])
self.assertAllEqual(rt_values_row_splits, nested_row_splits[1])
self.assertAllEqual(
rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
def testFromNestedRowSplitsWithNonListInput(self):
with self.assertRaisesRegex(
TypeError, '`nested_row_splits` must be a list of Tensors'):
RaggedTensor.from_nested_row_splits(
[1, 2], constant_op.constant([[0, 1, 2], [0, 1, 2]], dtypes.int64))
def testFromValueRowIdsWithBadNRows(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
nrows = constant_op.constant(5, dtypes.int64)
with self.assertRaisesRegex(ValueError, r'Expected nrows >= 0; got -2'):
RaggedTensor.from_value_rowids(
values=values,
value_rowids=array_ops.placeholder_with_default(value_rowids, None),
nrows=-2)
with self.assertRaisesRegex(
ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=2, '
r'value_rowids\[-1\]=4'):
RaggedTensor.from_value_rowids(
values=values, value_rowids=value_rowids, nrows=2)
with self.assertRaisesRegex(
ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=4, '
r'value_rowids\[-1\]=4'):
RaggedTensor.from_value_rowids(
values=values, value_rowids=value_rowids, nrows=4)
with self.assertRaisesRegex(ValueError, r'Shape \(7, 1\) must have rank 1'):
RaggedTensor.from_value_rowids(
values=values,
value_rowids=array_ops.expand_dims(value_rowids, 1),
nrows=nrows)
with self.assertRaisesRegex(ValueError, r'Shape \(1,\) must have rank 0'):
RaggedTensor.from_value_rowids(
values=values,
value_rowids=value_rowids,
nrows=array_ops.expand_dims(nrows, 0))
def testCondWithTensorsFromValueIds(self):
# b/141166460
rt = RaggedTensor.from_value_rowids([1, 2, 3], [0, 0, 2])
c = array_ops.placeholder_with_default(True, None)
result = cond.cond(c, lambda: rt, lambda: rt)
self.assertAllEqual(rt, result)
def testGraphMismatch(self):
if not context.executing_eagerly():
with ops.Graph().as_default():
values = constant_op.constant([1, 2, 3], dtypes.int64)
with ops.Graph().as_default():
splits = constant_op.constant([0, 2, 3], dtypes.int64)
with self.assertRaisesRegex(ValueError,
'.* must be from the same graph as .*'):
RaggedTensor.from_row_splits(values, splits)
@parameterized.named_parameters([
dict(
testcase_name='Rank0',
tensor='a'),
dict(
testcase_name='Rank1',
tensor=['a', 'b']),
])
def testFromTensorRankError(self, tensor):
with self.assertRaisesRegex(ValueError, 'must be greater than 1'):
RaggedTensor.from_tensor(tensor)
#=============================================================================
# Ragged Value & Row-Partitioning Tensor Accessors
#=============================================================================
def testRaggedTensorAccessors_2d(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
rt1 = RaggedTensor.from_row_splits(values, row_splits)
rt2 = RaggedTensor.from_value_rowids(values, value_rowids)
for rt in [rt1, rt2]:
self.assertAllEqual(
rt, [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
self.assertAllEqual(rt.values, [b'a', b'b', b'c', b'd', b'e', b'f', b'g'])
self.assertEqual(rt.values.shape.dims[0].value, 7)
self.assertAllEqual(rt.value_rowids(), [0, 0, 2, 2, 2, 3, 4])
self.assertAllEqual(rt.nrows(), 5)
self.assertAllEqual(rt.row_splits, [0, 2, 2, 5, 6, 7])
self.assertAllEqual(rt.row_starts(), [0, 2, 2, 5, 6])
self.assertAllEqual(rt.row_limits(), [2, 2, 5, 6, 7])
self.assertAllEqual(rt.row_lengths(), [2, 0, 3, 1, 1])
self.assertAllEqual(rt.flat_values,
[b'a', b'b', b'c', b'd', b'e', b'f', b'g'])
self.assertLen(rt.nested_row_splits, 1)
self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 2, 5, 6, 7])
def testRaggedTensorAccessors_3d_with_ragged_rank_1(self):
values = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]]
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
row_lengths = constant_op.constant([2, 0, 3, 1, 1])
rt1 = RaggedTensor.from_row_splits(values, row_splits)
rt2 = RaggedTensor.from_value_rowids(values, value_rowids)
rt3 = RaggedTensor.from_row_lengths(values, row_lengths)
for rt in [rt1, rt2, rt3]:
self.assertAllEqual(rt, [[[0, 1], [2, 3]], [], [[4, 5], [6, 7], [8, 9]],
[[10, 11]], [[12, 13]]])
self.assertAllEqual(
rt.values,
[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]])
self.assertEqual(rt.values.shape.dims[0].value, 7)
self.assertAllEqual(rt.value_rowids(), [0, 0, 2, 2, 2, 3, 4])
self.assertAllEqual(rt.nrows(), 5)
self.assertAllEqual(rt.row_splits, [0, 2, 2, 5, 6, 7])
self.assertAllEqual(rt.row_starts(), [0, 2, 2, 5, 6])
self.assertAllEqual(rt.row_limits(), [2, 2, 5, 6, 7])
self.assertAllEqual(rt.row_lengths(), [2, 0, 3, 1, 1])
self.assertAllEqual(
rt.row_lengths(axis=2), [[2, 2], [], [2, 2, 2], [2], [2]])
self.assertAllEqual(
rt.flat_values,
[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]])
self.assertLen(rt.nested_row_splits, 1)
self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 2, 5, 6, 7])
self.assertLen(rt.nested_value_rowids(), 1)
self.assertAllEqual(rt.nested_value_rowids()[0], [0, 0, 2, 2, 2, 3, 4])
def testRaggedTensorAccessors_3d_with_ragged_rank_2(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
nested_row_splits = [
constant_op.constant([0, 2, 3, 3, 5], dtypes.int64),
constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
]
nested_value_rowids = [
constant_op.constant([0, 0, 1, 3, 3], dtypes.int64),
constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
]
rt1 = RaggedTensor.from_nested_row_splits(values, nested_row_splits)
rt2 = RaggedTensor.from_nested_value_rowids(values, nested_value_rowids)
for rt in [rt1, rt2]:
self.assertAllEqual(
rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
self.assertAllEqual(
rt.values, [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
self.assertEqual(rt.values.shape.dims[0].value, 5)
self.assertAllEqual(rt.value_rowids(), [0, 0, 1, 3, 3])
self.assertAllEqual(rt.nrows(), 4)
self.assertAllEqual(rt.row_splits, [0, 2, 3, 3, 5])
self.assertAllEqual(rt.row_starts(), [0, 2, 3, 3])
self.assertAllEqual(rt.row_limits(), [2, 3, 3, 5])
self.assertAllEqual(rt.row_lengths(), [2, 1, 0, 2])
self.assertAllEqual(rt.flat_values,
[b'a', b'b', b'c', b'd', b'e', b'f', b'g'])
self.assertLen(rt.nested_row_splits, 2)
self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 3, 3, 5])
self.assertAllEqual(rt.nested_row_splits[1], [0, 2, 2, 5, 6, 7])
self.assertLen(rt.nested_value_rowids(), 2)
self.assertAllEqual(rt.nested_value_rowids()[0], [0, 0, 1, 3, 3])
self.assertAllEqual(rt.nested_value_rowids()[1], [0, 0, 2, 2, 2, 3, 4])
#=============================================================================
# RaggedTensor.shape
#=============================================================================
def testShape(self):
"""Tests for RaggedTensor.shape."""
rt1 = RaggedTensor.from_row_splits(b'a b c d e f g'.split(),
[0, 2, 5, 6, 6, 7])
self.assertEqual(rt1.shape.as_list(), [5, None])
rt2 = RaggedTensor.from_row_splits(
[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14]],
[0, 2, 5, 6, 6, 7])
self.assertEqual(rt2.shape.as_list(), [5, None, 2])
rt3 = RaggedTensor.from_row_splits(
[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]], [0, 2, 2, 3])
self.assertEqual(rt3.shape.as_list(), [3, None, 2, 2])
rt4 = RaggedTensor.from_row_splits(rt3, [0, 1, 3, 3])
self.assertEqual(rt4.shape.as_list(), [3, None, None, 2, 2])
if not context.executing_eagerly():
rt5 = RaggedTensor.from_row_splits(
array_ops.placeholder(dtype=dtypes.string), [0, 2, 3, 5])
self.assertIsNone(rt5.shape.ndims)
rt6 = RaggedTensor.from_row_splits(
[1, 2, 3], array_ops.placeholder(dtype=dtypes.int64))
self.assertEqual(rt6.shape.as_list(), [None, None])
def testGetShape(self):
rt = RaggedTensor.from_row_splits(b'a b c d e f g'.split(),
[0, 2, 5, 6, 6, 7])
self.assertEqual(rt.shape.as_list(), rt.get_shape().as_list())
#=============================================================================
# RaggedTensor.__str__
#=============================================================================
def testRaggedTensorStr(self):
values = [b'a', b'b', b'c', b'd', b'e', b'f', b'g']
row_splits = [0, 2, 5, 6, 6, 7]
rt = RaggedTensor.from_row_splits(values, row_splits, validate=False)
splits_type = 'int64'
if context.executing_eagerly():
expected_repr = '<tf.RaggedTensor {}>'.format([[b'a', b'b'],
[b'c', b'd', b'e'], [b'f'],
[], [b'g']])
else:
expected_repr = (
'tf.RaggedTensor(values=Tensor("RaggedFromRowSplits/values:0", '
'shape=(7,), dtype=string), '
'row_splits=Tensor('
'"RaggedFromRowSplits/RowPartitionFromRowSplits/row_splits:0",'
' shape=(6,), dtype={}))').format(splits_type)
self.assertEqual(repr(rt), expected_repr)
self.assertEqual(str(rt), expected_repr)
def testRaggedTensorValueStr(self):
values = [b'a', b'b', b'c', b'd', b'e', b'f', b'g']
row_splits = [0, 2, 5, 6, 6, 7]
rt = ragged_tensor_value.RaggedTensorValue(
np.array(values), np.array(row_splits, dtype=np.int64))
expected_str = '<tf.RaggedTensorValue {}>'.format([[b'a', b'b'],
[b'c', b'd', b'e'],
[b'f'], [], [b'g']])
expected_repr = ("tf.RaggedTensorValue(values=array({}, dtype='|S1'), "
'row_splits=array({}))'.format(values, row_splits))
self.assertEqual(' '.join(str(rt).split()), expected_str)
self.assertEqual(' '.join(repr(rt).split()), expected_repr)
def testRaggedTensorStrWithZeroSizeInnerShape(self):
# Tests that b/226112826 is fixed.
if context.executing_eagerly():
rt = RaggedTensor.from_row_lengths(array_ops.zeros([9, 0]), [4, 3, 2])
expected_repr = (
'<tf.RaggedTensor [[[], [], [], []], [[], [], []], [[], []]]>')
self.assertEqual(' '.join(repr(rt).split()), expected_repr)
#=============================================================================
# RaggedTensor.with_values() and RaggedTensor.with_flat_values().
#=============================================================================
def testWithValues(self):
rt1 = ragged_factory_ops.constant([[1, 2], [3, 4, 5], [6], [], [7]])
rt2 = ragged_factory_ops.constant([[[1, 2], [3, 4, 5]], [[6]], [], [[],
[7]]])
rt1_plus_10 = rt1.with_values(rt1.values + 10)
rt2_times_10 = rt2.with_flat_values(rt2.flat_values * 10)
rt1_expanded = rt1.with_values(array_ops.expand_dims(rt1.values, axis=1))
self.assertAllEqual(rt1_plus_10, [[11, 12], [13, 14, 15], [16], [], [17]])
self.assertAllEqual(rt2_times_10,
[[[10, 20], [30, 40, 50]], [[60]], [], [[], [70]]])
self.assertAllEqual(rt1_expanded,
[[[1], [2]], [[3], [4], [5]], [[6]], [], [[7]]])
#=============================================================================
# Session.run
#=============================================================================
def testSessionRun(self):
if context.executing_eagerly():
return
rt1 = ragged_factory_ops.constant([[1, 2, 3], [4]])
rt2 = ragged_factory_ops.constant([[[], [1, 2]], [[3]]])
with self.test_session() as session:
result = session.run({'rt1': rt1, 'rt2': rt2})
self.assertCountEqual(result.keys(), ['rt1', 'rt2'])
self.assertEqual(result['rt1'].to_list(), [[1, 2, 3], [4]])
self.assertEqual(result['rt2'].to_list(), [[[], [1, 2]], [[3]]])
def testSessionRunFeed(self):
if context.executing_eagerly():
return
rt1 = RaggedTensor.from_row_splits(
array_ops.placeholder(dtypes.int32),
array_ops.placeholder(dtypes.int64))
rt2 = RaggedTensor.from_nested_row_splits(
array_ops.placeholder(dtypes.int32), [
array_ops.placeholder(dtypes.int64),
array_ops.placeholder(dtypes.int64)
])
rt1_feed_val = ragged_factory_ops.constant_value([[1, 2, 3], [4]])
rt2_feed_val = ragged_factory_ops.constant_value([[[], [1, 2]], [[3]]])
with self.test_session() as session:
fetches = {'rt1': rt1, 'rt2': rt2}
feeds = {rt1: rt1_feed_val, rt2: rt2_feed_val}
result = session.run(fetches, feed_dict=feeds)
self.assertCountEqual(result.keys(), ['rt1', 'rt2'])
self.assertEqual(result['rt1'].to_list(), [[1, 2, 3], [4]])
self.assertEqual(result['rt2'].to_list(), [[[], [1, 2]], [[3]]])
def testSessionPartialRunFeed(self):
if context.executing_eagerly():
return
# Placeholder inputs.
a = RaggedTensor.from_row_splits(
array_ops.placeholder(dtypes.int32, shape=[None], name='a.values'),
array_ops.placeholder(dtypes.int64, name='a.row_splits'))
b = RaggedTensor.from_row_splits(
array_ops.placeholder(dtypes.int32, shape=[None], name='b.values'),
array_ops.placeholder(dtypes.int64, name='b.row_splits'))
c = array_ops.placeholder(dtypes.int32, shape=[], name='c')
# Feed values for placeholder inputs.
a_val = ragged_factory_ops.constant_value([[1, 2, 3], [4]])
b_val = ragged_factory_ops.constant_value([[5, 4, 3], [2]])
c_val = 3
# Compute some values.
r1 = ragged_math_ops.reduce_sum(a * b, axis=1)
r2 = ragged_math_ops.reduce_sum(a + c, axis=1)
with self.test_session() as session:
handle = session.partial_run_setup([r1, r2], [a, b, c])
res1 = session.partial_run(handle, r1, feed_dict={a: a_val, b: b_val})
self.assertAllEqual(res1, [22, 8])
res2 = session.partial_run(handle, r2, feed_dict={c: c_val})
self.assertAllEqual(res2, [15, 7])
# Test case for GitHub issue 24679.
def testEagerForLoop(self):
if not context.executing_eagerly():
return
values = [[1., 2.], [3., 4., 5.], [6.]]
r = ragged_factory_ops.constant(values)
i = 0
for elem in r:
self.assertAllEqual(elem, values[i])
i += 1
def testConsumers(self):
if context.executing_eagerly():
return
a = RaggedTensor.from_row_splits(
array_ops.placeholder(dtypes.int32, shape=[None], name='a.values'),
array_ops.placeholder(dtypes.int64, name='a.row_splits'),
validate=False)
ragged_math_ops.reduce_sum(a)
self.assertLen(a.consumers(), 1)
@parameterized.parameters([
{
'descr': 'from_value_rowids',
'factory': RaggedTensor.from_value_rowids,
'test': RaggedTensor.value_rowids,
'values': {
'values': [1, 2, 3, 4, 5, 6],
'value_rowids': [0, 0, 1, 1, 2, 2],
},
'tensor_field': 'value_rowids',
'value_rowids': [0, 1, 2],
'nrows': 10
},
{
'descr': 'from_row_splits',
'factory': RaggedTensor.from_row_splits,
# row_splits is a property, not a function.
'test': (lambda rt: rt.row_splits),
'values': {
'values': [1, 2, 3, 4, 5, 6],
'row_splits': [0, 2, 4, 6],
},
'tensor_field': 'row_splits',
'row_splits': [0, 1, 2, 3]
},
{
'descr': 'from_row_lengths',
'factory': RaggedTensor.from_row_lengths,
'test': RaggedTensor.row_lengths,
'values': {
'values': [1, 2, 3, 4, 5, 6],
'row_lengths': [2, 2, 2],
},
'tensor_field': 'row_lengths',
'row_lengths': [1, 1, 1],
},
# from_row_starts
{
'descr': 'from_row_starts',
'factory': RaggedTensor.from_row_starts,
'test': RaggedTensor.row_starts,
'values': {
'values': [1, 2, 3, 4, 5, 6],
'row_starts': [0, 2, 4]
},
'tensor_field': 'row_starts',
'row_starts': [0, 1, 2]
},
# from_row_limits
{
'descr': 'from_row_limits',
'factory': RaggedTensor.from_row_limits,
'test': RaggedTensor.row_limits,
'values': {
'values': [1, 2, 3, 4, 5, 6],
'row_limits': [2, 4, 6]
},
'tensor_field': 'row_limits',
'row_limits': [3]
},
# from_uniform_row_length
{
'descr': 'from_uniform_row_length',
'factory': RaggedTensor.from_uniform_row_length,
# One cannot extract uniform_row_length or nvals, so we return
# nvals//nrows = uniform_row_length, where nvals = 3
'test': (lambda rt: 3 // (rt.shape[0])),
'values': {
'values': [1, 2, 3, 4, 5, 6],
'uniform_row_length': 2
},
'tensor_field': 'uniform_row_length',
'uniform_row_length': 3
},
])
def testFactoryTypePreference(self, descr, test, factory, values,
tensor_field, **kwargs):
# When input tensors have shape information, some of these errors will be
# detected statically.
def op_cast(k, v):
if k == tensor_field:
return constant_op.constant(v, dtype=dtypes.int32)
else:
return v
value_copy = {k: op_cast(k, v) for k, v in values.items()}
rt = factory(**value_copy)
kw_copy = {k: v for k, v in kwargs.items()}
kw_copy['values'] = rt
rt2 = factory(**kw_copy)
self.assertAllEqual(kwargs[tensor_field], test(rt2))
@parameterized.parameters([
# from_value_rowids
{
'descr': 'bad rank for value_rowids',
'factory': RaggedTensor.from_value_rowids,
'values': [[1, 2], [3, 4]],
'value_rowids': [[1, 2], [3, 4]],
'nrows': 10
},
{
'descr': 'bad rank for nrows',
'factory': RaggedTensor.from_value_rowids,
'values': [1, 2, 3, 4],
'value_rowids': [1, 2, 3, 4],
'nrows': [10]
},
{
'descr': 'len(values) != len(value_rowids)',
'factory': RaggedTensor.from_value_rowids,
'values': [1, 2, 3, 4],
'value_rowids': [1, 2, 3, 4, 5],
'nrows': 10
},
{
'descr': 'negative value_rowid',
'factory': RaggedTensor.from_value_rowids,
'values': [1, 2, 3, 4],
'value_rowids': [-5, 2, 3, 4],
'nrows': 10
},
{
'descr': 'non-monotonic-increasing value_rowid',
'factory': RaggedTensor.from_value_rowids,
'values': [1, 2, 3, 4],
'value_rowids': [4, 3, 2, 1],
'nrows': 10
},
{
'descr': 'value_rowid > nrows',
'factory': RaggedTensor.from_value_rowids,
'values': [1, 2, 3, 4],
'value_rowids': [1, 2, 3, 4],
'nrows': 2
},
{
'descr': 'bad rank for values',
'factory': RaggedTensor.from_value_rowids,
'values': 10,
'value_rowids': [1, 2, 3, 4],
'nrows': 10
},
# from_row_splits
{
'descr': 'bad rank for row_splits',
'factory': RaggedTensor.from_row_splits,
'values': [[1, 2], [3, 4]],
'row_splits': [[1, 2], [3, 4]]
},
{
'descr': 'row_splits[0] != 0',
'factory': RaggedTensor.from_row_splits,
'values': [1, 2, 3, 4],
'row_splits': [2, 3, 4]
},
{
'descr': 'non-monotonic-increasing row_splits',
'factory': RaggedTensor.from_row_splits,
'values': [1, 2, 3, 4],
'row_splits': [0, 3, 2, 4]
},
{
'descr': 'row_splits[0] != nvals',
'factory': RaggedTensor.from_row_splits,
'values': [1, 2, 3, 4],
'row_splits': [0, 2, 3, 5]
},
{
'descr': 'bad rank for values',
'factory': RaggedTensor.from_row_splits,
'values': 10,
'row_splits': [0, 1]
},
# from_row_lengths
{
'descr': 'bad rank for row_lengths',
'factory': RaggedTensor.from_row_lengths,
'values': [1, 2, 3, 4],
'row_lengths': [[1, 2], [1, 0]]
},
{
'descr': 'negatve row_lengths',
'factory': RaggedTensor.from_row_lengths,
'values': [1, 2, 3, 4],
'row_lengths': [3, -1, 2]
},
{
'descr': 'sum(row_lengths) != nvals',
'factory': RaggedTensor.from_row_lengths,
'values': [1, 2, 3, 4],
'row_lengths': [2, 4, 2, 8]
},
{
'descr': 'bad rank for values',
'factory': RaggedTensor.from_row_lengths,
'values': 10,
'row_lengths': [0, 1]
},
# from_row_starts
{
'descr': 'bad rank for row_starts',
'factory': RaggedTensor.from_row_starts,
'values': [[1, 2], [3, 4]],
'row_starts': [[1, 2], [3, 4]]
},
{
'descr': 'row_starts[0] != 0',
'factory': RaggedTensor.from_row_starts,
'values': [1, 2, 3, 4],
'row_starts': [2, 3, 4]
},
{
'descr': 'non-monotonic-increasing row_starts',
'factory': RaggedTensor.from_row_starts,
'values': [1, 2, 3, 4],
'row_starts': [0, 3, 2, 4]
},
{
'descr': 'row_starts[0] > nvals',
'factory': RaggedTensor.from_row_starts,
'values': [1, 2, 3, 4],
'row_starts': [0, 2, 3, 5]
},
{
'descr': 'bad rank for values',
'factory': RaggedTensor.from_row_starts,
'values': 10,
'row_starts': [0, 1]
},
# from_row_limits
{
'descr': 'bad rank for row_limits',
'factory': RaggedTensor.from_row_limits,
'values': [[1, 2], [3, 4]],
'row_limits': [[1, 2], [3, 4]]
},
{
'descr': 'row_limits[0] < 0',
'factory': RaggedTensor.from_row_limits,
'values': [1, 2, 3, 4],
'row_limits': [-1, 3, 4]
},
{
'descr': 'non-monotonic-increasing row_limits',
'factory': RaggedTensor.from_row_limits,
'values': [1, 2, 3, 4],
'row_limits': [0, 3, 2, 4]
},
{
'descr': 'row_limits[0] != nvals',
'factory': RaggedTensor.from_row_limits,
'values': [1, 2, 3, 4],
'row_limits': [0, 2, 3, 5]
},
{
'descr': 'bad rank for values',
'factory': RaggedTensor.from_row_limits,
'values': 10,
'row_limits': [0, 1]
},
# from_uniform_row_length
{
'descr': 'rowlen * nrows != nvals (1)',
'factory': RaggedTensor.from_uniform_row_length,
'values': [1, 2, 3, 4, 5],
'uniform_row_length': 3
},
{
'descr': 'rowlen * nrows != nvals (2)',
'factory': RaggedTensor.from_uniform_row_length,
'values': [1, 2, 3, 4, 5],
'uniform_row_length': 6
},
{
'descr': 'rowlen * nrows != nvals (3)',
'factory': RaggedTensor.from_uniform_row_length,
'values': [1, 2, 3, 4, 5, 6],
'uniform_row_length': 3,
'nrows': 3
},
{
'descr': 'rowlen must be a scalar',
'factory': RaggedTensor.from_uniform_row_length,
'values': [1, 2, 3, 4],
'uniform_row_length': [2]
},
{
'descr': 'rowlen must be nonnegative',
'factory': RaggedTensor.from_uniform_row_length,
'values': [1, 2, 3, 4],
'uniform_row_length': -1
},
])
def testFactoryValidation(self, descr, factory, **kwargs):
# When input tensors have shape information, some of these errors will be
# detected statically.
with self.assertRaises((errors.InvalidArgumentError, ValueError)):
self.evaluate(factory(**kwargs))
# Remove shape information (by wrapping tensors in placeholders), and check
# that we detect the errors when the graph is run.
if not context.executing_eagerly():
def wrap_arg(v):
return array_ops.placeholder_with_default(
constant_op.constant(v, dtype=dtypes.int64),
tensor_shape.TensorShape(None))
kwargs = dict((k, wrap_arg(v)) for (k, v) in kwargs.items())
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(factory(**kwargs))
#=============================================================================
# RaggedTensor Variant conversion
#=============================================================================
@parameterized.named_parameters(
{
'testcase_name': 'Shape_5_none',
'ragged_constant': [[1, 2], [3, 4, 5], [6], [], [7]],
'ragged_rank': 1
}, {
'testcase_name': 'Shape_4_none_2',
'ragged_constant': [[[1, 2]], [], [[3, 4]], []],
'ragged_rank': 1
}, {
'testcase_name': 'Shape_1_none_none',
'ragged_constant': [[[1], [2, 3, 4, 5, 6, 7]], [[]]],
'ragged_rank': 2
})
def testRaggedToVariant(self, ragged_constant, ragged_rank):
rt = ragged_factory_ops.constant(ragged_constant, ragged_rank=ragged_rank)
et = rt._to_variant()
self.assertEqual(et.shape.as_list(), [])
self.assertEqual(et.dtype, dtypes.variant)
@parameterized.parameters(
{
'ragged_constant': [[1, 2], [3, 4, 5], [6], [], [7]],
'ragged_rank': 1,
'num_batched_elems': 5
}, {
'ragged_constant': [[[1, 2]], [], [[3, 4]], []],
'ragged_rank': 1,
'num_batched_elems': 4
}, {
'ragged_constant': [[[1], [2, 3, 4, 5, 6, 7]], [[]]],
'ragged_rank': 2,
'num_batched_elems': 2
})
def testRaggedToBatchedVariant(self, ragged_constant, ragged_rank,
num_batched_elems):
rt = ragged_factory_ops.constant(ragged_constant, ragged_rank=ragged_rank)
et = rt._to_variant(batched_input=True)
self.assertEqual(et.shape.as_list(), [num_batched_elems])
self.assertEqual(et.dtype, dtypes.variant)
@parameterized.parameters(
# 2D test cases.
{
'ragged_constant': [[]],
'ragged_rank': 1,
},
{
'ragged_constant': [[1]],
'ragged_rank': 1,
},
{
'ragged_constant': [[1, 2]],
'ragged_rank': 1,
},
{
'ragged_constant': [[1], [2], [3]],
'ragged_rank': 1,
},
{
'ragged_constant': [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
'ragged_rank': 1,
},
{
'ragged_constant': [[1, 2], [3, 4, 5], [6], [], [7]],
'ragged_rank': 1,
},
# 3D test cases.
{
'ragged_constant': [[[]]],
'ragged_rank': 2,
},
{
'ragged_constant': [[[1]]],
'ragged_rank': 2,
},
{
'ragged_constant': [[[1, 2]]],
'ragged_rank': 2,
},
{
'ragged_constant': [[[1, 2], [3, 4]]],
'ragged_rank': 2,
},
{
'ragged_constant': [[[1, 2]], [[3, 4]], [[5, 6]], [[7, 8]]],
'ragged_rank': 2,
},
{
'ragged_constant': [[[1], [2]], [[3], [4]], [[5], [6]], [[7], [8]]],
'ragged_rank': 2,
},
{
'ragged_constant': [[[1, 2]], [], [[3, 4]], []],
'ragged_rank': 2,
},
# 4D test cases.
{
'ragged_constant': [[[[1, 2], [3, 4]]],
[[[0, 0], [0, 0]], [[5, 6], [7, 8]]], []],
'ragged_rank': 3,
},
# dtype `string`.
{
'ragged_constant': [['a'], ['b'], ['c']],
'ragged_rank': 1,
'dtype': dtypes.string,
},
{
'ragged_constant': [[['a', 'b'], ['c', 'd']]],
'ragged_rank': 2,
'dtype': dtypes.string,
},
{
'ragged_constant': [[[['a', 'b'], ['c', 'd']]],
[[['e', 'f'], ['g', 'h']], [['i', 'j'],
['k', 'l']]], []],
'ragged_rank': 3,
'dtype': dtypes.string,
})
def testVariantRoundTrip(self,
ragged_constant,
ragged_rank,
dtype=dtypes.int32):
rt = ragged_factory_ops.constant(
ragged_constant, ragged_rank=ragged_rank, dtype=dtype)
et = rt._to_variant()
round_trip_rt = RaggedTensor._from_variant(
et, dtype, output_ragged_rank=ragged_rank)
self.assertAllEqual(rt, round_trip_rt)
def testBatchedVariantRoundTripInputRaggedRankInferred(self):
ragged_rank = 1
rt = ragged_factory_ops.constant(
[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]],
ragged_rank=ragged_rank)
batched_variant = rt._to_variant(batched_input=True)
nested_batched_variant = array_ops.reshape(batched_variant, [5, 2])
decoded_rt = RaggedTensor._from_variant(
nested_batched_variant,
dtype=dtypes.int32,
output_ragged_rank=ragged_rank + 1)
expected_rt = ragged_factory_ops.constant([[[0], [1]], [[2], [3]], [[4],
[5]],
[[6], [7]], [[8], [9]]])
self.assertAllEqual(decoded_rt, expected_rt)
def testBatchedVariantRoundTripWithInputRaggedRank(self):
ragged_rank = 1
rt = ragged_factory_ops.constant(
[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]],
ragged_rank=ragged_rank)
batched_variant = rt._to_variant(batched_input=True)
nested_batched_variant = array_ops.reshape(batched_variant, [5, 2])
decoded_rt = RaggedTensor._from_variant(
nested_batched_variant,
dtype=dtypes.int32,
output_ragged_rank=ragged_rank + 1,
input_ragged_rank=ragged_rank - 1)
expected_rt = ragged_factory_ops.constant([[[0], [1]], [[2], [3]], [[4],
[5]],
[[6], [7]], [[8], [9]]])
self.assertAllEqual(decoded_rt, expected_rt)
def testUnbatchVariant(self): # b/141789000
rt = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [], [6, 7, 8, 9]])
batched = rt._to_variant(batched_input=True)
for i in range(4):
row = RaggedTensor._from_variant(
batched[i], dtype=dtypes.int32, output_ragged_rank=0)
self.assertAllEqual(rt[i], row)
def testUnbatchVariantInDataset(self):
rt = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [], [6, 7, 8, 9]])
ds = dataset_ops.Dataset.from_tensor_slices(rt)
if context.executing_eagerly():
for i, value in enumerate(ds):
self.assertAllEqual(rt[i], value)
else:
it = dataset_ops.make_one_shot_iterator(ds)
out = it.get_next()
with self.cached_session() as sess:
for i in range(3):
self.assertAllEqual(sess.run(rt[i]), out)
def testToVariantInvalidParams(self):
self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
r'be rank 1 but is rank 0',
gen_ragged_conversion_ops.ragged_tensor_to_variant,
rt_nested_splits=[0, 1, 2],
rt_dense_values=[0, 1, 2],
batched_input=True)
self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
r'be rank 1 but is rank 2',
gen_ragged_conversion_ops.ragged_tensor_to_variant,
rt_nested_splits=[[[0]], [[1]], [[2]]],
rt_dense_values=[0, 1, 2],
batched_input=True)
def testFromVariantInvalidParams(self):
rt = ragged_factory_ops.constant([[0], [1], [2], [3]])
batched_variant = rt._to_variant(batched_input=True)
nested_batched_variant = array_ops.reshape(batched_variant, [2, 2])
with self.assertRaisesRegex(ValueError,
r'`output_ragged_rank` \(1\) must be equal to'):
RaggedTensor._from_variant(
nested_batched_variant,
dtype=dtypes.int32,
output_ragged_rank=1,
input_ragged_rank=1)
def testUnbatchToTensor(self):
batched = ragged_factory_ops.constant([[0], [1], [2], [3]])
unbatched = [constant_op.constant(x) for x in [[0], [1], [2], [3]]]
batched_spec = type_spec.type_spec_from_value(batched)
# Note that the unbatched_spec is derived from the batched spec, so it can
# add back a ragged instead of a dense tensor.
unbatched_spec = batched_spec._unbatch()
batched_tensor_list = batched_spec._to_batched_tensor_list(batched)
unbatched_tensor_lists = zip(
*[array_ops_stack.unstack(tensor) for tensor in batched_tensor_list])
actual_unbatched = [
batched_spec._unbatch()._from_tensor_list(tensor_list)
for tensor_list in unbatched_tensor_lists]
self.assertLen(actual_unbatched, len(unbatched))
for x in actual_unbatched:
self.assertTrue(unbatched_spec.is_compatible_with(x))
for (actual, expected) in zip(actual_unbatched, unbatched):
self.assertAllEqual(actual, expected)
def testDatasetUnbatchTwice(self):
batched = ragged_factory_ops.constant([[[0], [1], [5]], [[2], [3]]])
ds = dataset_ops.Dataset.from_tensors(batched)
ds2 = ds.unbatch()
ds3 = ds2.unbatch()
if context.executing_eagerly():
value = next(iter(ds3))
self.assertAllEqual([0], value)
def testDatasetUnbatchToScalar(self):
batched = ragged_factory_ops.constant([[0], [1], [2], [3]])
ds = dataset_ops.Dataset.from_tensors(batched)
ds2 = ds.unbatch()
ds3 = ds2.unbatch()
if context.executing_eagerly():
value = next(iter(ds3))
self.assertAllEqual(0, value)
def testBatchToTensor(self):
batched = ragged_factory_ops.constant([[0], [1], [2], [3]])
unbatched = [constant_op.constant(x) for x in [[0], [1], [2], [3]]]
batched_spec = type_spec.type_spec_from_value(batched)
# Note that the unbatched_spec is derived from the batched spec, so it can
# add back a ragged instead of a dense tensor.
unbatched_spec = batched_spec._unbatch()
unbatched_tensor_lists = [unbatched_spec._to_tensor_list(x)
for x in unbatched]
batched_tensor_list = [array_ops_stack.stack(tensors)
for tensors in zip(*unbatched_tensor_lists)]
actual_batched = unbatched_spec._batch(4)._from_tensor_list(
batched_tensor_list)
self.assertAllEqual(actual_batched, batched)
def _testGradient(self, func, x, expected_grad, grad_y=None):
x = ragged_factory_ops.constant(x)
if grad_y is not None:
grad_y = ragged_factory_ops.constant(grad_y)
if context.executing_eagerly():
with backprop.GradientTape() as t:
t.watch(x)
y = func(x)
g = t.gradient(y, x, grad_y)
else:
y = func(x)
g = gradients_impl.gradients(ys=y, xs=x, grad_ys=grad_y)[0]
if expected_grad is None:
self.assertIsNone(g)
else:
g = ragged_tensor.convert_to_tensor_or_ragged_tensor(g)
self.assertAllClose(g, expected_grad)
@parameterized.named_parameters([
dict(
testcase_name='RaggedInput',
func=lambda x: math_ops.reduce_prod(x, axis=1),
x=[[1., 2.], [3.]],
expected=[[2., 1.], [1.]]),
dict(
testcase_name='RaggedOutput',
func=lambda x: ragged_concat_ops.stack([x, x[:1]]),
x=[3., 2.],
expected=[2., 1.]),
dict(
testcase_name='RaggedInputAndOutput',
func=lambda x: array_ops_stack.stack([x, x * x]),
x=[[1., 2.], [3.]],
expected=[[3., 5.], [7.]]),
dict(
testcase_name='RaggedOutputWithGradYs',
func=lambda x: ragged_concat_ops.stack([x, x[:1]]),
x=[3., 2.],
grad_ys=[[1., 1.], [1.]],
expected=[2., 1.]),
dict(
testcase_name='RaggedInputAndOutputWithGradYs',
func=lambda x: array_ops_stack.stack([x, x * x]),
x=[[1., 2.], [3.]],
grad_ys=[[[1., 1.], [1.]], [[1., 1.], [1.]]],
expected=[[3., 5.], [7.]]),
dict(
testcase_name='RaggedRank3',
func=lambda x: ragged_concat_ops.stack([x, (x * x)[:, 1:]]),
x=[[[1., 2.], [3., 4., 5.]], [[6.]]],
expected=[[[1.0, 1.0], [7.0, 9.0, 11.0]], [[1.0]]]),
dict(
testcase_name='RaggedIndexedSlices',
func=lambda x: ragged_gather_ops.gather(x, [0, 2]),
x=[[1., 2.], [3.], [4., 5., 6.]],
expected=[[1., 1.], [0.], [1., 1., 1.]]),
])
def testGradient(self, func, x, expected, grad_ys=None):
self._testGradient(func, x, expected, grad_ys)
def testHigherOrderGradient(self):
x = ragged_factory_ops.constant([[1.0, 2.0], [3.0]])
with backprop.GradientTape() as t2:
t2.watch(x)
with backprop.GradientTape() as t1:
t1.watch(x)
y = x * x * x
dy_dx = t1.gradient(y, x)
d2y_dx2 = t2.gradient(dy_dx, x)
self.assertAllEqual(dy_dx, [[3.0, 12.0], [27.0]])
self.assertAllEqual(d2y_dx2, [[6.0, 12.0], [18.0]])
def testUnconnectedGradient(self):
x = ragged_factory_ops.constant([[1.0, 2.0], [3.0]])
with backprop.GradientTape() as t:
t.watch(x)
y = ragged_factory_ops.constant([[2.0, 4.0], [6.0]])
self.assertIsNone(t.gradient(y, x))
def testStopGradient(self):
def func(x):
y = x * constant_op.constant([[1.], [3.]])
y = y.with_values(array_ops.stop_gradient(y.values))
z = x * y
return math_ops.reduce_sum(z)
self._testGradient(func, [[1., 2.], [3., 4., 5.]],
[[1., 2.], [9., 12., 15.]])
def testStopGradientNoneComponent(self):
def func(x):
y = x * constant_op.constant([[1.], [3.]])
y = y.with_values(array_ops.stop_gradient(y.values))
return y
self._testGradient(func, [[1., 2], [3, 4, 5]], None)
def testRaggedVariantGradients(self):
def func(x):
rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8])
rt2 = rt1 * [[10], [100], [1000]]
v = rt2._to_variant(batched_input=False)
rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1)
return rt3.flat_values
self._testGradient(func, [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[10., 10., 10., 10., 100., 100., 100., 1000.])
def testRaggedVariantGradientsEmptyRows(self):
def func(x):
rt1 = RaggedTensor.from_row_splits(
values=x, row_splits=[0, 2, 2, 4, 7, 7, 8])
rt2 = rt1 * [[10], [20], [30], [40], [50], [60]]
v = rt2._to_variant(batched_input=False)
rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1)
return rt3.flat_values
self._testGradient(func, [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[10., 10., 30., 30., 40., 40., 40., 60.])
def testRaggedVariantSteps(self):
x = [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0]
rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8])
rt2 = rt1 * [[10], [100], [1000]]
v = rt2._to_variant(batched_input=False)
rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1)
self.assertAllClose([30., 10., 40., 10., 100., 0., 200., 1000.],
rt3.flat_values)
def testRaggedVariantGradientsBatched(self):
def func(x):
rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8])
rt2 = rt1 * [[10], [100], [1000]]
v = rt2._to_variant(batched_input=True)
rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1)
return rt3.flat_values
self._testGradient(func, [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[10., 10., 10., 10., 100., 100., 100., 1000.])
def testRaggedVariantGradientsEmptyRowsBatched(self):
def func(x):
rt1 = RaggedTensor.from_row_splits(
values=x, row_splits=[0, 2, 2, 4, 7, 7, 8])
rt2 = rt1 * [[10], [20], [30], [40], [50], [60]]
v = rt2._to_variant(batched_input=True)
rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1)
return rt3.flat_values
self._testGradient(func, [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[10., 10., 30., 30., 40., 40., 40., 60.])
def testRaggedVariantGradientsEmptyOutputBatched(self):
def func(x):
rt1 = RaggedTensor.from_row_splits(
values=x, row_splits=[0, 0, 0, 0, 0, 0, 0])
rt2 = rt1 * [[10], [20], [30], [40], [50], [60]]
v = rt2._to_variant(batched_input=True)
rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1)
return rt3.flat_values
self._testGradient(func, [], [])
def testRaggedVariantGradientsBatchedAndSliced(self):
def func(x, i):
rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8])
rt2 = rt1 * [[10], [100], [1000]]
v_slice = rt2._to_variant(batched_input=True)[i]
return RaggedTensor._from_variant(
v_slice, dtype=rt2.dtype, output_ragged_rank=0)
self._testGradient(
functools.partial(func, i=0), [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[10., 10., 10., 10., 0., 0., 0., 0.])
self._testGradient(
functools.partial(func, i=1), [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[0., 0., 0., 0., 100., 100., 100., 0.])
self._testGradient(
functools.partial(func, i=2), [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[0., 0., 0., 0., 0., 0., 0., 1000.])
def testRaggedVariantGradientsEmptyRowsBatchedAndSliced(self):
def func(x, i):
rt1 = RaggedTensor.from_row_splits(
values=x, row_splits=[0, 2, 2, 4, 7, 7, 8])
rt2 = rt1 * [[10], [20], [30], [40], [50], [60]]
v_slice = rt2._to_variant(batched_input=True)[i]
return RaggedTensor._from_variant(
v_slice, dtype=rt2.dtype, output_ragged_rank=0)
self._testGradient(
functools.partial(func, i=0), [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[10., 10., 0., 0., 0., 0., 0., 0.])
self._testGradient(
functools.partial(func, i=1), [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[0., 0., 0., 0., 0., 0., 0., 0.])
self._testGradient(
functools.partial(func, i=2), [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[0., 0., 30., 30., 0., 0., 0., 0.])
self._testGradient(
functools.partial(func, i=3), [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[0., 0., 0., 0., 40., 40., 40., 0.])
self._testGradient(
functools.partial(func, i=4), [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[0., 0., 0., 0., 0., 0., 0., 0.])
self._testGradient(
functools.partial(func, i=5), [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[0., 0., 0., 0., 0., 0., 0., 60.])
def testRaggedVariantGradientsRaggedRank0(self):
def func(x):
x2 = x * 2
v = gen_ragged_conversion_ops.ragged_tensor_to_variant(
[], x2, batched_input=False)
return RaggedTensor._from_variant(v, dtype=x2.dtype, output_ragged_rank=0)
self._testGradient(func, [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0])
def testRaggedVariantGradientsRaggedRank3(self):
def func(x):
x2 = x * 2
rt1 = RaggedTensor.from_nested_row_splits(
x2, ([0, 0, 3], [0, 2, 2, 3], [0, 4, 7, 8]))
v = rt1._to_variant(batched_input=False)
rt3 = RaggedTensor._from_variant(v, dtype=x2.dtype, output_ragged_rank=3)
return rt3.flat_values
self._testGradient(func, [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0])
def testRaggedVariantGradientsViaMapFn(self):
rt = RaggedTensor.from_row_splits(
values=[3, 1.0, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 7, 8])
def func(x):
def transform_row(row):
return math_ops.sqrt(
math_ops.reduce_mean(math_ops.square(row * x), keepdims=True))
return math_ops.reduce_sum(map_fn.map_fn(transform_row, rt))
self._testGradient(func, 3.0, 14.653377)
def testRaggedVariantGradientsEmptyRowsViaMapFn(self):
rt = RaggedTensor.from_row_splits(
values=[3, 1.0, 4, 1, 5, 9, 2, 6], row_splits=[0, 2, 2, 4, 7, 7, 8])
def func(x):
def transform_row(row):
return math_ops.sqrt(
math_ops.reduce_mean(math_ops.square(row * x), keepdims=True))
return math_ops.reduce_sum(map_fn.map_fn(transform_row, rt))
self._testGradient(func, 3.0, 17.206844)
def testRaggedVariantGradientsEmptyOutputViaMapFn(self):
rt = RaggedTensor.from_row_splits(
values=[], row_splits=[0, 0, 0, 0])
def func(x):
def transform_row(row):
return math_ops.sqrt(
math_ops.reduce_mean(math_ops.square(row * x), keepdims=True))
return math_ops.reduce_sum(map_fn.map_fn(transform_row, rt))
self._testGradient(func, 3.0, 0.0)
def testRaggedVariantGradientsViaMapFnReduce(self):
def func(x):
rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8])
return map_fn.map_fn(
math_ops.reduce_max,
rt1,
fn_output_signature=tensor_lib.TensorSpec((), x.dtype))
self._testGradient(func, [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0])
def testRaggedVariantGradientsEmptyRowsViaMapFnReduce(self):
def func(x):
rt1 = RaggedTensor.from_row_splits(
values=x, row_splits=[0, 2, 2, 4, 7, 7, 8])
return map_fn.map_fn(
math_ops.reduce_max,
rt1,
fn_output_signature=tensor_lib.TensorSpec((), x.dtype))
self._testGradient(func, [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
[1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0])
def testRaggedVariantGradientsEmptyOutputViaMapFnReduce(self):
def func(x):
rt1 = RaggedTensor.from_row_splits(
values=x, row_splits=[0, 0, 0, 0])
return map_fn.map_fn(
math_ops.reduce_max,
rt1,
fn_output_signature=tensor_lib.TensorSpec((), x.dtype))
self._testGradient(func, [], [])
def testRaggedVariantGradientsErrors(self):
if context.executing_eagerly():
return
rt = RaggedTensor.from_row_splits([1.0, 2.0], row_splits=[0, 2, 2])
v1 = rt._to_variant()
v2 = array_ops_stack.stack([array_ops_stack.stack([v1])])
y = RaggedTensor._from_variant(v2, rt.dtype, output_ragged_rank=3)
with self.assertRaisesRegex(
ValueError, 'Unable to compute gradient: RaggedTensorToVariant '
'can currently only generate 0D or 1D output.'):
gradients_impl.gradients(ys=y.flat_values, xs=rt.flat_values)
def assertNumpyObjectTensorsRecursivelyEqual(self, a, b, msg):
"""Check that two numpy arrays are equal.
For arrays with dtype=object, check values recursively to see if a and b
are equal. (c.f. `np.array_equal`, which checks dtype=object values using
object identity.)
Args:
a: A numpy array.
b: A numpy array.
msg: Message to display if a != b.
"""
if isinstance(a, np.ndarray) and a.dtype == object:
self.assertEqual(a.dtype, b.dtype, msg)
self.assertEqual(a.shape, b.shape, msg)
self.assertLen(a, len(b), msg)
for a_val, b_val in zip(a, b):
self.assertNumpyObjectTensorsRecursivelyEqual(a_val, b_val, msg)
else:
self.assertAllEqual(a, b, msg)
@parameterized.named_parameters([
('Shape_2_R',
[[1, 2], [3, 4, 5]],
np.array([int32array([1, 2]), int32array([3, 4, 5])], dtype=object)),
('Shape_2_2',
[[1, 2], [3, 4]],
np.array([[1, 2], [3, 4]])),
('Shape_2_R_2',
[[[1, 2], [3, 4]], [[5, 6]]],
np.array([int32array([[1, 2], [3, 4]]), int32array([[5, 6]])],
dtype=object)),
('Shape_3_2_R',
[[[1], []], [[2, 3], [4]], [[], [5, 6, 7]]],
np.array([[int32array([1]), int32array([])],
[int32array([2, 3]), int32array([4])],
[int32array([]), int32array([5, 6, 7])]], dtype=object)),
('Shape_0_R',
ragged_factory_ops.constant_value([], ragged_rank=1, dtype=np.int32),
np.zeros([0, 0], dtype=np.int32)),
('Shape_0_R_2',
ragged_factory_ops.constant_value([], ragged_rank=1,
inner_shape=(2,), dtype=np.int32),
np.zeros([0, 0, 2], dtype=np.int32)),
]) # pyformat: disable
def testRaggedTensorNumpy(self, rt, expected):
if isinstance(rt, list):
rt = ragged_factory_ops.constant(rt, dtype=dtypes.int32)
else:
rt = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt)
if context.executing_eagerly():
actual = rt.numpy()
self.assertNumpyObjectTensorsRecursivelyEqual(
expected, actual, 'Expected %r, got %r' % (expected, actual))
else:
with self.assertRaisesRegex(ValueError, 'only supported in eager mode'):
rt.numpy()
@parameterized.parameters([
([[[1, 2], [3, 4, 5]], [[6]]], 2, None),
([[[1, 2], [3, 4, 5]], [[6]]], 2, [None, None, None]),
([[[1, 2], [3, 4, 5]], [[6]]], 2, [2, None, None]),
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, None),
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [None, None, None]),
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, None]),
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, 3]),
([[[1, 2, 3]]], 1, [1, 1, None]),
([[[1, 2, 3]]], 1, [1, 1, 3]),
])
def testRaggedTensorSetShape(self, rt, rt_ragged_rank, shape):
rt1 = ragged_factory_ops.constant(rt, ragged_rank=rt_ragged_rank)
rt1._set_shape(shape)
rt1.shape.assert_is_compatible_with(shape)
if shape is not None:
self.assertIsNot(rt1.shape.rank, None)
for a, b in zip(rt1.shape, shape):
if b is not None:
self.assertEqual(a, b)
@parameterized.parameters([
([[[1, 2], [3, 4, 5]], [[6]]], 2, None),
([[[1, 2], [3, 4, 5]], [[6]]], 2, [None, None, None]),
([[[1, 2], [3, 4, 5]], [[6]]], 2, [2, None, None]),
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, None),
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [None, None, None]),
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, None]),
([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, 3]),
([[[1, 2, 3]]], 1, [1, 1, None]),
([[[1, 2, 3]]], 1, [1, 1, 3]),
])
def testRaggedTensorSetShapeWithPlaceholders(self, rt, rt_ragged_rank, shape):
rt2 = nest.map_structure(
lambda x: array_ops.placeholder_with_default(x, None),
ragged_factory_ops.constant(rt, ragged_rank=rt_ragged_rank),
expand_composites=True)
rt2._set_shape(shape)
rt2.shape.assert_is_compatible_with(shape)
if shape is not None:
self.assertIsNot(rt2.shape.rank, None)
for a, b in zip(rt2.shape, shape):
if b is not None:
self.assertEqual(a, b)
def testRaggedTensorSetShapeUniformRowLength(self):
rt = [[[1], [2], [3]], [[4], [5], [6]]]
rt1 = RaggedTensor.from_tensor(rt, ragged_rank=1)
rt1._set_shape([2, 3, 1])
rt2 = nest.map_structure(
lambda x: array_ops.placeholder_with_default(x, None),
rt1,
expand_composites=True)
rt2._set_shape([2, 3, 1])
def testRaggedTensorSetShapeInconsistentShapeError(self):
rt = RaggedTensor.from_tensor([[[1], [2], [3]], [[4], [5], [6]]],
ragged_rank=1)
self.assertEqual(rt.shape.as_list(), [2, 3, 1])
with self.assertRaises(ValueError):
rt._set_shape([None, None, 5])
with self.assertRaisesRegex(ValueError, 'Inconsistent size'):
rt._set_shape([None, 5, None])
with self.assertRaises(ValueError):
rt._set_shape([5, None, None])
@test_util.run_all_in_graph_and_eager_modes
class RaggedTensorSpecTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
def assertAllTensorsEqual(self, list1, list2):
self.assertLen(list1, len(list2))
for (t1, t2) in zip(list1, list2):
self.assertAllEqual(t1, t2)
def testConstruction(self):
spec1 = RaggedTensorSpec(ragged_rank=1)
self.assertIsNone(spec1._shape.rank)
self.assertEqual(spec1._dtype, dtypes.float32)
self.assertEqual(spec1._row_splits_dtype, dtypes.int64)
self.assertEqual(spec1._ragged_rank, 1)
self.assertIsNone(spec1.shape.rank)
self.assertEqual(spec1.dtype, dtypes.float32)
self.assertEqual(spec1.row_splits_dtype, dtypes.int64)
self.assertEqual(spec1.ragged_rank, 1)
spec2 = RaggedTensorSpec(shape=[None, None, None])
self.assertEqual(spec2._shape.as_list(), [None, None, None])
self.assertEqual(spec2._dtype, dtypes.float32)
self.assertEqual(spec2._row_splits_dtype, dtypes.int64)
self.assertEqual(spec2._ragged_rank, 2)
with self.assertRaisesRegex(ValueError, 'Must specify ragged_rank'):
RaggedTensorSpec()
with self.assertRaisesRegex(TypeError, '`ragged_rank` must be an int'):
RaggedTensorSpec(ragged_rank=constant_op.constant(1))
with self.assertRaisesRegex(
ValueError,
r'Argument `ragged_rank` \(2\) must be less than rank \(2\).'):
RaggedTensorSpec(ragged_rank=2, shape=[None, None])
def testValueType(self):
spec1 = RaggedTensorSpec(ragged_rank=1)
self.assertEqual(spec1.value_type, RaggedTensor)
spec2 = RaggedTensorSpec(ragged_rank=0)
self.assertEqual(spec2.value_type, tensor_lib.Tensor)
@parameterized.parameters([
(RaggedTensorSpec(ragged_rank=1),
(tensor_shape.TensorShape(None), dtypes.float32, 1, dtypes.int64)),
(RaggedTensorSpec(shape=[5, None, None]),
(tensor_shape.TensorShape([5, None, None]), dtypes.float32,
2, dtypes.int64)),
(RaggedTensorSpec(shape=[5, None, None], dtype=dtypes.int32),
(tensor_shape.TensorShape([5, None, None]), dtypes.int32, 2,
dtypes.int64)),
(RaggedTensorSpec(ragged_rank=1, row_splits_dtype=dtypes.int32),
(tensor_shape.TensorShape(None), dtypes.float32, 1, dtypes.int32)),
]) # pyformat: disable
def testSerialize(self, rt_spec, expected):
serialization = rt_spec._serialize()
# TensorShape has an unconventional definition of equality, so we can't use
# assertEqual directly here. But repr() is deterministic and lossless for
# the expected values, so we can use that instead.
self.assertEqual(repr(serialization), repr(expected))
@parameterized.parameters([
(RaggedTensorSpec(ragged_rank=0, shape=[5, 3]), [
tensor_lib.TensorSpec([5, 3], dtypes.float32),
]),
(RaggedTensorSpec(ragged_rank=1), [
tensor_lib.TensorSpec(None, dtypes.float32),
tensor_lib.TensorSpec([None], dtypes.int64)
]),
(RaggedTensorSpec(ragged_rank=1, row_splits_dtype=dtypes.int32), [
tensor_lib.TensorSpec(None, dtypes.float32),
tensor_lib.TensorSpec([None], dtypes.int32),
]),
(RaggedTensorSpec(ragged_rank=2), [
tensor_lib.TensorSpec(None, dtypes.float32),
tensor_lib.TensorSpec([None], dtypes.int64),
tensor_lib.TensorSpec([None], dtypes.int64),
]),
(RaggedTensorSpec(shape=[5, None, None], dtype=dtypes.string), [
tensor_lib.TensorSpec([None], dtypes.string),
tensor_lib.TensorSpec([6], dtypes.int64),
tensor_lib.TensorSpec([None], dtypes.int64),
]),
])
def testComponentSpecs(self, rt_spec, expected):
self.assertEqual(rt_spec._component_specs, expected)
@parameterized.parameters([
{
'rt_spec': RaggedTensorSpec(ragged_rank=0),
'rt': [1.0, 2.0, 3.0],
'components': [[1.0, 2.0, 3.0]]
},
{
'rt_spec': RaggedTensorSpec(ragged_rank=1),
'rt': [[1.0, 2.0], [3.0]],
'components': [[1.0, 2.0, 3.0], [0, 2, 3]]
},
{
'rt_spec': RaggedTensorSpec(shape=[2, None, None]),
'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]],
'components': [[1.0, 2.0, 3.0, 4.0], [0, 2, 4], [0, 2, 3, 3, 4]]
},
])
def testToFromComponents(self, rt_spec, rt, components):
rt = ragged_factory_ops.constant(rt)
actual_components = rt_spec._to_components(rt)
self.assertAllTensorsEqual(actual_components, components)
rt_reconstructed = rt_spec._from_components(actual_components)
self.assertAllEqual(rt, rt_reconstructed)
@parameterized.parameters([
{
'flat_value_spec': tensor_lib.TensorSpec(None, dtypes.float32),
'row_splits_spec': tensor_lib.TensorSpec(None, dtypes.int64),
},
{
'flat_value_spec': tensor_lib.TensorSpec([None,], dtypes.float32),
'row_splits_spec': tensor_lib.TensorSpec(None, dtypes.int64),
},
{
'flat_value_spec': tensor_lib.TensorSpec(None, dtypes.float32),
'row_splits_spec': tensor_lib.TensorSpec([None,], dtypes.int64),
},
{
'flat_value_spec': tensor_lib.TensorSpec([None,], dtypes.float32),
'row_splits_spec': tensor_lib.TensorSpec([None,], dtypes.int64),
},
{
'flat_value_spec': tensor_lib.TensorSpec([4,], dtypes.float32),
'row_splits_spec': tensor_lib.TensorSpec(None, dtypes.int64),
},
{
'flat_value_spec': tensor_lib.TensorSpec(None, dtypes.float32),
'row_splits_spec': tensor_lib.TensorSpec([3,], dtypes.int64),
},
])
def testToFromComponentsStaticUnknownShape(self, flat_value_spec,
row_splits_spec):
rt_spec = RaggedTensorSpec(shape=[2, None], ragged_rank=1)
tester = self
@def_function.function(input_signature=[flat_value_spec, row_splits_spec])
def test_fn(flat_value, row_splits):
# Apply static shape information saved in rt_spec to rt.
rt = rt_spec._from_components([flat_value, row_splits])
tester.assertEqual(rt.shape.as_list(), [2, None])
return rt + ragged_factory_ops.constant([[1.0, 1.0, 1.0], [1.0]])
result = test_fn([1.0, 2.0, 3.0, 4.0], [0, 3, 4])
expected_result = ragged_factory_ops.constant([[2.0, 3.0, 4.0], [5.0]])
self.assertAllEqual(result, expected_result)
@test_util.run_v1_only('RaggedTensorValue is deprecated in v2')
def testFromNumpyComponents(self):
spec1 = RaggedTensorSpec(ragged_rank=1, dtype=dtypes.int32)
rt1 = spec1._from_components([np.array([1, 2, 3]), np.array([0, 2, 3])])
self.assertIsInstance(rt1, ragged_tensor_value.RaggedTensorValue)
self.assertAllEqual(rt1, [[1, 2], [3]])
spec2 = RaggedTensorSpec(ragged_rank=2, dtype=dtypes.int32)
rt2 = spec2._from_components(
[np.array([1, 2, 3]),
np.array([0, 2, 3]),
np.array([0, 0, 2, 3])])
self.assertIsInstance(rt2, ragged_tensor_value.RaggedTensorValue)
self.assertAllEqual(rt2, [[[], [1, 2]], [[3]]])
spec3 = RaggedTensorSpec(ragged_rank=0, dtype=dtypes.int32)
rt3 = spec3._from_components([np.array([1, 2, 3])])
self.assertIsInstance(rt3, np.ndarray)
self.assertAllEqual(rt3, [1, 2, 3])
@parameterized.parameters([
RaggedTensorSpec(ragged_rank=0, shape=[5, 3]),
RaggedTensorSpec(ragged_rank=1),
RaggedTensorSpec(ragged_rank=1, row_splits_dtype=dtypes.int32),
RaggedTensorSpec(ragged_rank=2, dtype=dtypes.string),
RaggedTensorSpec(shape=[5, None, None]),
])
def testFlatTensorSpecs(self, rt_spec):
self.assertEqual(rt_spec._flat_tensor_specs,
[tensor_lib.TensorSpec(None, dtypes.variant)])
@parameterized.parameters([
(dtypes.float32, full_type_pb2.TFT_FLOAT),
(dtypes.string, full_type_pb2.TFT_STRING),
])
def testFullTypesForFlatTensors(self, dt, ft):
rt_spec = RaggedTensorSpec(ragged_rank=2, dtype=dt)
full_type_list = fulltypes_for_flat_tensors(rt_spec)
expect = [
full_type_pb2.FullTypeDef(
type_id=full_type_pb2.TFT_RAGGED,
args=[full_type_pb2.FullTypeDef(type_id=ft)])
]
self.assertEqual(len(rt_spec._flat_tensor_specs), len(full_type_list))
self.assertEqual(expect, full_type_list)
@parameterized.named_parameters([
{
'testcase_name': 'RaggedRank0',
'rt_spec': RaggedTensorSpec(ragged_rank=0),
'rt': [1.0, 2.0, 3.0],
},
{
'testcase_name': 'RaggedRank1',
'rt_spec': RaggedTensorSpec(ragged_rank=1),
'rt': [[1.0, 2.0], [3.0]]
},
{
'testcase_name': 'RaggedRank2',
'rt_spec': RaggedTensorSpec(shape=[2, None, None]),
'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]]
},
])
def testToFromTensorList(self, rt_spec, rt):
rt = ragged_factory_ops.constant(rt)
tensor_list = rt_spec._to_tensor_list(rt)
rt_reconstructed = rt_spec._from_tensor_list(tensor_list)
self.assertAllEqual(rt, rt_reconstructed)
@parameterized.named_parameters([
# TODO(b/141789000) Test ragged_rank=0 when support is added.
{
'testcase_name': 'RaggedRank1',
'rt_spec': RaggedTensorSpec(ragged_rank=1),
'rt': [[1.0, 2.0], [3.0]]
},
{
'testcase_name': 'RaggedRank2',
'rt_spec': RaggedTensorSpec(shape=[2, None, None]),
'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]]
},
])
def testToFromBatchedTensorList(self, rt_spec, rt):
rt = ragged_factory_ops.constant(rt)
tensor_list = rt_spec._to_batched_tensor_list(rt)
rt_reconstructed = rt_spec._from_tensor_list(tensor_list)
self.assertAllEqual(rt, rt_reconstructed)
first_row = rt_spec._unbatch()._from_tensor_list(
[t[0] for t in tensor_list])
self.assertAllEqual(rt[0], first_row)
def testToFromBatchedTensorListPreservesUniformRowLengths(self):
rt = RaggedTensor.from_tensor(array_ops.zeros([3, 4, 5]), ragged_rank=2)
rt_spec = rt._type_spec
tensor_list = rt_spec._to_batched_tensor_list(rt)
rt_reconstructed = rt_spec._from_tensor_list(tensor_list)
self.assertAllEqual(rt, rt_reconstructed)
self.assertTrue(rt.shape.is_fully_defined())
self.assertTrue(rt_reconstructed.shape.is_fully_defined())
self.assertEqual(rt.shape.as_list(), rt_reconstructed.shape.as_list())
@parameterized.parameters([
(RaggedTensorSpec([2, None], dtypes.float32, 1), 32,
RaggedTensorSpec([32, 2, None], dtypes.float32, 2)),
(RaggedTensorSpec([4, None], dtypes.float32, 1), None,
RaggedTensorSpec([None, 4, None], dtypes.float32, 2)),
(RaggedTensorSpec([2], dtypes.float32,
-1), 32, RaggedTensorSpec([32, 2], dtypes.float32, 0)),
])
def testBatch(self, spec, batch_size, expected):
self.assertEqual(spec._batch(batch_size), expected)
@parameterized.parameters([
(RaggedTensorSpec([32, None, None], dtypes.float32, 2),
RaggedTensorSpec([None, None], dtypes.float32, 1)),
(RaggedTensorSpec([None, None, None], dtypes.float32, 2),
RaggedTensorSpec([None, None], dtypes.float32, 1)),
(RaggedTensorSpec([32, 2], dtypes.float32, 0),
RaggedTensorSpec([2], dtypes.float32, -1)),
(RaggedTensorSpec([32, None, 4], dtypes.float32, 1, dtypes.int32),
RaggedTensorSpec([None, 4], dtypes.float32, 0, dtypes.int32)),
]) # pyformat: disable
def testUnbatch(self, spec, expected):
self.assertEqual(spec._unbatch(), expected)
def testIsCompatibleWith(self):
spec1 = RaggedTensorSpec([32, None, None], dtypes.float32, 2)
spec2 = RaggedTensorSpec(None, dtypes.float32, 2)
spec3 = RaggedTensorSpec(None, dtypes.int32, 1)
spec4 = RaggedTensorSpec([None], dtypes.int32, 0)
self.assertTrue(spec1.is_compatible_with(spec2))
self.assertFalse(spec1.is_compatible_with(spec3))
self.assertFalse(spec1.is_compatible_with(spec4))
self.assertFalse(spec2.is_compatible_with(spec3))
self.assertFalse(spec2.is_compatible_with(spec4))
self.assertFalse(spec3.is_compatible_with(spec4))
self.assertTrue(spec4.is_compatible_with(constant_op.constant([1, 2, 3])))
if __name__ == '__main__':
googletest.main()