tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import collections
import dataclasses
import functools
import itertools
import multiprocessing.pool
import pickle
import platform
import re
import sys
import time
import weakref
from absl.testing import parameterized
import numpy
from tensorflow.core.function import trace_type
from tensorflow.core.function.capture import capture_container
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.lang import directives
from tensorflow.python.checkpoint.checkpoint import Checkpoint
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import cancellation
from tensorflow.python.eager import context
from tensorflow.python.eager import lift_to_graph
from tensorflow.python.eager.polymorphic_function import attributes as attributes_lib
from tensorflow.python.eager.polymorphic_function import polymorphic_function
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import extension_type
from tensorflow.python.framework import function as tf_function
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor as tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import type_spec
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import cond_v2
from tensorflow.python.ops import control_flow_assert
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_random_ops
from tensorflow.python.ops import gen_sendrecv_ops
from tensorflow.python.ops import gen_training_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.structured import structured_tensor
from tensorflow.python.platform import test
from tensorflow.python.saved_model import save_context
from tensorflow.python.saved_model import save_options
from tensorflow.python.saved_model.load import load
from tensorflow.python.saved_model.save import save
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
def total_function_cache(defined):
return defined._list_all_concrete_functions() # pylint: disable=protected-access
def _example_indexed_slices_with_dense_shape():
return indexed_slices.IndexedSlices(
constant_op.constant([1, 2]), constant_op.constant([0, 1]),
constant_op.constant([2]))
def _example_indexed_slices_without_dense_shape():
return indexed_slices.IndexedSlices(
constant_op.constant([1, 2]), constant_op.constant([0, 1]))
def _spec_for_value(value):
"""Returns the (nested) TypeSpec for a value."""
if nest.is_nested(value):
return nest.map_structure(_spec_for_value, value)
elif isinstance(value, (tensor_lib.Tensor, composite_tensor.CompositeTensor)):
return type_spec.type_spec_from_value(value)
else:
return value
# This dummy decorator imitates ordinary decorators utilizing tf_decorator.
def dummy_tf_decorator(method):
def wrapper(*args, **kwargs):
return method(*args, **kwargs)
return tf_decorator.make_decorator(method, wrapper)
def undecorated_function(x):
return x * 3.
class _HasDecoratedMethod(object):
@polymorphic_function.function
def f(self, x):
return x * 3.
class FunctionBenchmark(test.Benchmark):
"""Benchmark the tf.function implementation."""
def benchmark_repeat_captures_property_access(self):
n_iters = 1000000
n_captures = 100
vs = []
for _ in range(n_captures):
vs.append(variables.Variable(1.0))
def f():
result = 0
for idx in range(n_captures):
result += vs[idx]
return result
pf = polymorphic_function.function(f)
g = pf.get_concrete_function().graph
start_time = time.time()
for _ in range(n_iters):
temp = g.captures # pylint: disable=unused-variable
duration = time.time() - start_time
self.report_benchmark(iters=n_iters, wall_time=duration / float(n_iters))
@dataclasses.dataclass
class MaskedTensor:
mask: bool
value: tensor_lib.Tensor
def __tf_flatten__(self):
metadata = (self.mask,)
components = (self.value,)
return metadata, components
@classmethod
def __tf_unflatten__(cls, metadata, leaves):
mask = metadata[0]
value = leaves[0]
return MaskedTensor(mask=mask, value=value)
@dataclasses.dataclass
class MaskedTensorPair:
masks: list[bool]
value1: MaskedTensor
value2: MaskedTensor
def __tf_flatten__(self):
metadata = (self.masks,)
components = (self.value1, self.value2)
return metadata, components
@classmethod
def __tf_unflatten__(cls, metadata, leaves):
masks = metadata[0]
value1, value2 = leaves
return MaskedTensorPair(masks=masks, value1=value1, value2=value2)
# TODO(mdan): Organize these tests.
class FunctionTest(test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
cpus = config.list_physical_devices('CPU')
# Set 4 virtual CPUs
config.set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration()
])
def testBasic(self):
matmul = polymorphic_function.function(math_ops.matmul)
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
sq = matmul(t, t, transpose_a=True)
sq2 = matmul(sq, t, transpose_a=True)
self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
self.assertAllEqual(sq2.numpy().reshape(-1), [52, 76, 74, 108])
def testPythonFunctionNotCallable(self):
with self.assertRaisesRegex(TypeError, 'is not a callable object'):
polymorphic_function.function(1)
def testOnExitCallback(self):
values = []
def append_1():
values.append(1)
def append_2():
values.append(2)
def g(x):
old_values = list(values)
ops.add_exit_callback_to_default_func_graph(append_1)
self.assertEqual(old_values, values)
return x + 1
tf_g = polymorphic_function.function(g)
def f(x):
old_values = list(values)
ops.add_exit_callback_to_default_func_graph(append_2)
self.assertEqual(old_values, values)
return tf_g(x)
tf_f = polymorphic_function.function(f)
self.assertEmpty(values)
tf_f(constant_op.constant(1.0))
self.assertEqual(values, [1, 2]) # Once for g, once for f.
tf_f(constant_op.constant([1.0])) # force a retrace
self.assertEqual(values, [1, 2, 1, 2]) # And again.
def testCannotAddExitCallbackWhenNotInFunctionScope(self):
with self.assertRaisesRegex(RuntimeError, 'when not building a function.'):
ops.add_exit_callback_to_default_func_graph(lambda: None)
def testVariable(self):
v1 = variables.Variable(1.0)
add = polymorphic_function.function(lambda x, v: x + v1 + v)
v2 = variables.Variable(1.0)
x = constant_op.constant(1.0)
r = add(x, v2)
self.assertEqual(3.0, self.evaluate(r))
def testVariableOnly(self):
v = variables.Variable(1.0)
add = polymorphic_function.function(lambda x: x.assign_add(1.0))
r1 = add(v)
self.assertEqual(2.0, self.evaluate(r1))
c = constant_op.constant(1.0)
with self.assertRaisesRegex(AttributeError, 'no attribute'):
add(c)
def testVariableMultiFunction(self):
@polymorphic_function.function
def second(dup_var, dup_var_2, some_const):
return dup_var + dup_var_2 + some_const
@polymorphic_function.function
def first(dup_var, some_const):
return second(dup_var, dup_var, some_const)
my_const = constant_op.constant(1)
my_var = variables.Variable(2, dtype=dtypes.int32)
self.assertEqual(second(my_var, my_var, my_const).numpy(), 5)
self.assertEqual(first(my_var, my_const).numpy(), 5)
@test_util.disable_tfrt('Packed tensor is not supported in tfrt yet.')
def testPackedVariable(self):
with ops.device('/cpu:0'):
v0_0 = resource_variable_ops.ResourceVariable(1.0)
with ops.device('/cpu:1'):
v0_1 = resource_variable_ops.ResourceVariable(2.0)
v1_0 = resource_variable_ops.ResourceVariable(3.0)
with ops.device('/cpu:2'):
v1_1 = resource_variable_ops.ResourceVariable(4.0)
packed_var_0 = ops.pack_eager_tensors([v0_0.handle, v0_1.handle])
packed_var_1 = ops.pack_eager_tensors([v1_0.handle, v1_1.handle])
# TODO(b/145922293): use ResourceVariable.assign_add and
# ResourceVariable.read_value directly once we support packing multiple
# ResourceVariable into one ResourceVariable.
@polymorphic_function.function
def read_var():
resource_variable_ops.assign_add_variable_op(packed_var_0,
constant_op.constant(5.0))
resource_variable_ops.assign_add_variable_op(packed_var_1,
constant_op.constant(6.0))
with ops.device('/cpu:0'):
read0 = resource_variable_ops.read_variable_op(
packed_var_0, dtype=dtypes.float32)
with ops.device('/cpu:1'):
read1 = resource_variable_ops.read_variable_op(
packed_var_0, dtype=dtypes.float32)
read2 = resource_variable_ops.read_variable_op(
packed_var_1, dtype=dtypes.float32)
with ops.device('/cpu:2'):
read3 = resource_variable_ops.read_variable_op(
packed_var_1, dtype=dtypes.float32)
return read0, read1, read2, read3
arg_attrs = read_var.get_concrete_function().function_def.arg_attr
self.assertLen(arg_attrs, 2)
self.assertEqual(arg_attrs[0].attr['_composite_device'].s,
compat.as_bytes(packed_var_0.device))
self.assertEqual(arg_attrs[1].attr['_composite_device'].s,
compat.as_bytes(packed_var_1.device))
self.assertAllEqual(read_var(), (1 + 5, 2 + 5, 3 + 6, 4 + 6))
def testImplementsAttributeBasic(self):
v = polymorphic_function.function(
experimental_implements='func')(lambda x, y: x + y)
with context.graph_mode(), self.cached_session():
a = array_ops.placeholder(dtypes.float32, ())
b = array_ops.placeholder(dtypes.float32, ())
v(a, b)
gradients_impl.gradients(v(a, b), [a, b])
fdefs = ops.get_default_graph().as_graph_def().library.function
self.assertLen(fdefs, 3)
not_present = 0
present = 0
for f in fdefs:
name = f.signature.name
if 'forward' in name or 'backward' in name:
not_present += 1
self.assertNotIn(attributes_lib.IMPLEMENTS,
f.attr, f)
else:
present += 1
self.assertEqual(
f.attr[attributes_lib.IMPLEMENTS].s,
'func'.encode('ascii'), f)
self.assertEqual(not_present, 2, fdefs)
self.assertEqual(present, 1, fdefs)
def testImplementsAttributeAssertsOnSideInput(self):
with context.graph_mode(), self.cached_session():
z = array_ops.zeros(0)
v = polymorphic_function.function(
experimental_implements='func')(lambda x, y: x + y + z)
a = array_ops.ones((1,))
b = array_ops.ones((1,))
with self.assertRaisesRegex(AssertionError,
'variables are always captured'):
v(a, b)
functions = ops.get_default_graph().as_graph_def().library.function
self.assertEmpty(functions)
def testImplementsAttributeWorksWithGradientTape(self):
add = lambda x, y: x + y**2
add = polymorphic_function.function(experimental_implements='MyFunc')(add)
x = variables.Variable(3.0)
y = variables.Variable(2.0)
with backprop.GradientTape() as tape:
g = add(x, y)
dg_dy, dg_dx = tape.gradient(g, [y, x])
self.assertEqual(dg_dy.numpy(), 4.0)
self.assertEqual(dg_dx.numpy(), 1.0)
def testImplementsAttributeWorksOnVariables(self):
with context.graph_mode(), self.cached_session():
v = polymorphic_function.function(
experimental_implements='func')(lambda x, y: x + y)
a = variables.Variable((1.0,))
b = variables.Variable((1.0,))
r1 = v(a, b)
_ = v(a, a)
functions = ops.get_default_graph().as_graph_def().library.function
# Verify that we created only one function
self.assertLen(functions, 1)
# Verify that self.evaluate() reads the current values.
a.initializer.run()
b.initializer.run()
self.assertEqual(self.evaluate(r1), 2)
self.evaluate(a.assign_add([1]))
self.assertEqual(self.evaluate(r1), 3)
def testImplementsAttributeWorksOnConstants(self):
with context.graph_mode(), self.cached_session():
v = polymorphic_function.function(
experimental_implements='func')(lambda x, y: x + y)
a = variables.Variable(1.0)
r1 = v(a, 2.)
r2 = v(2., a)
functions = ops.get_default_graph().as_graph_def().library.function
self.assertLen(functions, 1)
self.assertLen(functions[0].signature.input_arg, 2)
# Verify that self.evaluate() reads the current values.
a.initializer.run()
self.assertEqual(self.evaluate(r1), 3)
self.assertEqual(self.evaluate(r2), 3)
def testImplementsAttributeSpecializes(self):
with context.graph_mode(), self.cached_session():
v = polymorphic_function.function(
experimental_implements='func')(lambda x, y: x + y)
a = variables.Variable(1.0)
r1 = v(a, [2.])
r2 = v([2., 2], a)
functions = ops.get_default_graph().as_graph_def().library.function
self.assertLen(functions, 2)
# Ensure that all parameters are still there and haven't been inlined!
self.assertLen(functions[0].signature.input_arg, 2)
self.assertLen(functions[1].signature.input_arg, 2)
# Verify that self.evaluate() reads the current values.
a.initializer.run()
numpy.testing.assert_equal(self.evaluate(r1), [3.])
numpy.testing.assert_equal(self.evaluate(r2), [3., 3.])
def testImplementsWorksWithTensorSpec(self):
v = polymorphic_function.function(
experimental_implements='func')(lambda x, y: x + y)
v = v.get_concrete_function(
tensor_lib.TensorSpec(shape=None, dtype=dtypes.float32),
tensor_lib.TensorSpec(shape=None, dtype=dtypes.float32))
x = v(1., 2.)
self.assertEqual(x.numpy(), 3.)
def testImplementsAttributeAsNameAttrList(self):
implements_attr = (
'name: "embedding_matmul" attr { key: "key1" value { i: 2 } '
'} attr { key: "key2" value { b: false } }')
v = polymorphic_function.function(
experimental_implements=implements_attr)(lambda x, y: x + y)
with context.graph_mode(), self.cached_session():
a = array_ops.placeholder(dtypes.float32, ())
b = array_ops.placeholder(dtypes.float32, ())
v(a, b)
gradients_impl.gradients(v(a, b), [a, b])
fdefs = ops.get_default_graph().as_graph_def().library.function
self.assertLen(fdefs, 3)
not_present = 0
present = 0
for f in fdefs:
name = f.signature.name
if 'forward' in name or 'backward' in name:
not_present += 1
self.assertNotIn(attributes_lib.IMPLEMENTS,
f.attr, f)
else:
present += 1
attr_value = f.attr[attributes_lib.IMPLEMENTS]
self.assertIsNotNone(attr_value.func, f)
self.assertEqual(attr_value.func.name, 'embedding_matmul')
name_attrs = attr_value.func.attr
self.assertLen(name_attrs, 2)
self.assertEqual(not_present, 2, fdefs)
self.assertEqual(present, 1, fdefs)
def testDisableACDAttribute(self):
v = resource_variable_ops.ResourceVariable(1.0)
def foo(x, y):
nonlocal v
t = v.read_value()
v.assign_add(x + y)
return t
with_acd = polymorphic_function.function(foo)
without_acd = polymorphic_function.function(
foo, experimental_attributes={'_disable_acd': True}
)
with_acd_control_outputs = with_acd.get_concrete_function(
tensor_lib.TensorSpec(shape=None, dtype=dtypes.float32),
tensor_lib.TensorSpec(shape=None, dtype=dtypes.float32),
).graph.control_outputs
without_acd_control_outputs = without_acd.get_concrete_function(
tensor_lib.TensorSpec(shape=None, dtype=dtypes.float32),
tensor_lib.TensorSpec(shape=None, dtype=dtypes.float32),
).graph.control_outputs
self.assertLen(with_acd_control_outputs, 2)
self.assertEmpty(without_acd_control_outputs)
def testReduceTracingWithNestedTFFunction(self):
v = resource_variable_ops.ResourceVariable([1., 2.])
@polymorphic_function.function(reduce_retracing=True)
def inner_test_fn(x):
x.assign_add([2., 2.])
return x
@polymorphic_function.function(reduce_retracing=True)
def test_fn(x):
x.assign_add([1., 1.])
return inner_test_fn(x)
with backprop.GradientTape() as tape:
y = test_fn(v)
grad = tape.gradient(y, v)
self.assertAllEqual(y, [4., 5.])
self.assertAllEqual(grad, [1., 1.])
with backprop.GradientTape() as tape:
y = test_fn(v)
grad = tape.gradient(y, v)
self.assertAllEqual(y, [7., 8.])
self.assertAllEqual(grad, [1., 1.])
def testInputShapeRelaxationOnInstanceMethod(self):
# Test that reduce_retracing is passed during
# instance method bounding.
unknown_dim = [False]
class Foo:
@polymorphic_function.function(reduce_retracing=True)
def func(self, a):
if a._shape_tuple()[0] is None:
unknown_dim[0] = True
return a + 1
foo = Foo()
foo.func(constant_op.constant([]))
self.assertFalse(unknown_dim[0])
foo.func(constant_op.constant([1.0]))
self.assertTrue(unknown_dim[0])
foo.func(constant_op.constant([1.0, 2.0]))
self.assertTrue(unknown_dim[0])
def testInputShapeFunctionRelaxationWithRaggedTensors(self):
traced_type_spec = [None]
@polymorphic_function.function(reduce_retracing=True)
def func(x):
traced_type_spec[0] = x._type_spec
return x
def check_trace(x, expected_trace):
traced_type_spec[0] = None
func(x)
self.assertEqual(traced_type_spec[0], expected_trace)
check_trace( # Initial call gets traced.
ragged_factory_ops.constant([[1], [2, 3, 4]]),
ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32))
check_trace( # Input TypeSpec is the same -> no retrace.
ragged_factory_ops.constant([[1, 2], [3, 4]]), None)
check_trace( # Even if component tensor shapes change -> no retrace.
ragged_factory_ops.constant([[1, 2], [3, 4, 5, 6]]), None)
check_trace( # Different TypeSpec shape (nrows): relax & retrace
ragged_factory_ops.constant([[1], [2], [3]]),
ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32))
check_trace( # Different nrows again: relax & retrace
ragged_factory_ops.constant([[1], [2], [3], [4]]), None)
check_trace( # Different nrows yet again: not retrace
ragged_factory_ops.constant([[1]]), None)
check_trace( # Different ragged_rank: retrace
ragged_factory_ops.constant([[[1]]]),
ragged_tensor.RaggedTensorSpec([1, None, None], dtypes.int32))
check_trace( # Different ragged_rank again: retrace & relax
ragged_factory_ops.constant([[[1]], [[2]]]),
ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.int32))
def testInputShapeFunctionRelaxationWithStructuredTensors(self):
traced_type_spec = [None]
@polymorphic_function.function(reduce_retracing=True)
def func(x):
traced_type_spec[0] = x._type_spec
return x
def check_trace(x, expected_trace):
traced_type_spec[0] = None
func(x)
self.assertEqual(traced_type_spec[0], expected_trace)
# If we have TypeSpecs that differ in ways other than just their shape,
# then retrace each time.
check_trace(
structured_tensor.StructuredTensor.from_pyval({'a': [1]}),
structured_tensor.StructuredTensor.Spec._from_fields_and_rank(
fields={'a': tensor_lib.TensorSpec((1,), dtypes.int32)}, rank=0))
check_trace(
structured_tensor.StructuredTensor.from_pyval({'b': [1]}),
structured_tensor.StructuredTensor.Spec._from_fields_and_rank(
fields={'b': tensor_lib.TensorSpec((1,), dtypes.int32)}, rank=0))
check_trace(
structured_tensor.StructuredTensor.from_pyval({'c': [1]}),
structured_tensor.StructuredTensor.Spec._from_fields_and_rank(
fields={'c': tensor_lib.TensorSpec((1,), dtypes.int32)}, rank=0))
# But if we call again with only shape different, then do relax:
check_trace( # relax & retrace
structured_tensor.StructuredTensor.from_pyval({'a': [1, 2]}),
structured_tensor.StructuredTensor.Spec._from_fields_and_rank(
fields={'a': tensor_lib.TensorSpec((None,), dtypes.int32)},
rank=0))
check_trace( # use relaxed graph
structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3]}), None)
check_trace( # use relaxed graph
structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3, 4]}),
None)
def testInputShapeFunctionRelaxationWithDatasetIterators(self):
# For dataset iterators, the TypeSpec includes type information that's
# not derivable from the component tensors. Make sure that the TypeSpec
# shapes get relaxed as appropriate.
traced_type_spec = [None]
@polymorphic_function.function(reduce_retracing=True)
def func(x):
traced_type_spec[0] = x._type_spec
return x
def check_trace(x, expected_trace):
traced_type_spec[0] = None
func(x)
self.assertEqual(traced_type_spec[0], expected_trace)
ds_1_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([1, 2]))
ds_2_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([2, 2]))
ds_3_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([3, 2]))
ds_4_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([4, 2]))
ds_2_1 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([2, 1]))
check_trace( # shape=[1, 2]: retrace
dataset_ops.make_one_shot_iterator(ds_1_2),
iterator_ops.IteratorSpec(
tensor_lib.TensorSpec([1, 2], dtypes.float32)))
check_trace( # shape=[1, 2]: no retrace (use the [1, 2] graph)
dataset_ops.make_one_shot_iterator(ds_1_2), None)
check_trace( # shape=[2, 2]: relax to [None, 2] and retrace
dataset_ops.make_one_shot_iterator(ds_2_2),
iterator_ops.IteratorSpec(
tensor_lib.TensorSpec([None, 2], dtypes.float32)))
check_trace( # shape=[3, 2]: no retrace (use the [None, 2] graph)
dataset_ops.make_one_shot_iterator(ds_3_2), None)
check_trace( # shape=[4, 2]: no retrace (use the [None, 2] graph)
dataset_ops.make_one_shot_iterator(ds_4_2), None)
check_trace( # shape=[2, 1]: relax to [None, None] and retrace
dataset_ops.make_one_shot_iterator(ds_2_1),
iterator_ops.IteratorSpec(
tensor_lib.TensorSpec([None, None], dtypes.float32)))
def testCapturesVariables(self):
a = variables.Variable(1.0, trainable=False)
b = variables.Variable(1.0)
cc = [None]
@polymorphic_function.function
def f():
c = cc[0]
if c is None:
c = cc[0] = variables.Variable(1.)
return a + b + c + 1
cf = f.get_concrete_function()
c = cc[0]
captured_variables = {v.ref() for v in (a, b, c)}
trainable_variables = {v.ref() for v in (b, c)}
self.assertEqual({v.ref() for v in cf.variables}, captured_variables)
self.assertEqual({v.ref() for v in cf.trainable_variables},
trainable_variables)
self.assertEqual(cf.variables, cf.graph.variables)
self.assertEqual(cf.trainable_variables, cf.graph.trainable_variables)
def testNestedShapeFunctionRelaxation(self):
traced_shape = None
# The inner function will go through shape relaxation because the shapes it
# receives will be [1], [2], [3], ...
@polymorphic_function.function(reduce_retracing=True)
def bar(x_shape):
nonlocal traced_shape
traced_shape = x_shape._shape_tuple()
return x_shape
# The outer function will not go through shape relaxation because the shapes
# it receives will be [1], [[1]], [[[1]]], ...
@polymorphic_function.function(reduce_retracing=True)
def foo(ones):
return bar(array_ops.shape(ones))
self.assertAllEqual(self.evaluate(foo(array_ops.ones([1]))), [1])
self.assertEqual(traced_shape, (1,))
for rank in range(2, 6):
x_shape = self.evaluate(foo(array_ops.ones([1] * rank)))
self.assertAllEqual(x_shape, [1] * rank)
self.assertEqual(traced_shape, (None,))
def testNoHash(self):
@polymorphic_function.function()
def f(_):
return 1.0
with self.assertRaisesRegex(
TypeError, r'Could not generate a generic TraceType'):
f(set([]))
def testBasicGraphMode(self):
matmul = polymorphic_function.function(math_ops.matmul)
@polymorphic_function.function
def sq(a):
return matmul(a, a)
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
out = sq(t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testNestedInputsGraphMode(self):
matmul = polymorphic_function.function(math_ops.matmul)
pair = collections.namedtuple('pair', ['a', 'b'])
@polymorphic_function.function
def a_times_b(inputs):
return matmul(inputs.a['a'], inputs.b['b'])
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
out = a_times_b(pair({'a': t}, {'b': t}))
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testNestedOutputsGraphMode(self):
matmul = polymorphic_function.function(math_ops.matmul)
pair = collections.namedtuple('pair', ['a', 'b'])
@polymorphic_function.function()
def pairs_mul(pair_a, pair_b):
return pair(matmul(pair_a.a, pair_b.a), matmul(pair_a.b, pair_b.b))
a = constant_op.constant([[1.0, 2.0], [1.0, 2.0]])
b = constant_op.constant([[3.0, 4.0], [3.0, 4.0]])
out = pairs_mul(pair(a, b), pair(b, a))
expected = pair(
math_ops.matmul(a, b).numpy(),
math_ops.matmul(b, a).numpy())
self.assertAllClose(out, expected)
def testNestedFunctionGraphNotOutOfDate(self):
@polymorphic_function.function
def f():
return constant_op.constant(1.)
class _Model(object):
@polymorphic_function.function
def g(self):
self.f = f.get_concrete_function()
model = _Model()
model.g()
concrete = model.f
weak_g_graph = weakref.ref(model.g.get_concrete_function().graph)
self.assertIs(weak_g_graph(), concrete.graph.outer_graph)
weak_g = weakref.ref(model.g)
del model
self.assertIsNone(weak_g())
self.assertIsNone(weak_g_graph())
self.assertIsNotNone(concrete.graph.outer_graph)
self.assertIs(ops.get_default_graph(), concrete.graph.outer_graph)
def testBasicGraphFunction(self):
matmul = polymorphic_function.function(math_ops.matmul)
@polymorphic_function.function
def sq(a):
return matmul(a, a)
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
sq_op = sq.get_concrete_function(t)
self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
out = sq_op(t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testGetConcreteFunctionThreadSafety(self):
@polymorphic_function.function
def sq():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
return math_ops.matmul(t, t)
concrete_functions = []
def thread_func(_):
cf = sq.get_concrete_function()
concrete_functions.append(cf)
num_threads = 100
pool = multiprocessing.pool.ThreadPool(num_threads)
_ = pool.map(thread_func, list(range(num_threads)))
self.assertLen(set(concrete_functions), 1)
def testGetConcreteFunctionThreadSafetyWithArgs(self):
@polymorphic_function.function
def add_100(*args):
return math_ops.add_n(args)
p = multiprocessing.pool.ThreadPool(2)
args = (constant_op.constant(1.),) * 100
f1, f2 = p.map(add_100.get_concrete_function, [args] * 2)
# I see about len(args) + max(0, len(args) - 3) arguments expected.
f1(*args)
del f2
def testInputSpecGraphFunction(self):
matmul = polymorphic_function.function(math_ops.matmul)
@polymorphic_function.function
def sq(a):
return matmul(a, a)
sq_op = sq.get_concrete_function(
tensor_lib.TensorSpec((None, None), dtypes.float32))
self.assertEqual([None, None], sq_op.output_shapes.as_list())
t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
out1 = sq_op(t1)
self.assertAllEqual(out1, math_ops.matmul(t1, t1).numpy())
t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
out2 = sq_op(t2)
self.assertAllEqual(out2, math_ops.matmul(t2, t2).numpy())
def testNestedInputSpecGraphFunction(self):
matmul = polymorphic_function.function(math_ops.matmul)
@polymorphic_function.function
def sq(mats):
((a, b),) = mats
return matmul(a, b)
sq_op_autonamed = sq.get_concrete_function([(
tensor_lib.TensorSpec((None, None), dtypes.float32),
tensor_lib.TensorSpec((None, None), dtypes.float32),
)])
self.assertEqual([None, None], sq_op_autonamed.output_shapes.as_list())
sq_op = sq.get_concrete_function([(
tensor_lib.TensorSpec((None, None), dtypes.float32, name='first_mat'),
tensor_lib.TensorSpec((None, None), dtypes.float32, name='second_mat'),
)])
self.assertEqual([None, None], sq_op.output_shapes.as_list())
t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]])
out = sq_op(first_mat=t1, second_mat=t2)
self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy())
self.assertAllEqual(
sq_op_autonamed(t1, t2),
math_ops.matmul(t1, t2).numpy())
def testExecutingStatelessDefunConcurrently(self):
@polymorphic_function.function
def stateless(x):
return math_ops.multiply(2.0, x)
pool = multiprocessing.pool.ThreadPool()
inputs = [constant_op.constant(1.0 * x) for x in range(100)]
outputs = [float(out) for out in pool.map(stateless, inputs)]
expected = [float(2.0 * x) for x in inputs]
self.assertSequenceEqual(outputs, expected)
def testExecutingManyStatelessDefunsConcurrently(self):
@polymorphic_function.function
def stateless(x):
del x
return math_ops.multiply(2.0, 2.0)
pool = multiprocessing.pool.ThreadPool()
# `pool.map` below instantiates 100 functions, one for each object.
objects = [object() for _ in range(100)]
outputs = [float(out) for out in pool.map(stateless, objects)]
expected = [4.0] * 100
self.assertSequenceEqual(outputs, expected)
@test_util.disable_tfrt('b/169431085: This test is flaky on tfrt')
def testExecutingStatefulDefunConcurrently(self):
v = resource_variable_ops.ResourceVariable(1.0)
@polymorphic_function.function
def stateful(x):
v.assign(x)
pool = multiprocessing.pool.ThreadPool()
inputs = [constant_op.constant(0.0)] * 100
pool.map(stateful, inputs)
self.assertEqual(float(v.read_value()), 0.0)
def testExecutingManyStatefulDefunsConcurrently(self):
v = resource_variable_ops.ResourceVariable(1.0)
@polymorphic_function.function
def stateful(x):
del x
return v.assign(0.0)
pool = multiprocessing.pool.ThreadPool()
# `pool.map` below instantiates 100 functions, one for each object.
pool.map(stateful, [object() for _ in range(100)])
self.assertEqual(float(v.read_value()), 0.0)
def testShareRendezvous(self):
# Disable grappler from inlining the functions. Note we run the send & recv
# in graph mode since with eager mode the function should automatically be
# inlined.
context.context().set_optimizer_experimental_options(
{'disable_meta_optimizer': True})
cpu = '/device:CPU:0'
signature = [tensor_lib.TensorSpec([], dtypes.int32)]
@polymorphic_function.function
def send():
x = constant_op.constant(1)
gen_sendrecv_ops.send(x, 'x', cpu, 0, cpu)
return x
send._shared_rendezvous = True # pylint: disable=protected-access
@polymorphic_function.function(input_signature=signature)
def send_body(n):
send()
return n - 1
@polymorphic_function.function
def recv():
return gen_sendrecv_ops.recv(dtypes.int32, 'x', cpu, 0, cpu)
recv._shared_rendezvous = True # pylint: disable=protected-access
@polymorphic_function.function(input_signature=signature)
def recv_body(n):
recv()
return n - 1
@polymorphic_function.function(input_signature=signature)
def cond_fn(n):
return n > 0
# Instead of calling the send & recv functions directly we want to call them
# through a functional while to ensure the rendezvous is shared across the
# while boundary.
@polymorphic_function.function
def fn(n):
functional_ops.While([n], cond_fn.get_concrete_function(),
send_body.get_concrete_function())
return functional_ops.While([n], cond_fn.get_concrete_function(),
recv_body.get_concrete_function())
# Use a graph context since functions will not be automatically inlined
with context.graph_mode(), self.cached_session():
self.evaluate(fn(2))
def disabled_testRandomSeed(self):
@polymorphic_function.function
def f():
return random_ops.random_normal(())
random_seed.set_random_seed(1)
x = f()
self.assertNotEqual(x, f())
random_seed.set_random_seed(1)
self.assertAllEqual(f(), x)
def testNestedInputsGraphFunction(self):
matmul = polymorphic_function.function(math_ops.matmul)
pair = collections.namedtuple('pair', ['a', 'b'])
@polymorphic_function.function
def a_times_b(inputs):
return matmul(inputs.a['a'], inputs.b['b'])
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
sq_op = a_times_b.get_concrete_function(
pair(
dict(a=tensor_lib.TensorSpec([2, 2], dtypes.float32, 'a')),
dict(b=tensor_lib.TensorSpec([2, 2], dtypes.float32, 'b'))))
self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
out = sq_op(a=t, b=t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testNestedOutputGraphFunction(self):
matmul = polymorphic_function.function(math_ops.matmul)
@polymorphic_function.function
def sq(a):
return (matmul(a, a), {'b': constant_op.constant(1.0)})
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
sq_op = sq.get_concrete_function(t)
self.assertEqual(sq_op.output_shapes, (tensor_shape.TensorShape([2, 2]), {
'b': tensor_shape.TensorShape([])
}))
self.assertEqual(sq_op.output_dtypes, (dtypes.float32, {
'b': dtypes.float32
}))
(a, b) = sq_op(t)
self.assertAllEqual(a, math_ops.matmul(t, t).numpy())
self.assertAllEqual(b['b'].numpy(), 1.0)
def testZipStrictBuiltin(self):
major, minor, _ = platform.python_version_tuple()
if not (major == '3' and int(minor) >= 10):
self.skipTest('strict zip is only supported in Python 3.10+')
@polymorphic_function.function
def foo(x):
return list(zip([x], [x], strict=True))
self.assertEqual(foo(2)[0][0].numpy(), 2)
def testGraphFunctionNoneOutput(self):
@polymorphic_function.function
def fn(unused_a, unused_b):
return None
x = constant_op.constant(1)
fn_op = fn.get_concrete_function(x, x)
self.assertEqual(fn_op.output_dtypes, None)
self.assertEqual(fn_op.output_shapes, None)
self.assertAllEqual(fn_op(x, x), None)
def testDefunCapturedInt32(self):
x = constant_op.constant(1, dtype=dtypes.int32)
@polymorphic_function.function
def add_int32s():
return x + x
self.assertEqual(2, int(add_int32s()))
def testDefunReadVariable(self):
v = resource_variable_ops.ResourceVariable(1.0)
@polymorphic_function.function
def f():
return v.read_value()
self.assertEqual(1.0, float(f()))
def testDefunAssignAddVariable(self):
v = resource_variable_ops.ResourceVariable(1.0)
x = constant_op.constant(2.0)
@polymorphic_function.function
def test_assign_add():
v.assign_add(x)
return v.read_value()
self.assertEqual(3.0, float(test_assign_add()))
@test_util.run_in_graph_and_eager_modes
def testTensorInitializationInFunctionRaisesError(self):
@polymorphic_function.function
def tensor_init():
with self.assertRaisesRegex(ValueError, 'could not be lifted out'):
resource_variable_ops.ResourceVariable(constant_op.constant(2.0))
tensor_init()
@test_util.run_in_graph_and_eager_modes
def testCallableTensorInitializationInFunction(self):
@polymorphic_function.function
def tensor_init():
self.v = resource_variable_ops.ResourceVariable(
lambda: constant_op.constant(2.0))
return self.v.read_value()
value = tensor_init()
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
self.assertEqual(self.evaluate(value), 2.0)
@test_util.also_run_as_tf_function
def testInitScopeTensorInitializationInFunction(self):
@polymorphic_function.function
def tensor_init():
with ops.init_scope():
const = constant_op.constant(2.0)
# Note: this variable bypasses tf.function's variable creation
# requirements by bypassing variable_creator_scope by using
# ResourceVariable instead of Variable.
self.v = resource_variable_ops.ResourceVariable(const)
return self.v.read_value()
value = tensor_init()
self.assertAllEqual(value, 2.0)
@test_util.run_in_graph_and_eager_modes
def testGetConcreteFunctionCreatesVariables(self):
v_holder = []
@polymorphic_function.function
def tensor_init():
if not v_holder:
v_holder.append(variables.Variable(5.))
return v_holder[0].read_value()
concrete = tensor_init.get_concrete_function()
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual(5., self.evaluate(concrete()))
self.assertAllEqual(5., self.evaluate(tensor_init()))
def testDefunShapeInferenceWithCapturedResourceVariable(self):
v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
# We do not return v directly since the tensor conversion function of
# ResourceVariable returns the read value and not the resource itself.
return v._handle
compiled = polymorphic_function.function(f)
var_handle = compiled()
self.assertEqual(var_handle.dtype, dtypes.resource)
self.assertEqual(var_handle.shape, tensor_shape.TensorShape([]))
var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
def testShapeInferenceForMoreSpecificInput(self):
def f(a):
return array_ops.reshape(a, [-1, 3])
signature = [tensor_lib.TensorSpec(None, dtypes.float32)]
compiled = polymorphic_function.function(f, input_signature=signature)
@polymorphic_function.function
def use_f():
inputs = array_ops.zeros([10, 10, 3])
self.assertAllEqual(f(inputs).shape, compiled(inputs).shape)
use_f()
def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self):
with context.graph_mode():
v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
# We do not return v directly since the tensor conversion function of
# ResourceVariable returns the read value and not the resource itself.
return v._handle
compiled = polymorphic_function.function(f)
var_handle = compiled()
self.assertEqual(var_handle.dtype, dtypes.resource)
self.assertEqual(var_handle.shape, tensor_shape.TensorShape([]))
var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
def testDefunShapeInferenceWithCapturedVariableInGraphMode(self):
with context.graph_mode():
v = variables.Variable([[1, 2], [3, 4]])
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
# Check that shape inference works while creating the defun
compiled = polymorphic_function.function(f)
compiled()
def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self):
with context.graph_mode():
tensor_list = list_ops.empty_tensor_list(
element_dtype=dtypes.float32,
element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
tensor_list = list_ops.tensor_list_push_back(tensor_list,
constant_op.constant(1.0))
tensor_list = list_ops.tensor_list_push_back(tensor_list,
constant_op.constant(2.0))
def f():
tl, value = list_ops.tensor_list_pop_back(
tensor_list, element_dtype=dtypes.float32)
self.assertEqual(value.shape, tensor_shape.TensorShape([]))
return tl
compiled = polymorphic_function.function(f)
output_tensor_list = compiled()
_, value = list_ops.tensor_list_pop_back(
output_tensor_list, element_dtype=dtypes.float32)
self.assertEqual(value.shape, tensor_shape.TensorShape([]))
def testRunMetadata(self):
@polymorphic_function.function
def f(x):
return x * x
with ops.device('cpu:0'):
context.enable_run_metadata()
f(constant_op.constant(1.0))
run_metadata = context.export_run_metadata()
context.disable_run_metadata()
self.assertLen(run_metadata.partition_graphs, 1)
def testGraphModeCaptureVariable(self):
with context.graph_mode(), self.cached_session():
class HasAVar:
def __init__(self):
self.v = resource_variable_ops.ResourceVariable(1.0)
def call(self):
return self.v * 2
o = HasAVar()
self.evaluate(variables.global_variables_initializer())
call = polymorphic_function.function(o.call)
op = call()
self.assertAllEqual(self.evaluate(op), 2.0)
def testGraphModeManyFunctions(self):
with ops.Graph().as_default(), self.cached_session():
@polymorphic_function.function
def f(x):
return x * x
@polymorphic_function.function
def g(x):
return f(x) + 1
self.assertAllEqual(g(constant_op.constant(2.0)), 5.0)
def testDict(self):
@polymorphic_function.function
def f(x):
return {'name': x + 1}
self.assertAllEqual(f(constant_op.constant(1.0))['name'], 2.0)
def testWeakrefInputsRejected(self):
@polymorphic_function.function
def f(x):
return x
class Dummy:
pass
o = Dummy()
wr = weakref.ref(o)
with self.assertRaisesRegex(TypeError, 'weakref'):
f(wr)
def testTensorConversionWithDefun(self):
@polymorphic_function.function
def f(x):
return math_ops.add(x, constant_op.constant(3))
self.assertAllEqual(5, f(constant_op.constant(2)))
def testTensorConversionCall(self):
@polymorphic_function.function
def f(x):
return math_ops.add(x, constant_op.constant(3))
@polymorphic_function.function
def g(x):
return f(f(x))
self.assertAllEqual(8, g(constant_op.constant(2)))
def testCallShape(self):
@polymorphic_function.function
def f(x):
return x + 1
@polymorphic_function.function
def g(x):
x = f(x)
self.assertEqual(x.shape.as_list(), [])
return None
g(constant_op.constant(1.0))
def testNestedDefunWithNoOutputAndTapedInput(self):
three = resource_variable_ops.ResourceVariable(3.0, name='v')
@polymorphic_function.function
def f(x):
# This function intentionally takes a taped variable as input,
# but does not return any values
math_ops.add(x, three)
@polymorphic_function.function
def g(x):
y = math_ops.add(x, three)
f(y)
g(three)
def testGatherResourceWithDefun(self):
with ops.device('cpu:0'):
v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
def sum_gather():
return math_ops.reduce_sum(array_ops.gather(v, [1, 2]))
defined = polymorphic_function.function(sum_gather)
self.assertAllEqual(sum_gather(), defined())
@parameterized.named_parameters([
('IndexedSlicesWithDenseShape',
_example_indexed_slices_with_dense_shape,),
('IndexedSlicesWithoutDenseShape',
_example_indexed_slices_without_dense_shape,),
('RaggedTensorRaggedRank1', ragged_tensor.RaggedTensor.from_row_lengths,
{'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}),
('RaggedTensorRaggedRank2',
ragged_tensor.RaggedTensor.from_nested_row_lengths,
{'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}),
('SparseTensor', sparse_tensor.SparseTensor,
{'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}),
]) # pyformat: disable
def testReturnCompositeTensorWithDefun(self,
factory_fn,
factory_kwargs={},
input_signature=None):
input_ct = factory_fn(**factory_kwargs)
@polymorphic_function.function(input_signature=input_signature)
def f():
return input_ct
output_ct = f()
self.assertIsInstance(output_ct, type(input_ct))
nest.assert_same_structure(input_ct, output_ct, expand_composites=True)
input_flat = nest.flatten(input_ct, expand_composites=True)
output_flat = nest.flatten(output_ct, expand_composites=True)
for (input_component, output_component) in zip(input_flat, output_flat):
self.assertAllEqual(input_component, output_component)
@parameterized.named_parameters([
('IndexedSlicesWithDenseShape',
_example_indexed_slices_with_dense_shape,),
('IndexedSlicesWithoutDenseShape',
_example_indexed_slices_without_dense_shape,),
('RaggedTensorRaggedRank1',
ragged_tensor.RaggedTensor.from_row_lengths,
{'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}),
('RaggedTensorRaggedRank2',
ragged_tensor.RaggedTensor.from_nested_row_lengths,
{'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}),
('SparseTensor',
sparse_tensor.SparseTensor,
{'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}),
('RaggedTensorRaggedRank1WithSignature',
ragged_tensor.RaggedTensor.from_row_lengths,
{'values': [1, 2, 3], 'row_lengths': [2, 0, 1]},
[ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)]),
('RaggedTensorRaggedRank2WithSignature',
ragged_tensor.RaggedTensor.from_nested_row_lengths,
{'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]},
[ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.int32)]),
('SparseTensorWithSignature',
sparse_tensor.SparseTensor,
{'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]},
[sparse_tensor.SparseTensorSpec([None], dtypes.int32)]),
]) # pyformat: disable
def testCompositeAsArgumentTensorWithDefun(self,
factory_fn,
factory_kwargs={},
input_signature=None):
input_ct = factory_fn(**factory_kwargs)
@polymorphic_function.function(input_signature=input_signature)
def f(x):
return x
output_ct = f(input_ct)
self.assertIsInstance(output_ct, type(input_ct))
nest.assert_same_structure(input_ct, output_ct, expand_composites=True)
input_flat = nest.flatten(input_ct, expand_composites=True)
output_flat = nest.flatten(output_ct, expand_composites=True)
for (input_component, output_component) in zip(input_flat, output_flat):
self.assertAllEqual(input_component, output_component)
def testTracedCompositeDiscardsShapeInfo(self):
# SparseTensorSpec intentionally excludes info about the number of elements
# that are in a sparse tensor (which is recorded as st.indices.shape[0] and
# st.values.shape[0]). Similarly, RaggedTensorSpec intentionally excludes
# info about the total number of values in a RaggedTensor (stored as
# rt.values.shape[0]). This test checks that the placeholders created by
# tf.function() properly mask this shape info.
@polymorphic_function.function
def f(rt, st):
self.assertEqual(st.indices.shape.as_list()[:1], [None])
self.assertEqual(st.values.shape.as_list(), [None])
return (rt, st)
rt = ragged_factory_ops.constant([[1, 2], [3]])
st = sparse_tensor.SparseTensor([[0]], [0], [10])
f(rt, st)
@test_util.run_gpu_only
def testFunctionOnDevice(self):
x = constant_op.constant([1.]).gpu()
f = polymorphic_function.function(math_ops.add)
y = f(x, x).cpu()
self.assertAllEqual(y, [2.])
@test_util.run_gpu_only
@test_util.run_in_graph_and_eager_modes
def testOpInFunctionWithConflictingResourceInputs(self):
with ops.device('/cpu:0'):
v_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0],
name='cpu')
v_also_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0],
name='also_cpu')
with ops.device('/gpu:0'):
v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0],
name='gpu')
@polymorphic_function.function
def resource_apply_adam():
gen_training_ops.resource_apply_adam(
v_cpu.handle,
v_gpu.handle,
v_also_cpu.handle,
1.0, # beta1_power
1.0, # beta2_power
1.0, # learning_rate
1.0, # beta1
1.0, # beta2
1.0, # epsilon,
[1.0, 1.0, 1.0], # grad
False) # use_locking
return 1
with self.assertRaisesRegex(
errors.InvalidArgumentError,
'Cannot place the graph because a reference or resource edge connects '
'colocation groups with incompatible assigned devices'):
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
self.evaluate(resource_apply_adam())
@test_util.run_gpu_only
def testFunctionHandlesInputsOnDifferentDevices(self):
# The Reshape op requires the shape tensor to be placed in host memory.
reshape = polymorphic_function.function(array_ops.reshape)
value = constant_op.constant([1., 2.]).gpu()
shape = constant_op.constant([2, 1])
reshaped = reshape(value, shape).cpu()
self.assertAllEqual(reshaped, [[1], [2]])
@test_util.run_gpu_only
def testFunctionHandlesInputsPlacedOnTheWrongDeviceGracefully(self):
# The Reshape op requires the shape tensor to be placed in host memory.
reshape = polymorphic_function.function(array_ops.reshape)
value = constant_op.constant([1., 2.])
shape = constant_op.constant([2, 1]).gpu()
reshape(value, shape) # No error is raised
def testNoneOutput(self):
@polymorphic_function.function
def my_function(_):
return None
self.assertAllEqual(my_function(1), None)
def testNestedFunctions(self):
# TensorFlow function (which is what would be used in TensorFlow graph
# construction).
@tf_function.Defun(dtypes.int32, dtypes.int32)
def add(a, b):
return math_ops.add(a, b)
@polymorphic_function.function
def add_one(x):
return add(x, 1)
self.assertAllEqual(3, add_one(constant_op.constant(2)))
def testVariableCaptureInNestedFunctions(self):
v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int32)
@polymorphic_function.function
def inner_read():
return v.read_value()
@polymorphic_function.function
def outer():
return inner_read()
self.assertEqual(1, int(outer()))
def testReturnCapturedEagerTensor(self):
t = constant_op.constant(1)
@polymorphic_function.function
def read():
return t
self.assertEqual(1, int(read()))
def testReturnCapturedGraphTensor(self):
with context.graph_mode(), self.cached_session():
t = constant_op.constant(1)
@polymorphic_function.function
def read():
return t
self.assertEqual(1, int(self.evaluate(read())))
def testConcreteFunctionType(self):
y = constant_op.constant(1)
@polymorphic_function.function
def foo(x):
return {'input': x, 'capture': y}
cf = foo.get_concrete_function(tensor_lib.TensorSpec([], dtypes.int32))
x = constant_op.constant(2)
output = cf(x)
self.assertEqual(set(output.keys()), {'input', 'capture'})
self.assertEqual(output['input'].numpy(), 2)
self.assertEqual(output['capture'].numpy(), 1)
parameters = list(cf.function_type.parameters.values())
self.assertLen(parameters, 1)
self.assertEqual(parameters[0].name, 'x')
self.assertEqual(
parameters[0].type_constraint,
tensor_lib.TensorSpec([], dtypes.int32),
)
captures = cf.function_type.captures
self.assertLen(captures, 1)
self.assertEqual(captures[id(y)], tensor_lib.TensorSpec([], dtypes.int32))
output = cf.function_type.output
self.assertEqual(output, trace_type.from_value({'input': x, 'capture': y}))
def testSequenceInputs(self):
clip_by_global_norm = polymorphic_function.function(
clip_ops.clip_by_global_norm)
t_list = [constant_op.constant(1.0), constant_op.constant(2.0)]
clipped_list, global_norm = clip_by_global_norm(t_list,
constant_op.constant(.2))
for t in clipped_list:
self.assertIsInstance(t, tensor_lib.Tensor)
self.assertIsInstance(global_norm, tensor_lib.Tensor)
def testNestedSequenceInputs(self):
def my_op(inputs):
a, b, c = inputs
e, f = b
g, h = e
return [a + a, [tuple([f + f, g + g]), h + h], c + c], a + f + g + h + c
my_eager_op = polymorphic_function.function(my_op)
ret = my_eager_op([
constant_op.constant(1),
[(constant_op.constant(2), constant_op.constant(3)),
constant_op.constant(4)],
constant_op.constant(5)
])
self.assertLen(ret, 2)
self.assertAllEqual(ret[0][0], 2)
self.assertAllEqual(ret[0][1][0][0], 8)
self.assertAllEqual(ret[0][1][0][1], 4)
self.assertIsInstance(ret[0][1][0], tuple)
self.assertAllEqual(ret[0][1][1], 6)
self.assertAllEqual(ret[0][2], 10)
self.assertAllEqual(ret[1], 15)
def testVariableNamesRespectNameScopesWithDefun(self):
@polymorphic_function.function
def create_variable():
with ops.name_scope('foo', skip_on_eager=False):
v = resource_variable_ops.ResourceVariable(0.0, name='bar')
self.assertEqual(v.name, 'foo/bar:0')
create_variable()
def testVariableNamesRespectNameScopesWithDefunInGraph(self):
with context.graph_mode():
@polymorphic_function.function
def create_variable():
with ops.name_scope('foo', skip_on_eager=False):
v = resource_variable_ops.ResourceVariable([1.0, 2.0], name='bar')
self.assertEqual(v.name, 'foo/bar:0')
with ops.get_default_graph().as_default():
create_variable()
@test_util.run_in_graph_and_eager_modes
def testVariablesPlacedOnOutsideDevice(self):
class _Obj(object):
def __init__(self):
self.v = None
@polymorphic_function.function
def f(self):
if self.v is None:
self.v = variables.Variable(1.)
return self.v + 1.
has_device = _Obj()
with ops.device('cpu:0'):
has_device.f()
self.assertIn('CPU', has_device.v.device)
@test_util.run_in_graph_and_eager_modes
def testCallingGraphFunctionOnDifferentDevice(self):
def func():
return constant_op.constant(0)
defined = polymorphic_function.function(func)
with ops.device('cpu:0'):
cpu_graph_function = defined.get_concrete_function()
with ops.device('cpu:0'):
self.assertEqual(
self.evaluate(cpu_graph_function()), self.evaluate(func()))
with ops.device('cpu:1'):
self.assertEqual(0., self.evaluate(cpu_graph_function()))
with ops.device(None):
self.assertEqual(0., self.evaluate(cpu_graph_function()))
default_graph_function = defined.get_concrete_function()
self.assertEqual(
self.evaluate(default_graph_function()), self.evaluate(func()))
with ops.device('cpu:1'):
self.assertEqual(0., self.evaluate(default_graph_function()))
@test_util.run_gpu_only
@test_util.run_in_graph_and_eager_modes
def testColocateWithRespected(self):
# TODO(b/113291792): Use multiple CPUs instead of a GPU.
with ops.device('cpu:0'):
x = array_ops.identity(1.0)
with ops.device('gpu:0'):
y = array_ops.identity(1.0)
@polymorphic_function.function
def foo():
return test_ops.device_placement_op()
with ops.colocate_with(x):
self.assertIn(compat.as_bytes('CPU:0'), self.evaluate(foo()))
with ops.colocate_with(y):
self.assertIn(compat.as_bytes('GPU:0'), self.evaluate(foo()))
@parameterized.parameters([(True), (False)])
def testVariablesAreTracked(self, reduce_retracing):
v = resource_variable_ops.ResourceVariable(1.0)
def foo(x):
return v * x
defined = polymorphic_function.function(
foo, reduce_retracing=reduce_retracing)
x = constant_op.constant([1.0])
self.assertEqual(1., self.evaluate(defined(x)))
v.assign(2.)
x = constant_op.constant([1.0, 2.0])
self.assertAllEqual([2., 4.], self.evaluate(defined(x)))
def testInputSignatureMustBeSequenceOfTensorSpecs(self):
def foo(a, b):
del a
del b
# Signatures must consist exclusively of `TensorSpec` objects.
signature = [(2, 3), tensor_lib.TensorSpec([2, 3], dtypes.float32)]
with self.assertRaisesRegex(TypeError, 'input_signature.*nested sequence'):
polymorphic_function.function(foo, input_signature=signature)
@test_util.run_in_graph_and_eager_modes
def testInputsIncompatibleWithSignatureRaisesError(self):
def foo(a):
return a
signature = [tensor_lib.TensorSpec(shape=(2,), dtype=dtypes.float32)]
defined = polymorphic_function.function(foo, input_signature=signature)
# Valid call
defined(array_ops.ones([2]))
# Invalid shapes.
with self.assertRaisesRegex(
TypeError, r'Can not cast .*dtype=tf.int32.* to .*dtype=tf.float32.*'
):
defined(array_ops.ones([3], dtype=dtypes.int32))
# Invalid shapes.
with self.assertRaisesRegex(TypeError, 'Can not cast.*'):
defined(array_ops.ones([3]))
with self.assertRaisesRegex(TypeError, 'Can not cast.*'):
defined(array_ops.ones([2, 1]))
# Wrong number of arguments.
with self.assertRaisesRegex(TypeError, 'too many positional arguments'):
defined(array_ops.ones([2]), array_ops.ones([2]))
with self.assertRaisesRegex(TypeError, 'missing a required argument'):
defined()
with self.assertRaisesRegex(
TypeError, r'Can not cast .*shape=\(3,\).* to .*shape=\(2,\).*'
):
defined.get_concrete_function(
tensor_lib.TensorSpec(shape=(3,), dtype=dtypes.float32))
def testMismatchedConcreteSignatureRaisesError(self):
@polymorphic_function.function
def run_test():
@polymorphic_function.function
def f(x):
return x
with self.assertRaisesRegex(
TypeError, 'Binding inputs to tf.function failed .*'):
f.get_concrete_function(1)(constant_op.constant(1))
f.get_concrete_function(constant_op.constant(1))(1)
with self.assertRaisesRegex(
TypeError, 'Binding inputs to tf.function failed .*'):
f.get_concrete_function(1)(2)
run_test()
def testInputSignatureConversionWithDefaultArg(self):
def foo(a, training=True):
if training:
return a
else:
return -1.0 * a
signature = [
tensor_lib.TensorSpec([], dtypes.float32),
tensor_lib.TensorSpec([], dtypes.bool),
]
defined = polymorphic_function.function(foo, input_signature=signature)
a = constant_op.constant(1.0)
self.assertAllEqual(a.numpy(), defined(a))
self.assertAllEqual(a.numpy(), defined(a, training=True))
self.assertAllEqual(-a.numpy(), defined(a, training=False))
def testVariableSpecWithInputSignature(self):
def f(v):
v.assign_add(1)
signature = [
resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32)
]
with self.assertRaisesRegex(TypeError,
"input_signature doesn't support VariableSpec"):
polymorphic_function.function(f, input_signature=signature)
def testDefuningInstanceMethod(self):
integer = constant_op.constant(2, dtypes.int64)
class Foo:
def one(self, tensor):
return tensor
@polymorphic_function.function
def two(self, tensor, other=integer):
return self.one(tensor), other
foo = Foo()
t = constant_op.constant(1.0)
one, two = foo.two(t)
self.assertEqual(one.numpy(), 1.0)
self.assertEqual(two.numpy(), 2)
def testDefuningInstanceMethodWithDefaultArgument(self):
integer = constant_op.constant(2, dtypes.int64)
class Foo:
@polymorphic_function.function
def func(self, other=integer):
return other
foo = Foo()
self.assertEqual(foo.func().numpy(), int(integer))
def testPythonCallWithSideEffects(self):
state = []
@polymorphic_function.function
def side_effecting_function():
state.append(0)
side_effecting_function()
self.assertAllEqual(state, [0])
# The second invocation should call the graph function, which shouldn't
# trigger the list append.
side_effecting_function()
self.assertAllEqual(state, [0])
# Whereas calling the python function directly should create a side-effect.
side_effecting_function.python_function()
self.assertAllEqual(state, [0, 0])
def testFunctionWithNestedFunctionCallAndSideEffects(self):
v1 = variables.Variable(1.0)
v2 = variables.Variable(1.0)
@polymorphic_function.function
def add_one(a):
a.assign_add(1.0)
# Grappler will inline calls to `add_one` into the function body, we check
# that all side-effects were executed.
@polymorphic_function.function
def side_effecting_function(a, b):
add_one(a)
add_one(b)
return a + b
result = side_effecting_function(v1, v2)
self.assertEqual(result.numpy(), 4.0)
def testRegisterConcreteFunction(self):
@polymorphic_function.function
def py_add(x, y):
return math_ops.add(x, y)
py_add(array_ops.ones([]), array_ops.ones([]))
add = py_add.get_concrete_function(
tensor_lib.TensorSpec(None, dtypes.float32),
tensor_lib.TensorSpec(None, dtypes.float32))
@polymorphic_function.function
def py_composite(x, y):
return x, add(x, y)
py_composite(array_ops.ones([]), array_ops.ones([]))
composite = py_composite.get_concrete_function(
tensor_lib.TensorSpec(None, dtypes.float32),
tensor_lib.TensorSpec(None, dtypes.float32))
with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
composite.add_to_graph()
composite.add_gradient_functions_to_graph()
graph = ops.get_default_graph()
# pylint: disable=protected-access
self.assertLen(graph._functions, 6)
# two sets of functions, each of them are (inference, forward, backward)
functions = list(graph._functions.values())
captured_function_names = [
f.cached_definition.signature.name for f in functions
]
expected_func_name_regex = [
'.*inference.*py_composite.*',
'.*inference.*py_add.*',
'.*forward.*py_composite.*',
'.*forward.*py_add.*',
'.*inference.*backward.*py_composite.*',
'.*inference.*backward.*py_add.*',
]
for expected, found in zip(expected_func_name_regex,
captured_function_names):
self.assertRegex(found, expected)
composite_t, composite_double = composite(t, t)
double = add(t, t)
self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(double))
self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(composite_double))
self.assertAllEqual([[1, 2], [3, 4]], self.evaluate(composite_t))
# Make sure the pre registered function is used, and no other function
# is added.
self.assertLen(graph._functions, 6)
def testEagerCaptures(self):
with context.eager_mode():
large_tensor = array_ops.ones(shape=(256,))
self.assertGreater(256, capture_container._EAGER_CONST_THRESHOLD)
small_tensor = array_ops.ones(shape=(4,))
self.assertLessEqual(4, capture_container._EAGER_CONST_THRESHOLD)
v = resource_variable_ops.ResourceVariable(0.0)
for captured, op_type in [(large_tensor, 'Placeholder'),
(small_tensor, 'Const'), (v, 'Placeholder')]:
@polymorphic_function.function
def test_fn():
return captured + 1 # pylint: disable=cell-var-from-loop
g = test_fn.get_concrete_function().graph
internal_captures = g.internal_captures
self.assertLen(internal_captures, 1)
self.assertEqual(internal_captures[0].op.type, op_type)
@parameterized.parameters([(True), (False)])
def testVariableAliasIdInStructuredInputSignature(self, reduce_retracing):
@polymorphic_function.function(reduce_retracing=reduce_retracing)
def foo(v1, v2):
return v1 + v2
v1 = resource_variable_ops.ResourceVariable(1.0)
v2 = resource_variable_ops.ResourceVariable(2.0)
graph_function = foo.get_concrete_function(v1, v1)
args_sig, _ = graph_function.graph.structured_input_signature
expected_spec = resource_variable_ops.VariableSpec([], alias_id=0)
self.assertLen(args_sig, 2)
self.assertEqual(args_sig[0], expected_spec)
self.assertEqual(args_sig[1], expected_spec)
graph_function = foo.get_concrete_function(v1, v2)
args_sig, _ = graph_function.graph.structured_input_signature
expected_spec1 = resource_variable_ops.VariableSpec([], alias_id=0)
expected_spec2 = resource_variable_ops.VariableSpec([], alias_id=1)
self.assertLen(args_sig, 2)
self.assertEqual(args_sig[0], expected_spec1)
self.assertEqual(args_sig[1], expected_spec2)
def testStructuredSignatureAndMultipleVariables(self):
self.skipTest('b/209081027: Enable this test after Variable becomes a '
'CompositeTensor and Variable gets expand to handle tensor.')
@polymorphic_function.function
def foo(v1, v2):
return v1 + v2
v1 = resource_variable_ops.ResourceVariable(1.0)
v2 = resource_variable_ops.ResourceVariable(2.0)
graph_function = foo.get_concrete_function(v1, v1)
self.assertAllEqual(graph_function(v1, v1), 2.0)
with self.assertRaises(TypeError):
graph_function(v1, v2)
def _total_function_cache_def_func(self, defined):
return defined._list_all_concrete_functions() # pylint: disable=protected-access
@parameterized.parameters([(True), (False)])
def testVariableRetracingOnDtypeChanges(self, reduce_retracing):
@polymorphic_function.function(reduce_retracing=reduce_retracing)
def defined(a, b):
return a + b
x1 = resource_variable_ops.ResourceVariable(0.0)
x2 = resource_variable_ops.ResourceVariable(0.0)
defined(x1, x2)
self.assertLen(self._total_function_cache_def_func(defined), 1)
# Should expect retracing for new dtypes
y1 = resource_variable_ops.ResourceVariable(0)
y2 = resource_variable_ops.ResourceVariable(1)
defined(y1, y2)
self.assertLen(self._total_function_cache_def_func(defined), 2)
def testVariableRetracingDtypeShape(self):
@polymorphic_function.function
def defined(a, b):
return a + b
x1 = resource_variable_ops.ResourceVariable(0.0)
x2 = resource_variable_ops.ResourceVariable(0.0)
defined(x1, x2)
self.assertLen(self._total_function_cache_def_func(defined), 1)
y1 = resource_variable_ops.ResourceVariable([0.0, 1.0])
y2 = resource_variable_ops.ResourceVariable([0.0, 1.0])
defined(y1, y2)
self.assertLen(self._total_function_cache_def_func(defined), 2)
z1 = resource_variable_ops.ResourceVariable([[0.0, 1.0]])
z2 = resource_variable_ops.ResourceVariable([[0.0, 1.0]])
defined(z1, z2)
self.assertLen(self._total_function_cache_def_func(defined), 3)
def testFunctionModifiesInputList(self):
# Tests on `list` methods that do in place modification, except `list.sort`
# since it cannot even be "defunned" in the first place
def get_list():
return [constant_op.constant(0.), constant_op.constant(1.)]
expected_msg = '.*() should not modify'
with self.assertRaisesRegex(ValueError, expected_msg):
@polymorphic_function.function
def append(l):
l.append(constant_op.constant(0.))
append(get_list())
with self.assertRaisesRegex(ValueError, expected_msg):
@polymorphic_function.function
def extend(l):
l.extend([constant_op.constant(0.)])
extend(get_list())
with self.assertRaisesRegex(ValueError, expected_msg):
@polymorphic_function.function
def insert(l):
l.insert(0, constant_op.constant(0.))
insert(get_list())
with self.assertRaisesRegex(ValueError, expected_msg):
@polymorphic_function.function
def pop(l):
l.pop()
pop(get_list())
with self.assertRaisesRegex(ValueError, expected_msg):
@polymorphic_function.function
def reverse(l):
l.reverse()
reverse(get_list())
with self.assertRaisesRegex(ValueError, expected_msg):
@polymorphic_function.function
def remove(l):
l.remove(l[0])
remove(get_list())
# `list.clear` is a method that is in Py3 but not Py2
if sys.version.startswith('3'):
with self.assertRaisesRegex(ValueError, expected_msg):
@polymorphic_function.function
def clear(l):
l.clear()
clear(get_list())
# One last test for keyword arguments
with self.assertRaisesRegex(ValueError, expected_msg):
@polymorphic_function.function
def kwdappend(**kwargs):
l = kwargs['l']
l.append(constant_op.constant(0.))
kwdappend(l=get_list())
def testFunctionModifiesInputDict(self):
def get_dict():
return {'t1': constant_op.constant(0.), 't2': constant_op.constant(1.)}
expected_msg = '.* should not modify'
with self.assertRaisesRegex(ValueError, expected_msg):
@polymorphic_function.function
def clear(m):
m.clear()
clear(get_dict())
with self.assertRaisesRegex(ValueError, expected_msg):
@polymorphic_function.function
def pop(m):
m.pop('t1')
pop(get_dict())
with self.assertRaisesRegex(ValueError, expected_msg):
@polymorphic_function.function
def popitem(m):
m.popitem()
popitem(get_dict())
with self.assertRaisesRegex(ValueError, expected_msg):
@polymorphic_function.function
def update(m):
m.update({'t1': constant_op.constant(3.)})
update(get_dict())
with self.assertRaisesRegex(ValueError, expected_msg):
@polymorphic_function.function
def setdefault(m):
m.setdefault('t3', constant_op.constant(3.))
setdefault(get_dict())
def testFunctionModifiesInputNest(self):
with self.assertRaisesRegex(ValueError, 'modify.* should not modify'):
@polymorphic_function.function
def modify(n):
n[0]['t1'].append(constant_op.constant(1.))
nested_input = [{
't1': [constant_op.constant(0.),
constant_op.constant(1.)],
},
constant_op.constant(2.)]
modify(nested_input)
with self.assertRaisesRegex(ValueError,
'modify_same_flat.* should not modify'):
# The flat list doesn't change whereas the true structure changes
@polymorphic_function.function
def modify_same_flat(n):
n[0].append(n[1].pop(0))
nested_input = [[constant_op.constant(0.)],
[constant_op.constant(1.),
constant_op.constant(2.)]]
modify_same_flat(nested_input)
def testFunctionStackInErrorMessage(self):
if context.executing_eagerly():
# TODO(b/122736651): Remove this skipTest once fixed.
self.skipTest('Error interpolation is not working when function is '
'invoked without PartitionedCallOp.')
@polymorphic_function.function()
def fn3(x):
return x + 2
@polymorphic_function.function()
def fn2(x):
check_ops.assert_equal(fn3(x), 3)
return 2
@polymorphic_function.function()
def fn(x):
return fn2(x)
with self.assertRaises(errors.InvalidArgumentError) as cm:
fn(2)
e = cm.exception
self.assertIn('fn -> fn2', e.message)
self.assertIn('node assert_equal/Assert/Assert (defined at', e.message)
self.assertNotIn('fn3', e.message)
@test_util.run_gpu_only
def testFunctionIsNotPinned(self):
"""Tests that functions aren't pinned to the CPU by the eager runtime."""
seed1, seed2 = 79, 25
shape = constant_op.constant([4, 7])
dtype = dtypes.float32
@polymorphic_function.function
def func():
with ops.device('GPU:0'):
return gen_random_ops.random_standard_normal(
shape, dtype=dtype, seed=seed1, seed2=seed2)
with ops.device('GPU:0'):
x = func()
self.assertRegex(x.device, 'GPU')
def testLimitedRetracingWithCompositeTensors(self):
trace_count = [0]
@polymorphic_function.function
def f(x):
trace_count[0] += 1
return x
for i in range(10):
f(ragged_factory_ops.constant([[1, 2], [i]]))
f(ragged_factory_ops.constant([[1, 2], [], [3, 4, 5]]))
f(ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5, 6]]]))
self.assertEqual(trace_count[0], 3)
def testCompositeTensorsWithReducedRetracing(self):
inp = ragged_factory_ops.constant([[1, 2], [3]])
@polymorphic_function.function(reduce_retracing=True)
def f(x):
return x
output = f(inp)
self.assertTrue(math_ops.reduce_all(math_ops.equal(inp, output)))
def testMultipleInputsWithReducedRetracing(self):
tensor1 = ragged_factory_ops.constant([[1, 2], [3]])
tensor2 = ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5, 6]]])
variable1 = variables.Variable(1.0)
variable2 = variables.Variable(2.0)
@polymorphic_function.function(reduce_retracing=True)
def f(a, b, c, d):
return [a, b, c, d]
output = f(tensor1, tensor2, variable1, variable2)
self.assertTrue(math_ops.reduce_all(math_ops.equal(tensor1, output[0])))
self.assertTrue(math_ops.reduce_all(math_ops.equal(tensor2, output[1])))
self.assertTrue(math_ops.reduce_all(math_ops.equal(variable1, output[2])))
self.assertTrue(math_ops.reduce_all(math_ops.equal(variable2, output[3])))
def test_concrete_function_shape_mismatch(self):
@polymorphic_function.function
def f(argument_name):
return argument_name + 1.
f_concrete = f.get_concrete_function(constant_op.constant([1.]))
# Calling a function from eager doesn't do any shape checking above what
# kernels do while executing.
self.assertAllEqual([2., 3.],
f_concrete(constant_op.constant([1., 2.])).numpy())
@polymorphic_function.function
def g():
f_concrete(constant_op.constant([1., 2.]))
with self.assertRaisesRegex(
TypeError,
r'Can not cast TensorSpec\(shape=\(2,\).* to TensorSpec\(shape=\(1,\)',
):
g()
@test_util.run_in_graph_and_eager_modes
def test_shape_inference_with_symbolic_shapes(self):
@polymorphic_function.function
def _uses_symbolic_shapes(w, x, y):
x = array_ops.identity(x, name='name_collision')
x = array_ops.transpose(x, [1, 0, 2])
x_batch = array_ops.shape(x)[0]
y_batch = array_ops.shape(y)[0]
y *= w
n = y_batch // x_batch
return array_ops.reshape(y, [n, x_batch, -1])
conc = _uses_symbolic_shapes.get_concrete_function(
tensor_lib.TensorSpec(None, dtypes.float32),
tensor_lib.TensorSpec(None, dtypes.float32),
tensor_lib.TensorSpec(None, dtypes.float32))
@polymorphic_function.function
def _call_concrete():
c = constant_op.constant(1.)
array_ops.identity(c, name='name_collision')
output1 = conc(
array_ops.ones([2]), array_ops.ones([5, 4, 2]),
array_ops.ones([20, 2]))
self.assertEqual([5, 4, 2], output1.shape)
output2 = conc(
array_ops.ones([3]), array_ops.ones([5, 4, 3]),
array_ops.ones([40, 3]))
self.assertEqual([10, 4, 3], output2.shape)
return output1, output2
output1, output2 = _call_concrete()
self.assertEqual((5, 4, 2), self.evaluate(output1).shape)
self.assertEqual((10, 4, 3), self.evaluate(output2).shape)
def testAutoGraphContext(self):
@polymorphic_function.function
def test_fn():
self.assertEqual(ag_ctx.control_status_ctx().status,
ag_ctx.Status.ENABLED)
prev_status = ag_ctx.control_status_ctx().status
test_fn()
self.assertEqual(ag_ctx.control_status_ctx().status, prev_status)
@test_util.disable_tfrt('b/170435618')
def testCancelBeforeFunctionExecution(self):
if not context.executing_eagerly():
self.skipTest('eager only')
q = data_flow_ops.FIFOQueue(1, dtypes.int32)
@polymorphic_function.function
def f():
return q.dequeue()
c_mgr = cancellation.CancellationManager()
cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
c_mgr.start_cancel()
with self.assertRaises(errors.CancelledError):
cancelable_func()
@test_util.disable_tfrt('b/170435618')
def testCancelBlockedFunctionExecution(self):
if not context.executing_eagerly():
self.skipTest('eager only')
q = data_flow_ops.FIFOQueue(1, dtypes.int32)
@polymorphic_function.function
def f():
return q.dequeue()
c_mgr = cancellation.CancellationManager()
cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
def cancel_thread():
time.sleep(0.5)
c_mgr.start_cancel()
t = self.checkedThread(cancel_thread)
t.start()
with self.assertRaises(errors.CancelledError):
cancelable_func()
t.join()
@test_util.disable_tfrt('b/170435618')
def testCancelAfterFunctionExecution(self):
if not context.executing_eagerly():
self.skipTest('eager only')
q = data_flow_ops.FIFOQueue(1, dtypes.int32)
q.enqueue(37)
@polymorphic_function.function
def f():
return q.dequeue()
c_mgr = cancellation.CancellationManager()
cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
self.assertAllEqual(37, cancelable_func().numpy())
# Cancellation after the function executes is a no-op.
c_mgr.start_cancel()
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionWithNestedTensorInputs(self):
@polymorphic_function.function
def f(x, y):
return (x['a'] + x['b'], y[0] + y[1])
a = constant_op.constant(1000)
b = constant_op.constant(200)
c = constant_op.constant(30)
d = {'a': a, 'b': b}
e = (c, 4)
# Test different argument signatures when constructing the concrete func.
for cf in [
f.get_concrete_function(d, e),
f.get_concrete_function(d, y=e),
f.get_concrete_function(y=e, x=d),
f.get_concrete_function(_spec_for_value(d), _spec_for_value(e)),
f.get_concrete_function(_spec_for_value(d), y=_spec_for_value(e)),
f.get_concrete_function(y=_spec_for_value(e), x=_spec_for_value(d))
]:
# Test different calling conventions when calling the concrete func.
for output in [
cf(d, e), # structured signature
cf(d, y=e), # structured signature w/ kwarg
cf(y=e, x=d), # structured signature w/ 2 kwargs
cf(a, b, c), # flat signature
]:
self.assertIsInstance(output, tuple)
self.assertLen(output, 2)
self.assertAllEqual(output[0], 1200)
self.assertAllEqual(output[1], 34)
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionWithNestedNonTensorInputs(self):
@polymorphic_function.function
def f(x, y):
return (x['a'] + x['b'], y[0] + y[1])
a = {'a': constant_op.constant(1000), 'b': constant_op.constant(200)}
b = (50, 3)
for cf in [ # argument y is bound to non-Tensor value (50, 3).
f.get_concrete_function(a, b),
f.get_concrete_function(a, y=b),
f.get_concrete_function(x=a, y=b)
]:
for output in [cf(a, b), cf(x=a, y=b)]:
self.assertAllEqual(output[0] + output[1], 1253)
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionWithNonTensorStringInputs(self):
@polymorphic_function.function
def f(x, y):
return string_ops.string_join([x, y])
a = constant_op.constant('a')
b = 'b'
cf = f.get_concrete_function(a, b)
for output in [cf(a), cf(x=a), cf(a, b), cf(x=a, y=b)]:
self.assertAllEqual(output, b'ab')
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionWithBoundNestedNonTensorInputs(self):
@polymorphic_function.function
def f(x, y):
return (x['a'] + x['b'], y[0] + y[1])
a = {'a': 3000, 'b': 200, 'c': 9000}
b = (constant_op.constant(30), 4)
for cf in [ # argument x is bound to non-tensor value `a`
f.get_concrete_function(a, b),
f.get_concrete_function(a, y=b),
f.get_concrete_function(x=a, y=b)
]:
for output in [cf(a, b), cf(a, y=b), cf(x=a, y=b)]:
self.assertAllEqual(output[0] + output[1], 3234)
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionWithAllBoundNestedNonTensorInputs(self):
@polymorphic_function.function
def f(x, y):
return (x['a'] + x['b'], y[0] + y[1])
a = {'a': 5000, 'b': 500}
b = (50, 5)
cf = f.get_concrete_function(a, b)
for output in [cf(), cf(a, b), cf(x=a, y=b)]:
self.assertAllEqual(output[0] + output[1], 5555)
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionMethodWithVarargs(self):
float32_scalar = tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32)
class MyModel(module.Module):
@polymorphic_function.function(
input_signature=[float32_scalar, float32_scalar])
def add(self, *arg):
return math_ops.add(*arg)
m = MyModel()
cf = m.add.get_concrete_function()
cf(-12.0, 3.0)
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionStructuredSignatureKeywordOrder(self):
# Check that keyword-only arguments are sorted appropriately, so that they
# feed the right tensor into each input.
@polymorphic_function.function
def g(**kwargs):
return string_ops.reduce_join(
string_ops.reduce_join(
ops.convert_to_tensor(sorted(kwargs.items())),
axis=1,
separator='='),
axis=0,
separator=', ')
s = constant_op.constant('s')
g.get_concrete_function(q=s, a=s, p=s, r=s, v=s, m=s, l=s)
self.assertAllEqual(
g(m='a', r='b', v='c', q='d', l='e', a='f', p='g'),
b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
self.assertAllEqual(
g(q='d', a='f', p='g', r='b', v='c', m='a', l='e'),
b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
self.assertAllEqual(
g(a='f', l='e', m='a', p='g', q='d', r='b', v='c'),
b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
def testSameConcreteFunctionDifferentKwargOrder(self):
@polymorphic_function.function
def foo(**kwargs):
return kwargs['a'] + math_ops.cast(kwargs['b'], dtypes.float32)
foo(a=constant_op.constant(1.0), b=constant_op.constant(1))
foo(b=constant_op.constant(1), a=constant_op.constant(1.0))
self.assertLen(total_function_cache(foo), 1)
def testEmptyInputSignatures(self):
class Foo:
@polymorphic_function.function(input_signature=[])
def bar_none(self):
return 1
@polymorphic_function.function(input_signature=[])
def bar_one(self, x=0):
return x
@polymorphic_function.function(input_signature=[])
def bar_two(self, x=0, y=1):
return x + y
foo = Foo()
self.assertEqual(foo.bar_none.input_signature, ())
self.assertEqual(foo.bar_one.input_signature, ())
self.assertEqual(foo.bar_two.input_signature, ())
# pylint: disable=g-long-lambda
@parameterized.named_parameters([
dict(
testcase_name='MissingArg',
conc_args=lambda: (1, constant_op.constant(2)),
call_args=lambda: (1,),
error=r'missing a required argument: \'y\'',
),
dict(
testcase_name='MissingVararg',
conc_args=lambda: (1, 2, constant_op.constant(1.0)),
call_args=lambda: (1, 2),
error=r'missing a required argument: \'varargs_0\'',
),
dict(
testcase_name='ExtraPositionalArg',
conc_args=lambda: (1, 2),
call_args=lambda: (1, 2, 3),
error=r'too many positional arguments',
),
dict(
testcase_name='MissingKeywordOnlyArg',
conc_args=lambda: (1, 2),
conc_kwargs=lambda: {'c': constant_op.constant(1.0)},
call_args=lambda: (1, 2),
error=r'missing a required( keyword-only)? argument: \'c\'',
),
dict(
testcase_name='ExtraKeywordArg',
conc_args=lambda: (1, 2),
call_args=lambda: (1, 2),
call_kwargs=lambda: {'c': constant_op.constant(1.0)},
error=r'got an unexpected keyword argument',
),
dict(
testcase_name='ExpectedRaggedGotNest',
conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),),
call_args=lambda: ({'a': constant_op.constant([1, 2, 3])}, 5),
error=(
r'Binding inputs .* failed .* don\'t have the same nested'
r' structure'
),
),
dict(
testcase_name='WrongRaggedRank',
conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),),
call_args=lambda: (ragged_factory_ops.constant([[[1]]]), 5),
error=(
r'Binding inputs .* failed .* don\'t have the same nested'
r' structure'
),
),
dict(
testcase_name='WrongRaggedDType',
conc_args=lambda: (ragged_factory_ops.constant([[1]]),),
call_args=lambda: (ragged_factory_ops.constant([[1.0]]), 5),
error=(
r'Binding inputs .* failed.*dtype=tf.float32.* to'
r' .*dtype=tf.int32.*'
),
),
dict(
testcase_name='ExpectedDictGotTensor',
conc_args=lambda: (
{'a': constant_op.constant(1), 'b': constant_op.constant(1)},
),
call_args=lambda: (constant_op.constant(1), 5),
error=r'Binding inputs .* failed .*Can not cast .*Tensor.* to a Dict',
),
dict(
testcase_name='ExpectedTupleGotTensor',
conc_args=lambda: (
(constant_op.constant(1), constant_op.constant(2)),
),
call_args=lambda: (constant_op.constant(1), 5),
error=r'Binding inputs .* failed .*Can not cast .*Tensor.* to tuple',
),
dict(
testcase_name='WrongDType',
conc_args=lambda: (constant_op.constant(1),),
call_args=lambda: (constant_op.constant(1.0), 5),
exception=(
TypeError,
errors.InvalidArgumentError,
# on xla_gpu, we get InternalError instead.
errors.InternalError,
),
),
dict(
testcase_name='ExpectedIntGotDifferentInt',
conc_args=lambda: (5,),
call_args=lambda: (8, 5),
error=r'Binding inputs .* failed .*Can not cast 8 to .*5',
),
dict(
testcase_name='ExpectedIntGotTensor',
conc_args=lambda: (5,),
call_args=lambda: (constant_op.constant(6), 5),
error=r'Binding inputs .* failed .*Can not cast .*Tensor.* to .*5',
),
dict(
testcase_name='TwoValuesForArgument',
conc_args=lambda: (1, 2),
call_args=lambda: (1, 2),
call_kwargs=lambda: {'x': 3},
error=r'got an unexpected keyword argument \'x\'',
),
])
# pylint: enable=g-long-lambda
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionStructuredSignatureError(self,
conc_args=(),
conc_kwargs=None,
call_args=(),
call_kwargs=None,
error='.*',
exception=TypeError):
"""Tests for errors in the structrued signature.
Args:
conc_args: Positional arguments used for get_concrete_function.
conc_kwargs: Keyword arguments used for get_concrete_function.
call_args: Positional arguments used to call the function.
call_kwargs: Keyword arguments used to call the function.
error: Expected exception message.
exception: Expected exception type.
"""
conc_args = conc_args() if callable(conc_args) else conc_args
conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {}
call_args = call_args() if callable(call_args) else call_args
call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {}
self.assertIsInstance(conc_args, tuple)
self.assertIsInstance(call_args, tuple)
self.assertIsInstance(conc_kwargs, dict)
self.assertIsInstance(call_kwargs, dict)
@polymorphic_function.function
def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg
del y, varargs, kwargs
return x
conc = func.get_concrete_function(*conc_args, **conc_kwargs)
with self.assertRaisesRegex(exception, error):
self.evaluate(conc(*call_args, **call_kwargs))
# pylint: disable=g-long-lambda
@parameterized.named_parameters([
dict(
testcase_name='MissingArg',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_args=lambda: (constant_op.constant(1),),
error=r'func\(x, y\) missing required arguments: y'),
dict(
testcase_name='TwoValuesForArg',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_args=lambda: (constant_op.constant(1),),
call_kwargs=lambda: {
'x': constant_op.constant(1),
'y': constant_op.constant(1)
},
error=r"func\(x, y\) got two values for 'x'"),
dict(
testcase_name='ExtraPositionalArg',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_args=lambda: (constant_op.constant(1), constant_op.constant(2),
constant_op.constant(3)),
error=r'func\(x, y\) takes 2 .* got 3'),
dict(
testcase_name='UnexpectedKeywordArg',
conc_args=lambda: (constant_op.constant(1),),
call_args=lambda: (constant_op.constant(1),),
call_kwargs=lambda: {'c': constant_op.constant(1)},
error=r'func\(x\) got unexpected keyword arguments: c'),
dict(
testcase_name='MissingVararg',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2),
constant_op.constant(3)),
call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
error=r'func\(x, y, varargs_0\) missing required '
r'arguments: varargs_0'),
dict(
testcase_name='MissingKeywordArg',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
conc_kwargs=lambda: {'c': constant_op.constant(1)},
call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
error=r'func\(x, y, c\) missing required arguments: c'),
dict(
testcase_name='ExpectedTensorGotInt',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_args=lambda: (5, constant_op.constant(2)),
error=r'func\(x, y\): expected argument #0\(zero-based\) to be '
r'a Tensor; got int \(5\)'),
dict(
testcase_name='WrongDType',
conc_args=lambda: (constant_op.constant(1),),
call_args=lambda: (constant_op.constant(1.0),),
exception=(
ValueError,
errors.InvalidArgumentError,
# on xla_gpu, we get InternalError instead.
errors.InternalError)),
dict(
testcase_name='MissingKeywordArgNestPiece',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
conc_kwargs=lambda: {'c': ragged_factory_ops.constant([[1]])},
call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_kwargs=lambda: {'c': constant_op.constant(1)},
error=r'func\(x, y, c, c_1\) missing required arguments: c_1'),
])
# pylint: enable=g-long-lambda
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionFlatSignatureError(self,
conc_args=(),
conc_kwargs=None,
call_args=(),
call_kwargs=None,
error='.*',
exception=TypeError):
"""Tests for errors in the flat signature.
Args:
conc_args: Positional arguments used for get_concrete_function.
conc_kwargs: Keyword arguments used for get_concrete_function.
call_args: Positional arguments used to call the function.
call_kwargs: Keyword arguments used to call the function.
error: Expected exception message.
exception: Expected exception type.
"""
conc_args = conc_args() if callable(conc_args) else conc_args
conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {}
call_args = call_args() if callable(call_args) else call_args
call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {}
self.assertIsInstance(conc_args, tuple)
self.assertIsInstance(call_args, tuple)
self.assertIsInstance(conc_kwargs, dict)
self.assertIsInstance(call_kwargs, dict)
@polymorphic_function.function
def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg
del y, varargs, kwargs
return x
conc = func.get_concrete_function(*conc_args, **conc_kwargs)
with self.assertRaisesRegex(exception, error):
self.evaluate(conc._call_with_flat_signature(call_args, call_kwargs)) # pylint: disable=protected-access
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionAmbiguousSignature(self):
# When both the flat & structured signatures are applicable, but they
# give different results, we use the structured signature. Note: we expect
# this to be extremely rare.
@polymorphic_function.function
def f(x, y):
return x * 10 + y
conc = f.get_concrete_function(
x=tensor_lib.TensorSpec(None, dtypes.int32, name='y'),
y=tensor_lib.TensorSpec(None, dtypes.int32, name='x'))
result = conc(x=constant_op.constant(5), y=constant_op.constant(6))
self.assertAllEqual(result, 56)
def testPrettyPrintedSignature(self):
@polymorphic_function.function
def func(x, kangaroo=None, octopus=7):
del octopus, kangaroo
return x
scalar = constant_op.constant(5)
vector = constant_op.constant([10, 10, 20])
concrete_fn = func.get_concrete_function(scalar, vector)
summary = (
'(x: TensorSpec(shape=(), dtype=tf.int32, name=None), kangaroo:'
' TensorSpec(shape=(3,), dtype=tf.int32, name=None), octopus:'
' Literal[7]) -> TensorSpec(shape=(), dtype=tf.int32, name=None)'
)
details = (
'Input Parameters:\n'
+ ' x (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(),'
' dtype=tf.int32, name=None)\n'
+ ' kangaroo (POSITIONAL_OR_KEYWORD):'
' TensorSpec(shape=(3,), dtype=tf.int32, name=None)\n'
+ ' octopus (POSITIONAL_OR_KEYWORD): Literal[7]\n'
+ 'Output Type:\n'
+ ' TensorSpec(shape=(), dtype=tf.int32, name=None)\n'
+ 'Captures:\n'
+ ' None'
)
self.assertEqual(
concrete_fn.pretty_printed_signature(verbose=False), summary
)
self.assertEqual(
concrete_fn.pretty_printed_signature(verbose=True), details
)
self.assertRegex(repr(concrete_fn), r'<ConcreteFunction .* at .*')
self.assertEqual(str(concrete_fn), 'ConcreteFunction {}'.format(details))
def testPrettyPrintedExplicitSignatureWithKeywordArg(self):
@polymorphic_function.function(
input_signature=[tensor_lib.TensorSpec(None)])
def fn(a, b=1):
return a + b
concrete_fn = fn.get_concrete_function()
self.assertEqual(
concrete_fn.pretty_printed_signature(False),
'(a: TensorSpec(shape=<unknown>, dtype=tf.float32, name=None), b:'
' Literal[1]) -> TensorSpec(shape=<unknown>, dtype=tf.float32,'
' name=None)',
)
self.assertEqual(
concrete_fn.pretty_printed_signature(True),
'Input Parameters:\n'
+ ' a (POSITIONAL_OR_KEYWORD):'
' TensorSpec(shape=<unknown>, dtype=tf.float32, name=None)\n'
+ ' b (POSITIONAL_OR_KEYWORD): Literal[1]\n'
+ 'Output Type:\n'
+ ' TensorSpec(shape=<unknown>, dtype=tf.float32, name=None)\n'
+ 'Captures:\n'
+ ' None',
)
def testPrettyPrintedSignatureLoadedNamedTuple(self):
Point = collections.namedtuple('Point', ['x', 'y'])
@polymorphic_function.function
def fn(b, a): # pylint: disable=unused-argument
return 1.
b = Point(
x=constant_op.constant(1., dtype=dtypes.float32),
y=constant_op.constant(1., dtype=dtypes.float32))
a = Point(
x=constant_op.constant(1, dtype=dtypes.int32),
y=constant_op.constant(1, dtype=dtypes.int32))
mod = module.Module()
f = fn.get_concrete_function(b, a)
save(mod, '/tmp/f', signatures=f)
loaded = load('/tmp/f')
printed = loaded.signatures['serving_default'].pretty_printed_signature()
self.assertEqual(
printed,
'Input Parameters:\n'
+ " a (KEYWORD_ONLY): TensorSpec(shape=(), dtype=tf.int32, name='a')\n"
+ ' a_1 (KEYWORD_ONLY): TensorSpec(shape=(),'
" dtype=tf.int32, name='a_1')\n"
+ ' b (KEYWORD_ONLY): TensorSpec(shape=(),'
" dtype=tf.float32, name='b')\n"
+ ' b_1 (KEYWORD_ONLY):'
" TensorSpec(shape=(), dtype=tf.float32, name='b_1')\n"
+ 'Output Type:\n'
+ " Dict[['output_0', TensorSpec(shape=(), dtype=tf.float32,"
" name='output_0')]]\n"
+ 'Captures:\n'
+ ' None',
)
@test_util.run_in_graph_and_eager_modes
def testIndexedSlicesAsGradientsForConcreteFunctions(self):
@polymorphic_function.function
def summing_rnn(inputs):
return math_ops.reduce_sum(inputs, axis=1)
@polymorphic_function.function
def gradients(inputs):
with backprop.GradientTape() as tape:
tape.watch(inputs)
hidden = summing_rnn(inputs)
hidden = array_ops.gather(hidden, constant_op.constant([0]))
loss = math_ops.reduce_mean(hidden)
return tape.gradient(loss, inputs)
gradients(constant_op.constant([[[1.0], [2.0]]])) # No error is raised
def testWithExtraWrapper(self):
class Foo(module.Module):
def __init__(self):
super().__init__()
self.var = None
@polymorphic_function.function
@dummy_tf_decorator
def add(self, x, y, z=1):
if self.var is None:
return x + y + z
foo = Foo()
self.assertEqual(foo.add(2, 3).numpy(), 6)
@parameterized.parameters([
(polymorphic_function.function, dummy_tf_decorator),
(dummy_tf_decorator, polymorphic_function.function),
(polymorphic_function.function, polymorphic_function.function)
])
def testWithExtraWrapperRedundantArgs(self, decorator1, decorator2):
class Foo(module.Module):
def __init__(self):
super().__init__()
self.var = None
@decorator1
@decorator2
def add1(self, x, y):
if self.var is None:
return x + y
foo = Foo()
with self.assertRaisesRegex(TypeError, 'multiple values for argument'):
foo.add1(2, x=3) # pylint: disable=redundant-keyword-arg,no-value-for-parameter
def testWithExtraWrapperMissingArgs(self):
class Foo(module.Module):
def __init__(self):
super().__init__()
self.var = None
@polymorphic_function.function
@dummy_tf_decorator
def add1(self, x, y):
if self.var is None:
return x + y
@polymorphic_function.function
@dummy_tf_decorator
def add2(self, x, y):
if self.var is None:
return x + y
@polymorphic_function.function
@polymorphic_function.function
def add3(self, x, y):
if self.var is None:
return x + y
foo = Foo()
with self.assertRaisesRegex(TypeError,
'missing a required argument: \'y\''):
foo.add1(2) # pylint: disable=no-value-for-parameter
with self.assertRaisesRegex(TypeError,
'missing a required argument: \'x\''):
foo.add1(y=2) # pylint: disable=no-value-for-parameter
with self.assertRaisesRegex(TypeError,
'missing a required argument: \'y\''):
foo.add2(2) # pylint: disable=no-value-for-parameter
with self.assertRaisesRegex(TypeError,
'missing a required argument: \'x\''):
foo.add2(y=2) # pylint: disable=no-value-for-parameter
with self.assertRaisesRegex(TypeError,
'missing a required argument: \'y\''):
foo.add3(2) # pylint: disable=no-value-for-parameter
with self.assertRaisesRegex(TypeError,
'missing a required argument: \'x\''):
foo.add3(y=2) # pylint: disable=no-value-for-parameter
def testMissingArgsTfFunctionedMethod(self):
class A:
def func(self, position_arg1, position_arg2):
return position_arg1, position_arg2
@polymorphic_function.function
def decorated_method(self, position_arg1, position_arg2):
return position_arg1, position_arg2
a_instance = A()
tf_method_pos = polymorphic_function.function(a_instance.func)
with self.assertRaisesRegex(TypeError, 'missing a required argument'):
tf_method_pos(position_arg2='foo')
# tf.function-decorated instance methods need to be tested because of
# the __get__ method implementation.
tf_func_decorated_method = polymorphic_function.function(
a_instance.decorated_method)
tf_func_decorated_method(position_arg1='foo', position_arg2='bar')
with self.assertRaisesRegex(TypeError, 'missing a required argument'):
tf_func_decorated_method(position_arg2='bar')
def testMissingArgsTfFunctionedObject(self):
class A:
def __call__(self, position_arg1, position_arg2):
return position_arg1, position_arg2
a_instance = A()
# A tf.function-decorated callable object needs to be tested because of
# the special inspect results.
tf_func_obj = polymorphic_function.function(a_instance)
tf_func_obj(position_arg1=1, position_arg2=2)
with self.assertRaisesRegex(TypeError, 'missing a required argument'):
tf_func_obj(position_arg2='bar')
def testMissingArgsTfFunctionedFunctions(self):
def func_pos(position_arg1, position_arg2):
return position_arg1, position_arg2
def func_with_default(position_arg, named_arg=None):
return position_arg, named_arg
def func_pos_3args(position_arg1, position_arg2, position_arg3):
return position_arg1, position_arg2, position_arg3
tf_func_pos = polymorphic_function.function(func_pos)
with self.assertRaisesRegex(
TypeError, 'missing a required argument'):
tf_func_pos(position_arg2='foo')
tf_func_with_default = polymorphic_function.function(func_with_default)
tf_func_with_default(position_arg='bar')
with self.assertRaisesRegex(TypeError, 'missing a required argument'):
tf_func_with_default(named_arg='foo')
tf_func_pos_3args = polymorphic_function.function(func_pos_3args)
with self.assertRaisesRegex(TypeError, 'missing a required argument'):
tf_func_pos_3args(position_arg2='foo')
def testShapeInferencePropagateConstNestedStack(self):
@polymorphic_function.function(input_signature=[
tensor_lib.TensorSpec((None, None), dtype=dtypes.int32),
tensor_lib.TensorSpec((), dtype=dtypes.int32),
])
def f(x, s):
old_shape = array_ops.shape(x)
new_shape = array_ops_stack.stack([old_shape[0], s], axis=0)
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
return y
@polymorphic_function.function(input_signature=[
tensor_lib.TensorSpec(shape=(3, 6), dtype=dtypes.int32)
])
def g(x):
y = f(x, s=5)
assert y.shape.as_list() == [3, 5], y.shape.as_list()
return y
self.assertAllEqual(
g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5]))
def testShapeInferencePropagateConstNestedUnstackStack(self):
@polymorphic_function.function(input_signature=[
tensor_lib.TensorSpec((None, None), dtype=dtypes.int32),
tensor_lib.TensorSpec((), dtype=dtypes.int32),
])
def f(x, s):
s0, _ = array_ops_stack.unstack(array_ops.shape(x), axis=0)
new_shape = array_ops_stack.stack([s0, s], axis=0)
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
return y
@polymorphic_function.function(input_signature=[
tensor_lib.TensorSpec(shape=(3, 6), dtype=dtypes.int32)
])
def g(x):
y = f(x, s=5)
assert y.shape.as_list() == [3, 5], y.shape.as_list()
return y
self.assertAllEqual(
g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5]))
def testShapeInferencePropagateConstNestedConcat(self):
@polymorphic_function.function(input_signature=[
tensor_lib.TensorSpec((), dtype=dtypes.int32),
tensor_lib.TensorSpec((), dtype=dtypes.int32),
tensor_lib.TensorSpec((), dtype=dtypes.int32),
])
def f(d1, d2, d3):
new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1)
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
return y
@polymorphic_function.function()
def g():
y = f(1, 2, 3)
assert y.shape.as_list() == [1, 2, 3], y.shape.as_list()
return y
self.assertAllEqual(g(), array_ops.ones([1, 2, 3]))
def testShapeInferencePropagateConstDoubleNested(self):
@polymorphic_function.function(input_signature=[
tensor_lib.TensorSpec((), dtype=dtypes.int32),
tensor_lib.TensorSpec((), dtype=dtypes.int32),
tensor_lib.TensorSpec((), dtype=dtypes.int32),
])
def f(d1, d2, d3):
new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1)
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
return y
@polymorphic_function.function()
def g():
y = polymorphic_function.function(f)(1, 2, 3)
assert y.shape.as_list() == [1, 2, 3], y.shape.as_list()
return y
self.assertAllEqual(g(), array_ops.ones([1, 2, 3]))
@test_util.run_v2_only
def testControlDependencyAfterInline(self):
v = variables.Variable(0.)
@polymorphic_function.function
def assign():
return v.assign(1.)
@polymorphic_function.function
def assign_add():
return v.assign_add(1.)
@polymorphic_function.function
def f():
check_ops.assert_equal_v2(assign(), 1.)
check_ops.assert_equal_v2(assign_add(), 2.)
# We don't have a way to inspect the inlined graph in Python, so we run it
# multiple times to have more confidence the dependency is correct.
for _ in range(30):
f()
@test_util.run_v2_only
def testReadInFuncWriteOutside(self):
# Run many times since we are testing for a potential race condition.
for _ in range(30):
# pylint: disable=cell-var-from-loop
v = variables.Variable(1.)
@polymorphic_function.function
def add_one():
return v + 1.
@polymorphic_function.function
def get_v_plus_one():
v_plus_one = add_one()
v.assign_add(2.0)
return v_plus_one
self.assertAllEqual(get_v_plus_one(), 2.0)
def testOpExpandErrorMessage(self):
@polymorphic_function.function
def test_fn():
if array_ops.constant(False):
return array_ops.constant(1)
else:
return script_ops.eager_py_func(
func=lambda: array_ops.constant([2.]), inp=(), Tout=dtypes.int32)
error_pattern = re.compile(r'Graph execution error.*test_fn', re.DOTALL)
with self.assertRaisesRegex(errors.InvalidArgumentError, error_pattern):
test_fn()
def testNoVariables(self):
@polymorphic_function.function
def fn(x):
return 2 * x
self.assertAllEqual(fn(constant_op.constant(4.0)), 8.0)
def testFailIfVariablesAreCreatedMoreThanOnce(self):
@polymorphic_function.function
def fn(x):
return variables.Variable(1.0) + x
with self.assertRaises(ValueError):
fn(1.0)
def testFailIfVariablesAreCreatedMoreThanOnceNoWeakRef(self):
state = []
@polymorphic_function.function
def fn(x):
state.append(variables.Variable(1.0))
return state[-1] + x
with self.assertRaises(ValueError):
fn(1.0)
def testRange(self):
@polymorphic_function.function
def f(unused_x):
return 1.0
self.assertAllEqual(f(range(5)), 1.0)
def testCorrectVariableCreation(self):
state = []
@polymorphic_function.function
def fn(x):
if not state:
state.append(variables.Variable(2.0))
return state[0] * x
self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0)
def testFunctionInitializer(self):
state = []
@polymorphic_function.function
def fn(x):
if not state:
state.append(variables.Variable(lambda: 2.0))
return state[0] * x
self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
def testFunctionMultipleVariableInitializer(self):
state = []
@polymorphic_function.function
def fn(x):
if not state:
state.append(variables.Variable(lambda: 2.0))
state.append(variables.Variable(lambda: 5.0))
return state[0] * x, state[1] * x
self.assertAllEqual(fn(constant_op.constant(1.0)), [2.0, 5.0])
def testFunctionInitializationFunction(self):
state = []
@polymorphic_function.function
def fn(x):
if not state:
state.append(variables.Variable(2.0))
return state[0] * x
init_fn = fn.get_initialization_function(constant_op.constant(1.0))
self.assertLen(state, 1)
self.assertFalse(
resource_variable_ops.var_is_initialized_op(state[0].handle))
init_fn()
self.assertEqual(state[0].numpy(), 2.0)
def testVariableInitializerNotConstant(self):
state = []
@polymorphic_function.function
def fn(x):
if not state:
state.append(variables.Variable(2.0 * x))
return state[0] * x
self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0)
def testLegacyGraphModeVariables(self):
with ops.Graph().as_default(), self.test_session() as sess:
state = []
@polymorphic_function.function
def fn(x):
if not state:
state.append(variables.Variable(2.0))
return state[0] * x
result = fn(3.0)
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual(sess.run(state[0]), 2.0)
self.assertAllEqual(self.evaluate(result), 6.0)
def testLegacyGraphModeVariablesNonTrivialInitializer(self):
with ops.Graph().as_default(), self.test_session() as sess:
state = []
@polymorphic_function.function
def fn(x):
if not state:
two = constant_op.constant(2.0)
four = two * two
two_again = math_ops.sqrt(four)
state.append(variables.Variable(two_again + four))
return state[0] * x
result = fn(3.0)
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual(sess.run(state[0]), 6.0)
self.assertAllEqual(self.evaluate(result), 18.0)
def testLegacyGraphModeInputDependentInitializerFails(self):
with ops.Graph().as_default():
state = []
@polymorphic_function.function
def fn(x):
if not state:
state.append(variables.Variable(2.0 * x))
return state[0] * x
with self.assertRaisesRegex(lift_to_graph.UnliftableError,
r'transitively.* mul .* x'):
fn(constant_op.constant(3.0))
def testMethod(self):
class MyModel:
def __init__(self):
self.var = None
@polymorphic_function.function
def apply(self, x):
if self.var is None:
self.var = variables.Variable(2.0)
return self.var * x
m0 = MyModel()
self.assertAllEqual(m0.apply(3.0), 6.0)
# Calling twice to exercise that we do not recreate variables.
m0.var.assign(3.0)
self.assertAllEqual(m0.apply(3.0), 9.0)
m1 = MyModel()
self.assertAllEqual(m1.apply(3.0), 6.0)
def testMethodExtensionType(self):
class MaskedTensorExtensionType(extension_type.ExtensionType):
values: tensor_lib.Tensor
mask: tensor_lib.Tensor
@polymorphic_function.function
def with_default(self, default_value):
return array_ops.where_v2(self.mask, self.values, default_value)
@polymorphic_function.function
def sum(self):
# Use a loop & conditional to test that autograph works correctly.
result = 0
for i in range(array_ops.size(self.values)):
if self.mask[i]:
result += self.values[i]
return result
mt = MaskedTensorExtensionType([1, 2, 3], [True, False, True])
self.assertAllEqual(mt.with_default(-1), [1, -1, 3])
self.assertAllEqual(mt.sum(), 4)
def test_functools_partial(self):
self.assertAllClose(
3.,
polymorphic_function.function(
functools.partial(lambda x, y: x + y,
1.))(constant_op.constant(2.)))
def test_functools_partial_new_default(self):
def f(x=3, y=7):
return x + y
func = polymorphic_function.function(functools.partial(f, y=6))
self.assertEqual(func().numpy(), 9)
self.assertEqual(func(y=8).numpy(), 11)
def test_functools_partial_keywords(self):
def f(x, y):
return x + y
func = polymorphic_function.function(
functools.partial(f, x=array_ops.zeros([1]), y=array_ops.zeros([1])))
self.assertAllEqual(func(), [0.0])
def test_functools_partial_single_positional(self):
def f(x, y):
return x + y
func = polymorphic_function.function(
functools.partial(f, constant_op.constant(1)))
self.assertAllEqual(func(5), 6)
def test_complicated_partial_with_defaults(self):
def identity(*args):
return args
def dynamic_unroll(core_fn,
input_sequence,
initial_state,
sequence_length=None,
parallel_iterations=1,
swap_memory=False):
del core_fn
self.assertIs(None, sequence_length)
self.assertEqual(1, parallel_iterations)
self.assertTrue(swap_memory)
return input_sequence, initial_state
input_sequence = random_ops.random_uniform([1, 1, 1])
initial_state = random_ops.random_uniform([1, 1])
func = polymorphic_function.function(
functools.partial(dynamic_unroll, identity, swap_memory=True))
func(input_sequence, initial_state)
def test_unspecified_default_argument(self):
wrapped = polymorphic_function.function(
lambda x, y=2: x + y,
input_signature=[tensor_lib.TensorSpec((), dtypes.int32)])
self.assertEqual(3, wrapped(constant_op.constant(1)).numpy())
def test_concrete_function_from_signature(self):
@polymorphic_function.function(
input_signature=[tensor_lib.TensorSpec(None, dtypes.float32)])
def compute(x):
return 2. * x
concrete = compute.get_concrete_function()
self.assertAllClose(1., concrete(constant_op.constant(0.5)))
concrete = compute.get_concrete_function(
tensor_lib.TensorSpec(None, dtypes.float32))
self.assertAllClose(4., concrete(constant_op.constant(2.)))
signature_args, _ = concrete.structured_input_signature
self.assertEqual(signature_args,
(tensor_lib.TensorSpec(
None, dtypes.float32, name='x'),))
def testInputSignatureMissingTensorSpecsMethod(self):
class MyModule(module.Module):
def f1(self, arg1, arg2, arg3):
pass
def f2(self, arg1, arg2, arg3, **kwargs):
pass
def f3(self, arg1, arg2, arg3, arg4=4, **kwargs):
pass
def f4(self, arg1, arg2, arg3, *args):
pass
def f5(self, arg1, arg2, arg3, *args, **kwargs):
pass
def f6(self, arg1, arg4=4, **kwargs):
return arg1 + arg4
m = MyModule()
tf_func_dec = polymorphic_function.function(
input_signature=(tensor_lib.TensorSpec([], dtypes.int32),))
error_message = 'input_signature missing type constraint'
with self.assertRaisesRegex(TypeError, error_message):
tf_func_dec(m.f1)(1, 2, 3)
with self.assertRaisesRegex(TypeError, error_message):
tf_func_dec(m.f2)(1, 2, 3)
with self.assertRaisesRegex(TypeError, error_message):
tf_func_dec(m.f3)(1, 2, 3)
with self.assertRaisesRegex(TypeError, error_message):
tf_func_dec(m.f4)(1, 2, 3)
with self.assertRaisesRegex(TypeError, error_message):
tf_func_dec(m.f5)(1, 2, 3)
self.assertEqual(tf_func_dec(m.f6)(1).numpy(), 5)
def testInputSignatureMissingTensorSpecsFunction(self):
tf_func_dec = polymorphic_function.function(
input_signature=(tensor_lib.TensorSpec([], dtypes.int32),))
error_message = 'input_signature missing type constraint'
# pylint: disable=unused-argument
def f1(arg1, arg2, arg3):
pass
with self.assertRaisesRegex(TypeError, error_message):
tf_func_dec(f1)(1, 2, 3)
def f2(arg1, arg2, arg3, **kwargs):
pass
with self.assertRaisesRegex(TypeError, error_message):
tf_func_dec(f2)(1, 2, 3)
def f3(arg1, arg2, arg3, arg4=4, **kwargs):
pass
with self.assertRaisesRegex(TypeError, error_message):
tf_func_dec(f3)(1, 2, 3)
def f4(arg1, arg2, arg3, *args):
pass
with self.assertRaisesRegex(TypeError, error_message):
tf_func_dec(f4)(1, 2, 3)
def f5(arg1, arg2, arg3, *args, **kwargs):
pass
with self.assertRaisesRegex(TypeError, error_message):
tf_func_dec(f5)(1, 2, 3)
# pyline: enable=unused-argument
def f6(arg1, arg4=4, **kwargs):
return arg1 + arg4
self.assertEqual(tf_func_dec(f6)(1).numpy(), 5)
def testInputSignatureMissingTensorSpecsLambdaFunction(self):
tf_func_dec = polymorphic_function.function(
input_signature=(tensor_lib.TensorSpec([], dtypes.int32),))
error_message = 'input_signature missing type constraint'
with self.assertRaisesRegex(TypeError, error_message):
tf_func_dec(lambda ar1, arg2, arg3: None)(1, 2, 3)
with self.assertRaisesRegex(TypeError, error_message):
tf_func_dec(lambda arg1, arg2, arg3, **kwargs: None)(1, 2, 3)
with self.assertRaisesRegex(TypeError, error_message):
tf_func_dec(lambda arg1, arg2, arg3, arg4=4, **kwargs: None)(1, 2, 3)
with self.assertRaisesRegex(TypeError, error_message):
tf_func_dec(lambda arg1, arg2, arg3, *args: None)(1, 2, 3)
with self.assertRaisesRegex(TypeError, error_message):
tf_func_dec(lambda arg1, arg2, arg3, *args, **kwargs: None)(1, 2, 3)
self.assertEqual(
tf_func_dec(lambda arg1, arg4=4, **kwargs: arg1 + arg4)(1).numpy(), 5)
@parameterized.named_parameters(('_method', 'method'),
('_function', 'function'),
('_lambda_function', 'lambda_function'))
def testInputSignaturePartialFuncMissingTensorSpecs(self, func_type):
if func_type == 'method':
class MyModule(module.Module):
def f(self, arg1, arg2, arg3, arg4=4):
return arg1 + arg2 + arg3 + arg4
f = MyModule().f
elif func_type == 'function':
def f(arg1, arg2, arg3, arg4=4):
return arg1 + arg2 + arg3 + arg4
else: # lambda_function
f = lambda arg1, arg2, arg3, arg4=4: arg1 + arg2 + arg3 + arg4
error_message = 'input_signature missing type constraint'
tf_func_dec = polymorphic_function.function(
input_signature=(tensor_lib.TensorSpec([], dtypes.int32),)
)
with self.assertRaisesRegex(TypeError, error_message):
tf_func_dec(functools.partial(f, 1))(2, 3)
with self.assertRaisesRegex(TypeError, error_message):
tf_func_dec(functools.partial(f, arg4=5))(1, 2, 3)
with self.assertRaisesRegex(TypeError, error_message):
tf_func_dec(functools.partial(f, 1, arg4=5))(2, 3)
self.assertAllEqual(
tf_func_dec(functools.partial(f, 1, 2, arg4=5))(3),
array_ops.constant(11),
)
@test_util.run_in_graph_and_eager_modes
def test_variable_naming(self):
class HasVars(module.Module):
def __init__(self):
self.x = None
self.y = None
self.z = None
@polymorphic_function.function
def make_x(self):
if self.x is None:
self.x = variables.Variable(1., name='v')
def make_y(self):
if self.y is None:
self.y = variables.Variable(1., name='v')
def make_z(self):
if self.z is None:
with ops.name_scope('z_scope', skip_on_eager=False):
self.z = variables.Variable(1., name='z')
root = HasVars()
root.make_x()
root.make_y()
root.make_z()
self.assertEqual('v:0', root.x.name)
self.assertEqual('z_scope/z:0', root.z.name)
def test_concrete_function_keyword_arguments(self):
@polymorphic_function.function
def f(x):
return x
conc = f.get_concrete_function(
tensor_lib.TensorSpec(None, dtypes.float32, 'y'))
conc(y=constant_op.constant(3.0))
signature_args, _ = conc.structured_input_signature
self.assertEqual('y', signature_args[0].name)
# If name is not specified, the previously named one will be returned.
conc = f.get_concrete_function(tensor_lib.TensorSpec(None, dtypes.float32))
conc(x=constant_op.constant(3.0))
signature_args, _ = conc.structured_input_signature
self.assertEqual('y', signature_args[0].name)
# New name will return updated signature.
conc = f.get_concrete_function(
tensor_lib.TensorSpec(None, dtypes.float32, 'z')
)
conc(x=constant_op.constant(3.0))
signature_args, _ = conc.structured_input_signature
self.assertEqual('z', signature_args[0].name)
@polymorphic_function.function
def g(x):
return x[0]
conc = g.get_concrete_function(
[tensor_lib.TensorSpec(None, dtypes.float32, 'z'), 2])
conc(z=constant_op.constant(3.0))
signature_args, _ = conc.structured_input_signature
self.assertEqual('z', signature_args[0][0].name)
def testRuntimeErrorNotSticky(self):
@polymorphic_function.function
def fail(i):
control_flow_assert.Assert(math_ops.equal(i, 0), ['ick'])
fail(constant_op.constant(0)) # OK
with self.assertRaises(errors.InvalidArgumentError):
fail(constant_op.constant(1)) # InvalidArgument: "ick"
fail(constant_op.constant(0)) # OK
def testUnderscoreName(self):
@polymorphic_function.function
def f(_):
return _ + _
self.assertAllEqual(2.0, f(constant_op.constant(1.0)))
def test_serialization_signature_cache(self):
@polymorphic_function.function
def f(x, y):
return x, y
f(constant_op.constant([[3., 4.]]), constant_op.constant([2.]))
f(constant_op.constant([[3, 4, 5]]), constant_op.constant([2]))
signatures_args = set()
concrete_functions = f._list_all_concrete_functions_for_serialization()
for concrete_function in concrete_functions:
args, kwargs = concrete_function.structured_input_signature
signatures_args.add(args)
self.assertEqual(dict(), kwargs)
self.assertEqual(
signatures_args,
set(((tensor_lib.TensorSpec([1, 2], dtypes.float32, name='x'),
tensor_lib.TensorSpec([1], dtypes.float32, name='y')),
(tensor_lib.TensorSpec([1, 3], dtypes.int32, name='x'),
tensor_lib.TensorSpec([1], dtypes.int32, name='y')))))
@test_util.assert_no_garbage_created
def testFunctionReferenceCycles(self):
fn = polymorphic_function.function(lambda x: 2. * x)
fn(constant_op.constant(4.0))
weak_fn = weakref.ref(fn)
del fn
# Tests that the weak reference we made to the function is now dead, which
# means the object has been deleted. This should be true as long as the
# function itself is not involved in a reference cycle.
self.assertIs(None, weak_fn())
@test_util.assert_no_garbage_created
def testMethodReferenceCycles(self):
has_decorated_method = _HasDecoratedMethod()
has_decorated_method.f(constant_op.constant(5.))
weak_fn = weakref.ref(has_decorated_method.f)
del has_decorated_method
# Tests that the weak reference we made to the function is now dead, which
# means the object has been deleted. This should be true as long as the
# function itself is not involved in a reference cycle.
self.assertIs(None, weak_fn())
@test_util.assert_no_new_pyobjects_executing_eagerly()
def testErrorMessageWhenGraphTensorIsPassedToEager(self):
@polymorphic_function.function
def failing_function():
a = constant_op.constant(1.)
with ops.init_scope():
_ = a + a
with self.assertRaisesRegex(
TypeError,
re.compile('polymorphic_function_test.*out of scope', re.DOTALL)):
failing_function()
def testSymbolicTensorIllegalCaptureCallTimeError(self):
x = None
@polymorphic_function.function
def f1(a):
nonlocal x
x = a
return a
@polymorphic_function.function
def f2(b):
return b + x
f1(constant_op.constant(1))
with self.assertRaisesRegex(
TypeError,
re.compile('polymorphic_function_test.*out of scope', re.DOTALL)):
f2(constant_op.constant(2))
def testSymbolicTensorIllegalCaptureTraceTimeError(self):
@polymorphic_function.function
def f(inputs):
num_steps, _ = inputs.shape[:2]
outputs = []
for t in math_ops.range(num_steps):
outputs.append(inputs[t])
return outputs
with self.assertRaisesRegex(errors.InaccessibleTensorError, 'out of scope'):
f(array_ops.zeros(shape=(8, 42, 3)))
def testNonUniqueNamesGetConcreteFunction(self):
@polymorphic_function.function
def non_unique_arg_names(x, **kwargs):
a, b, c = x
d = kwargs['d']
return a + b + c + d
concrete = non_unique_arg_names.get_concrete_function(
(tensor_lib.TensorSpec(None, dtypes.float32),
tensor_lib.TensorSpec(None, dtypes.float32),
tensor_lib.TensorSpec(None, dtypes.float32)),
d=tensor_lib.TensorSpec(None, dtypes.float32))
self.assertAllClose(
10.,
concrete(x=constant_op.constant(1.),
x_1=constant_op.constant(2.),
x_2=constant_op.constant(3.),
d=constant_op.constant(4.)))
self.assertAllClose(
10.,
concrete(constant_op.constant(1.),
constant_op.constant(2.),
constant_op.constant(3.),
constant_op.constant(4.)))
def testDuplicatedSanitizedNames(self):
@polymorphic_function.function
def foo(**kwargs):
return kwargs['a_b'] + kwargs['a/b']
error_message = 'Name collision after sanitization.'
with self.assertRaisesRegex(ValueError, error_message):
foo(**{'a_b': 1, 'a/b': 2})
def testVariableCreatorScope(self):
created_variables = []
captured_variables = []
@polymorphic_function.function
def f():
if not created_variables:
created_variables.append(variables.Variable(1.))
return created_variables[0] + 1.
def capture_creator(next_creator, **kwargs):
created = next_creator(**kwargs)
captured_variables.append(created)
return created
with variable_scope.variable_creator_scope(capture_creator):
f()
self.assertEqual(created_variables, captured_variables)
def testVarAlreadyInitializedNoClobbering(self):
v_holder = []
@polymorphic_function.function
def add_var(x):
if not v_holder:
v = variables.Variable([1., 2.])
v_holder.append(v)
already_initialized = variables.Variable(3.)
with ops.init_scope():
already_initialized.assign(10.)
v_holder.append(already_initialized)
return v_holder[0] + v_holder[1] + x
add_var.get_concrete_function(constant_op.constant(2.))
self.assertAllClose([13., 14.], add_var(constant_op.constant(2.)))
def testSameVariableTwice(self):
v = variables.Variable(1.0)
@polymorphic_function.function
def add(a, b):
return a + b
self.assertAllEqual(add(v, v), 2.0)
def testSameVariableTwiceWithReducedRetracing(self):
v = variables.Variable(2.0)
@polymorphic_function.function(reduce_retracing=True)
def add(a, b):
return a + b
self.assertAllEqual(add(v, v), 4.0)
def testVariableUpdate(self):
v1 = variables.Variable(1.0)
v2 = variables.Variable(2.0)
v3 = variables.Variable(4, dtype=dtypes.int32)
trace_count = [0]
@polymorphic_function.function
def double_variable(x):
trace_count[0] += 1
x.assign_add(x.read_value())
self.assertEqual(trace_count[0], 0)
double_variable(v1)
self.assertEqual(trace_count[0], 1)
self.assertEqual(self.evaluate(v1), 2.0)
double_variable(v2)
# No retracing because v2's data type and shape are the same as v1
self.assertEqual(trace_count[0], 1)
self.assertEqual(self.evaluate(v2), 4.0)
double_variable(v3)
# Retracing because of data type change
self.assertEqual(trace_count[0], 2)
self.assertEqual(self.evaluate(v3), 8)
def testShapeCache(self):
@polymorphic_function.function
def func(x):
return 2 * x
func_a = func.get_concrete_function(
tensor_lib.TensorSpec([None], dtypes.int32))
func_b = func.get_concrete_function(
tensor_lib.TensorSpec([None], dtypes.int32))
self.assertIs(func_a, func_b)
def testCacheWithinSaveContext(self):
@polymorphic_function.function
def func(x):
return 2 * x
func_a = func.get_concrete_function(constant_op.constant(2.))
func_b = func.get_concrete_function(constant_op.constant(2.))
self.assertIs(func_a, func_b)
with save_context.save_context(
save_options.SaveOptions(experimental_variable_policy=save_options
.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES)):
func_c = func.get_concrete_function(constant_op.constant(2.))
with save_context.save_context(
save_options.SaveOptions(
experimental_variable_policy=save_options.VariablePolicy.NONE)):
func_d = func.get_concrete_function(constant_op.constant(2.))
self.assertIsNot(func_a, func_c)
self.assertIsNot(func_a, func_d)
def testInitializationInNestedCall(self):
v_holder = []
@polymorphic_function.function
def add_var(x):
if not v_holder:
v = variables.Variable([1., 2.])
v_holder.append(v)
already_initialized = variables.Variable(3.)
with ops.init_scope():
already_initialized.assign(10.)
v_holder.append(already_initialized)
return v_holder[0] + v_holder[1] + x
@polymorphic_function.function
def wrapper(x):
return add_var(x)
self.assertAllClose([13., 14.], wrapper(constant_op.constant(2.)))
v_holder[1].assign(11.)
self.assertAllClose([14., 15.], wrapper(constant_op.constant(2.)))
@test_util.run_gpu_only
def testDeviceAnnotationRespected(self):
a = []
@polymorphic_function.function()
def create_variable():
with ops.init_scope():
initial_value = random_ops.random_uniform(
(2, 2), maxval=1000000, dtype=dtypes.int64)
if not a:
with ops.device('CPU:0'):
a.append(resource_variable_ops.ResourceVariable(initial_value))
return a[0].read_value()
create_variable()
self.assertRegex(a[0].device, 'CPU')
@test_util.run_gpu_only
def testDeviceAnnotationForInitializerRespected(self):
a = []
initial_value = []
def initial_value_fn():
initial_value.append(random_ops.random_uniform((2, 3)))
return initial_value[0]
@polymorphic_function.function()
def create_variable():
with ops.init_scope():
if not a:
a.append(variables.Variable(initial_value_fn))
with ops.device('CPU:0'):
create_variable()
self.assertRegex(a[0].device, 'CPU')
self.assertRegex(initial_value[0].device, 'CPU')
def testDecorate(self):
func = polymorphic_function.function(lambda: 1)
def decorator(f):
return lambda: 1 + f()
func._decorate(decorator)
self.assertEqual(func().numpy(), 2)
@parameterized.parameters(*itertools.product(
(None, (tensor_lib.TensorSpec([]),)), # input_signature
(True, False), # autograph
(None, converter.Feature.ALL), # autograph_options
(None, 'foo.bar'), # implements
(None, True, False), # relax_shapes
(True, False), # compile
(True, False), # override_function
))
def testClone(self, input_signature, autograph, autograph_options, implements,
relax_shapes, compile_, override_function):
original_py_function = lambda x: x
compile_ = False
func = polymorphic_function.function(
func=original_py_function,
input_signature=input_signature,
autograph=autograph,
experimental_implements=implements,
experimental_autograph_options=autograph_options,
reduce_retracing=relax_shapes,
jit_compile=compile_)
if override_function:
cloned_py_function = lambda x: x + 1
else:
cloned_py_function = original_py_function
cloned = func._clone(python_function=cloned_py_function)
self.assertEqual(cloned_py_function, cloned._python_function)
self.assertEqual(func._name, cloned._name)
self.assertEqual(input_signature, cloned.input_signature)
self.assertEqual(autograph, cloned._autograph)
self.assertEqual(func._attributes, cloned._attributes)
self.assertEqual(autograph_options, cloned._experimental_autograph_options)
self.assertEqual(relax_shapes, cloned._reduce_retracing)
self.assertEqual(compile_, cloned._jit_compile)
# This test does not run with XLA JIT support linked in so we can only check
# the output of the function if compile is disabled.
if not compile_:
x = array_ops.zeros([])
self.assertEqual(self.evaluate(cloned(x)),
self.evaluate(cloned_py_function(x)))
def testLiftPlaceholderInitializedVariable(self):
with ops.Graph().as_default():
var_list = []
@polymorphic_function.function
def use_variable():
if not var_list:
initial_value = array_ops.placeholder(shape=[], dtype=dtypes.float32)
v = variables.Variable(initial_value)
var_list.append(v)
return var_list[0] + 1.
var_plus_one = use_variable()
with self.session() as session:
init_op = var_list[0].initializer
session.run(init_op, feed_dict={init_op.inputs[1]: 2.})
self.assertEqual(3., session.run(var_plus_one))
def testDecorate_rejectedAfterTrace(self):
func = polymorphic_function.function(lambda: 1)
self.assertEqual(func().numpy(), 1)
msg = 'Functions cannot be decorated after they have been traced.'
with self.assertRaisesRegex(ValueError, msg):
func._decorate(lambda f: f)
def testGetConcreteFunctionGraphLifetime(self):
@polymorphic_function.function
def func():
pass
graph = func.get_concrete_function().graph
del func
# If the graph is deleted, then an exception is raised on reading `captures`
self.assertEmpty(graph.captures)
@parameterized.parameters(*itertools.product(
(None, (tensor_lib.TensorSpec([]),)), # input_signature
(True, False), # autograph
(None, converter.Feature.ALL), # autograph_options
(None, 'foo.bar'), # implements
(None, True, False), # relax_shapes
))
def test_pickle(self, input_signature, autograph, autograph_options,
implements, relax_shapes):
"""@function objects can be pickled and unpickled."""
original_py_function = undecorated_function
func = polymorphic_function.function(
func=original_py_function,
input_signature=input_signature,
autograph=autograph,
experimental_implements=implements,
experimental_autograph_options=autograph_options,
reduce_retracing=relax_shapes,
)
cloned = pickle.loads(pickle.dumps(func))
self.assertEqual(func._name, cloned._name)
self.assertEqual(input_signature, cloned.input_signature)
self.assertEqual(autograph, cloned._autograph)
self.assertEqual(func._attributes, cloned._attributes)
self.assertEqual(autograph_options, cloned._experimental_autograph_options)
self.assertEqual(relax_shapes, cloned._reduce_retracing)
x = array_ops.ones([])
self.assertEqual(self.evaluate(cloned(x)), self.evaluate(func(x)))
def test_frequent_retracing_warning(self):
if sys.version_info[0] < 3:
self.skipTest('self.assertLogs() call is not available in Python 2.')
@polymorphic_function.function
def f(x):
return x
with self.assertLogs(level='WARN') as logs:
f(1)
f(2)
f(3)
f(4)
self.assertEmpty(logs.output)
f(5)
self.assertLen(logs.output, 1)
self.assertIn('Tracing is expensive', logs.output[0])
def test_frequent_retracing_warning_lambda(self):
if sys.version_info[0] < 3:
self.skipTest('self.assertLogs() call is not available in Python 2.')
f = polymorphic_function.function(lambda x: x)
with self.assertLogs(level='WARN') as logs:
f(1)
f(2)
f(3)
f(4)
f(5)
self.assertLen(logs.output, 1)
self.assertIn('Tracing is expensive', logs.output[0])
def test_frequent_retracing_warning_method(self):
if sys.version_info[0] < 3:
self.skipTest('self.assertLogs() call is not available in Python 2.')
class Foo:
@polymorphic_function.function
def f(self, x):
return x
f = Foo().f
with self.assertLogs(level='WARN') as logs:
f(1)
f(2)
f(3)
f(4)
f(5)
self.assertLen(logs.output, 1)
self.assertIn('Tracing is expensive', logs.output[0])
def test_frequent_retracing_warning_two_independent_tf_functions(self):
if sys.version_info[0] < 3:
self.skipTest('self.assertLogs() call is not available in Python 2.')
@polymorphic_function.function
def f(x):
return x
@polymorphic_function.function
def g(x):
return x
with self.assertLogs(level='WARN') as logs:
f(1)
f(2)
f(3)
f(4)
g(1)
g(2)
g(3)
g(4)
g(5)
self.assertLen(logs.output, 1)
self.assertIn('Tracing is expensive', logs.output[0])
def test_frequent_retracing_warning_nested(self):
if sys.version_info[0] < 3:
self.skipTest('self.assertLogs() call is not available in Python 2.')
@polymorphic_function.function
def inner(x):
return x + 1
@polymorphic_function.function
def outer1(x):
return inner(x) * 2
@polymorphic_function.function
def outer2(x):
return inner(x) * 3
with self.assertLogs(level='WARN') as logs:
inner(1)
inner(2)
inner(3)
inner(4)
outer1(5)
outer1(6)
outer1(7)
outer1(8)
outer2(9)
outer2(10)
outer2(11)
outer2(12)
self.assertEmpty(logs.output)
outer2(13)
self.assertLen(logs.output, 1)
self.assertIn('Tracing is expensive', logs.output[0])
def test_frequent_retracing_warning_on_reinstantiation(self):
if sys.version_info[0] < 3:
self.skipTest('self.assertLogs() call is not available in Python 2.')
with self.assertLogs(level='WARN') as logs:
for i in range(5):
@polymorphic_function.function
def f(x):
return x
f(i)
if i < 4:
self.assertEmpty(logs.output)
self.assertLen(logs.output, 1)
self.assertIn('Tracing is expensive', logs.output[0])
def test_restored_function_retracing_warning(self):
class Foo(Checkpoint):
@polymorphic_function.function
def __call__(self, x):
return x
f_flexible = Foo()
_ = f_flexible.__call__.get_concrete_function(
tensor_lib.TensorSpec(shape=[None], dtype=dtypes.int32))
tmp_dir = self.create_tempdir()
save(f_flexible, tmp_dir.full_path)
restored_f_flexible = load(tmp_dir.full_path)
f_fixed_shape = Foo()
with self.assertLogs(level='WARN') as logs:
restored_f_flexible(constant_op.constant([1], dtypes.int32))
restored_f_flexible(constant_op.constant([1, 2], dtypes.int32))
restored_f_flexible(constant_op.constant([1, 2, 3], dtypes.int32))
restored_f_flexible(constant_op.constant([1, 2, 3, 4], dtypes.int32))
restored_f_flexible(constant_op.constant([1, 2, 3, 4, 5], dtypes.int32))
self.assertEmpty(logs.output)
f_fixed_shape(constant_op.constant([1], dtypes.int32))
f_fixed_shape(constant_op.constant([1, 2], dtypes.int32))
f_fixed_shape(constant_op.constant([1, 2, 3], dtypes.int32))
f_fixed_shape(constant_op.constant([1, 2, 3, 4], dtypes.int32))
f_fixed_shape(constant_op.constant([1, 2, 3, 4, 5], dtypes.int32))
self.assertLen(logs.output, 1)
self.assertIn('Tracing is expensive', logs.output[0])
def test_retracing_warning_limits(self):
@polymorphic_function.function
def my_func(x):
return x
with self.assertLogs(level='WARN') as logs:
for i in range(10):
my_func(i)
self.assertLen(logs.output, 2)
def test_experimental_get_tracing_count_function(self):
@polymorphic_function.function
def double(a):
return a + a
double(constant_op.constant(1))
double(constant_op.constant(2))
self.assertAllEqual(double.experimental_get_tracing_count(), 1)
double(constant_op.constant('a'))
self.assertAllEqual(double.experimental_get_tracing_count(), 2)
def test_experimental_get_tracing_count_method(self):
class TestClass():
@polymorphic_function.function
def testDouble(self, a):
return a + a
obj1 = TestClass()
obj1.testDouble(constant_op.constant(1))
obj1.testDouble(constant_op.constant(2))
obj1.testDouble(constant_op.constant(1.1))
self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
obj2 = TestClass()
obj2.testDouble(constant_op.constant(1))
obj2.testDouble(constant_op.constant(1.1))
obj2.testDouble(constant_op.constant('a'))
self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3)
self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
def test_tensor_shape_casted_to_specific(self):
@polymorphic_function.function(
input_signature=[tensor_lib.TensorSpec([1])]
)
def specific(x):
self.assertEqual(x.shape, [1])
return x
@polymorphic_function.function(
input_signature=[tensor_lib.TensorSpec(None)]
)
def general(x):
return specific(x)
self.assertEqual(general(constant_op.constant([1.0])).numpy(), 1.0)
def test_recursive_tf_function(self):
@polymorphic_function.function
def recursive_fn(n):
if n > 0:
return recursive_fn(n - 1)
return 1
self.assertEqual(recursive_fn(5).numpy(), 1)
def test_recursive_tf_function_with_gradients(self):
@polymorphic_function.function
def recursive_fn(n, x):
if n > 0:
return n * recursive_fn(n - 1, x)
else:
return x
x = variables.Variable(1.0)
with backprop.GradientTape() as tape:
g = recursive_fn(5, x)
dg_dx = tape.gradient(g, x)
self.assertEqual(dg_dx.numpy(), 120)
def test_recursive_python_function(self):
def recursive_py_fn(n):
if n > 0:
return recursive_py_fn(n - 1)
return 1
@polymorphic_function.function
def recursive_fn(n):
return recursive_py_fn(n)
self.assertEqual(recursive_fn(5).numpy(), 1)
def test_recursive_python_function_with_gradients(self):
def recursive_py_fn(n, x):
if n > 0:
return n * recursive_py_fn(n - 1, x)
return x
@polymorphic_function.function
def recursive_fn(n, x):
return recursive_py_fn(n, x)
x = variables.Variable(1.0)
with backprop.GradientTape() as tape:
g = recursive_fn(5, x)
dg_dx = tape.gradient(g, x)
self.assertEqual(dg_dx.numpy(), 120)
def test_recursive_tf_function_call_each_other(self):
@polymorphic_function.function
def recursive_fn1(n):
if n <= 1:
return 1
return recursive_fn2(n - 1)
@polymorphic_function.function
def recursive_fn2(n):
if n <= 1:
return 2
return recursive_fn1(n - 1)
self.assertEqual(recursive_fn1(5).numpy(), 1)
self.assertEqual(recursive_fn1(6).numpy(), 2)
self.assertEqual(recursive_fn2(5).numpy(), 2)
self.assertEqual(recursive_fn2(6).numpy(), 1)
def test_recursive_tf_function_call_each_other_with_gradients(self):
@polymorphic_function.function
def recursive_fn1(n, x):
if n <= 1:
return x
return n * recursive_fn2(n - 1, x)
@polymorphic_function.function
def recursive_fn2(n, x):
if n <= 1:
return 2 * x
return n * recursive_fn1(n - 1, x)
x = variables.Variable(1.0)
with backprop.GradientTape() as tape:
g1 = recursive_fn1(5, x)
dg1_dx = tape.gradient(g1, x)
self.assertEqual(dg1_dx.numpy(), 120)
with backprop.GradientTape() as tape:
g2 = recursive_fn2(5, x)
dg2_dx = tape.gradient(g2, x)
self.assertEqual(dg2_dx.numpy(), 240)
def test_recursive_tf_function_with_cond(self):
@polymorphic_function.function(autograph=False)
def recursive_fn(n):
return cond_v2.cond_v2(n > 0, recursive_fn(n - 1), 1)
with self.assertRaises(RecursionError):
recursive_fn(constant_op.constant(5))
class MultiDeviceTest(test.TestCase, parameterized.TestCase):
def testNestedCallWatchedVariables(self):
v = variables.Variable(4.)
@polymorphic_function.function
def f():
return v**2.
with backprop.GradientTape() as tape:
f()
self.assertEqual((v,), tape.watched_variables())
@polymorphic_function.function
def g():
return f()
with backprop.GradientTape() as tape:
g()
self.assertEqual((v,), tape.watched_variables())
# f() can rely on the variable being read during its trace. g() checks that
# variables from a function which knows about them are recorded on the
# tape. h() tests that functions forward knowledge of variables to callers.
@polymorphic_function.function
def h():
return g()
with backprop.GradientTape() as tape:
h()
self.assertEqual((v,), tape.watched_variables())
def testReplaceCaptureWithDeferred(self):
x = constant_op.constant(1.0)
y = constant_op.constant(2.0)
z = constant_op.constant(3.0)
@polymorphic_function.function
def fn():
a = x + y
b = a + z
return b
concrete_fn = fn.get_concrete_function()
self.assertAllEqual(concrete_fn(), 6.0)
value = constant_op.constant(4.0)
def closure():
return value
concrete_fn.replace_capture_with_deferred_capture(
concrete_fn.captured_inputs[1],
closure,
spec=tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32),
placeholder=concrete_fn.inputs[1])
self.assertAllEqual(concrete_fn(), 8.0)
value = constant_op.constant(5.0)
self.assertAllEqual(concrete_fn(), 9.0)
def testRaiseReplaceCaptureWithDeferredTypeSpecMismatch(self):
bool_captured_tensor = constant_op.constant(True)
float_captured_tensor = constant_op.constant([3.], dtype=dtypes.float32)
value = constant_op.constant([2.], dtype=dtypes.float32)
@polymorphic_function.function
def fn():
deferred_tensor = ops.get_default_graph().capture_call_time_value(
lambda: value,
tensor_lib.TensorSpec(shape=(1,), dtype=dtypes.float32))
if bool_captured_tensor:
return deferred_tensor
else:
return deferred_tensor + float_captured_tensor
concrete_fn = fn.get_concrete_function()
self.assertAllEqual(concrete_fn(), [2.])
new_bool_captured_tensor = constant_op.constant(False)
def bool_closure():
return new_bool_captured_tensor
# Test raise if replacing a bool capture with a closure of output type
# float32
new_float_captured_tensor = constant_op.constant([3.], dtype=dtypes.float32)
def float_closure():
return new_float_captured_tensor
with self.assertRaisesRegex(ValueError,
'Attempting to substitute closure with spec*'):
concrete_fn.replace_capture_with_deferred_capture(
bool_captured_tensor,
float_closure,
spec=tensor_lib.TensorSpec(shape=(1,), dtype=dtypes.float32))
# Test replace without a placeholder
concrete_fn.replace_capture_with_deferred_capture(
bool_captured_tensor,
bool_closure,
spec=tensor_lib.TensorSpec(shape=(), dtype=dtypes.bool))
self.assertAllEqual(concrete_fn(), [5.])
def testConcreteFunctionSetExternalCapture(self):
captured_tensor = constant_op.constant([1.])
value = constant_op.constant([2.])
@polymorphic_function.function
def fn():
deferred_tensor = ops.get_default_graph().capture_call_time_value(
lambda: value,
tensor_lib.TensorSpec(shape=(1,), dtype=dtypes.float32))
return deferred_tensor + captured_tensor
cf = fn.get_concrete_function()
self.assertLen(cf._captured_inputs, 2)
self.assertEqual(list(map(callable, cf._captured_inputs)), [False, True])
self.assertAllEqual(cf(), [3.])
# Reset capture to a deferred one, reset deferred capture to a capture.
cf.set_external_captures([cf._captured_inputs[1], cf._captured_inputs[0]])
value = constant_op.constant([3.])
self.assertAllEqual(cf(), [4.])
def testGraphReplaceCaptureAndSetExternalCapture(self):
bool_captured_tensor = constant_op.constant(True)
float_captured_tensor = constant_op.constant([3.], dtype=dtypes.float32)
value = constant_op.constant([2.], dtype=dtypes.float32)
@polymorphic_function.function
def fn():
deferred_tensor = ops.get_default_graph().capture_call_time_value(
lambda: value,
tensor_lib.TensorSpec(shape=(1,), dtype=dtypes.float32))
if bool_captured_tensor:
return deferred_tensor
else:
return deferred_tensor + float_captured_tensor
concrete_fn = fn.get_concrete_function()
self.assertAllEqual(concrete_fn(), [2.])
new_bool_captured_tensor = constant_op.constant(False)
def closure():
return new_bool_captured_tensor
concrete_fn.graph.replace_capture_with_deferred_capture(
concrete_fn.captured_inputs[0],
closure,
spec=tensor_lib.TensorSpec(shape=(), dtype=dtypes.bool),
placeholder=concrete_fn.inputs[1])
concrete_fn.set_external_captures([
closure, concrete_fn._captured_inputs[1],
concrete_fn._captured_inputs[2]
])
self.assertAllEqual(concrete_fn(), [5.])
def testDeferredCapture(self):
value = 1.0
@polymorphic_function.function
def lazy_capture(x):
y = ops.get_default_graph().capture_call_time_value(
lambda: value, tensor_lib.TensorSpec(None))
return x + y
self.assertAllEqual(lazy_capture(2.0), 3.0)
# After changing the value of `value` the function call should return a
# different result.
value = 2.0
self.assertAllEqual(lazy_capture(2.0), 4.0)
def testNestedDeferredCapture(self):
value = 1.0
@polymorphic_function.function
def inner(x):
y = ops.get_default_graph().capture_call_time_value(
lambda: value, tensor_lib.TensorSpec(None))
return x + y
@polymorphic_function.function
def outer(x):
return inner(x)
self.assertAllEqual(outer(2.0), 3.0)
# After changing the value of `value` the function call should return a
# different result.
value = 2.0
self.assertAllEqual(outer(2.0), 4.0)
def testNestedDeferredCaptureInTFWhileLoop(self):
value = 1.
@polymorphic_function.function
def inner(x):
y = ops.get_default_graph().capture_call_time_value(
lambda: value, tensor_lib.TensorSpec(None))
return x + y
@polymorphic_function.function
def outer():
dummy = constant_op.constant(True)
sums = constant_op.constant(0.)
while dummy:
directives.set_loop_options(
shape_invariants=[(sums, tensor_shape.TensorShape(None))])
sums += inner(2.)
dummy = constant_op.constant(False)
return sums
self.assertAllEqual(outer(), 3.)
value = constant_op.constant(2.)
self.assertAllEqual(outer(), 4.)
value = constant_op.constant(3.)
self.assertAllEqual(outer(), 5.)
def testDeferredCaptureWithKey(self):
value0 = 1.0
value1 = 2.0
@polymorphic_function.function
def lazy_capture(x):
w = ops.get_default_graph().capture_call_time_value(
lambda: value0, tensor_lib.TensorSpec(None), key=0)
y = ops.get_default_graph().capture_call_time_value(
lambda: value1, tensor_lib.TensorSpec(None), key=1)
def bad_closure():
raise ValueError('Should not run')
z = ops.get_default_graph().capture_call_time_value(
bad_closure, tensor_lib.TensorSpec(None), key=1)
return x + y + w + z
self.assertAllEqual(lazy_capture(2.0), 7.0)
value0 = 2.0
value1 = 3.0
self.assertAllEqual(lazy_capture(2.0), 10.0)
def testDeferredCaptureTypeError(self):
value = constant_op.constant(1.0)
@polymorphic_function.function
def lazy_capture(x):
y = ops.get_default_graph().capture_call_time_value(
lambda: value, tensor_lib.TensorSpec(()))
return x + y
self.assertAllEqual(lazy_capture(2.0), 3.0)
# dtype mismatch
value = constant_op.constant(1)
with self.assertRaisesRegex(TypeError, 'Can not cast Tensor'):
lazy_capture(2.0)
# shape mismatch
value = constant_op.constant([1.0])
with self.assertRaisesRegex(TypeError, 'Can not cast'):
lazy_capture(2.0)
def testDeferredCaptureReturnNestWithCompositeTensor(self):
i_s = indexed_slices.IndexedSlices(
constant_op.constant([1, 2]),
constant_op.constant([0, 1], dtype=dtypes.int64),
constant_op.constant([2]))
r_t = ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5, 6]]])
s_t = sparse_tensor.SparseTensor(
values=[1, 2, 3], indices=[[0], [8], [10]], dense_shape=[20])
@polymorphic_function.function
def lazy_capture():
y = ops.get_default_graph().capture_call_time_value(
lambda: {'i': i_s, 't': (r_t, s_t)},
{'i': indexed_slices.IndexedSlicesSpec(
dtype=dtypes.int32, dense_shape_dtype=dtypes.int32),
't': (ragged_tensor.RaggedTensorSpec([2, None, None], dtypes.int32),
sparse_tensor.SparseTensorSpec([None], dtypes.int32))})
return y['i'], y['t']
i, (r, s) = lazy_capture()
self.assertAllEqual(i_s.values, i.values)
self.assertAllEqual(i_s.indices, i.indices)
self.assertAllEqual(i_s.dense_shape, i.dense_shape)
self.assertAllEqual(r_t, r)
self.assertAllEqual(s_t.indices, s.indices)
self.assertAllEqual(s_t.values, s.values)
self.assertAllEqual(s_t.dense_shape, s.dense_shape)
def testDeferredCaptureCompositeTensorSpecTypeMismatch(self):
value = indexed_slices.IndexedSlices(
constant_op.constant([1, 2]),
constant_op.constant([0, 1], dtype=dtypes.int64))
@polymorphic_function.function
def lazy_capture():
return ops.get_default_graph().capture_call_time_value(
lambda: value, indexed_slices.IndexedSlicesSpec(dtype=dtypes.int32))
# Type matches spec.
lazy_capture()
# Extra dense shape component.
value = indexed_slices.IndexedSlices(
constant_op.constant([1, 2]),
constant_op.constant([0, 1], dtype=dtypes.int64),
constant_op.constant([2]))
with self.assertRaises(ValueError):
lazy_capture()
# Index dtype mismatch int32 vs. int64.
value = indexed_slices.IndexedSlices(
constant_op.constant([1, 2]), constant_op.constant([0, 1]))
with self.assertRaises(TypeError):
lazy_capture()
@parameterized.parameters(
(1, int, 2, int, 1),
(1, constant_op.constant, 2, constant_op.constant, 1))
def testRetraceLogicWithSideInputs(self, val_before, type_before, val_after,
type_after, expected_len):
@polymorphic_function.function
def f():
func = lambda: x
return ops.get_default_graph()._experimental_capture_side_input_by_ref( # pylint: disable=protected-access
'lambda: x', func)
x = type_before(val_before)
_ = f()
x = type_after(val_after)
_ = f()
self.assertLen(total_function_cache(f), expected_len)
def testByRefCaptureWithInputSignature(self):
@polymorphic_function.function(input_signature=[])
def f():
func = lambda: x
return ops.get_default_graph()._experimental_capture_side_input_by_ref( # pylint: disable=protected-access
'lambda: x', func)
x = 1
_ = f()
x = 2
_ = f()
self.assertLen(total_function_cache(f), 1)
def testFunctoolsLruCache(self):
self.skipTest(
"b/194845243: inspect.getfullargspec doesn't unwrap Python decorators.")
@polymorphic_function.function
@functools.lru_cache(maxsize=2)
def f(a):
return 2 * a
self.assertAllEqual(f(1), array_ops.constant(2))
def testGraphRemoveFunction(self):
@polymorphic_function.function
def g(x):
return x + 1
@polymorphic_function.function
def f(x):
return g(x)
graph = f.get_concrete_function(constant_op.constant(1)).graph
graph_def = graph.as_graph_def()
func_name = graph_def.library.function[0].signature.name
self.assertLen(graph_def.library.function, 1)
self.assertTrue(graph._is_function(func_name))
graph._remove_function(func_name)
updated_graph_def = graph.as_graph_def()
self.assertEmpty(updated_graph_def.library.function)
self.assertFalse(graph._is_function(func_name))
with self.assertRaisesRegex(ValueError, 'not found'):
graph._remove_function(func_name)
def testInputAndOutputDataclass(self):
@polymorphic_function.function
def f(x):
return x
mt = MaskedTensor(mask=True, value=constant_op.constant([1.0]))
result = f(mt)
self.assertEqual(result.mask, mt.mask)
self.assertAllEqual(result.value, mt.value)
def testInputAndOutputNestedDataclass(self):
@polymorphic_function.function
def f(x):
return x
mt = MaskedTensor(mask=True, value=constant_op.constant([1.0]))
mt2 = MaskedTensor(mask=False, value=constant_op.constant([2.0]))
mtp = MaskedTensorPair(masks=[True, False], value1=mt, value2=mt2)
result = f(mtp)
self.assertEqual(result.masks, mtp.masks)
self.assertEqual(result.value1.mask, mt.mask)
self.assertAllEqual(result.value1.value, mt.value)
self.assertEqual(result.value2.mask, mt2.mask)
self.assertAllEqual(result.value2.value, mt2.value)
def testInputAndCreatNewDataclass(self):
@polymorphic_function.function
def f(x, y):
return MaskedTensor(mask=x.mask, value=y.value)
mt = MaskedTensor(mask=False, value=constant_op.constant([1.0]))
mt2 = MaskedTensor(mask=True, value=constant_op.constant([2.0]))
result = f(mt, mt2)
self.assertEqual(result.mask, mt.mask)
self.assertAllEqual(result.value, mt2.value)
def testDataclassWithUnhashableMetadata(self):
@polymorphic_function.function
def f(x, y):
return MaskedTensorPair(
masks=x.masks + y.masks, value1=x.value1, value2=y.value2
)
mt = MaskedTensor(mask=False, value=constant_op.constant([1.0]))
mt2 = MaskedTensor(mask=True, value=constant_op.constant([2.0]))
mtp = MaskedTensorPair(masks=[True, True], value1=mt, value2=mt2)
mt3 = MaskedTensor(mask=False, value=constant_op.constant([3.0]))
mt4 = MaskedTensor(mask=True, value=constant_op.constant([4.0]))
mtp2 = MaskedTensorPair(masks=[False, False], value1=mt3, value2=mt4)
result = f(mtp, mtp2)
self.assertEqual(result.masks, mtp.masks + mtp2.masks)
self.assertEqual(result.value1.mask, mt.mask)
self.assertAllEqual(result.value1.value, mt.value)
self.assertEqual(result.value2.mask, mt4.mask)
self.assertAllEqual(result.value2.value, mt4.value)
def testDataClassWithSubTraceType(self):
@polymorphic_function.function
def f(x):
return x
mt = MaskedTensor(mask=True, value=constant_op.constant([1.0]))
mt2 = MaskedTensor(mask=True, value=constant_op.constant([2.0]))
f1 = f.get_concrete_function(mt)
f2 = f.get_concrete_function(mt2)
# mt2's TraceType is the same as mt1, so it doesn't need retrace
self.assertIs(f1, f2)
mt3 = MaskedTensor(
mask=False,
value=tensor_lib.TensorSpec(shape=[None, None], dtype=dtypes.int32),
)
f3 = f.get_concrete_function(mt3)
self.assertIsNot(f1, f3)
mt4 = MaskedTensor(
mask=False,
value=constant_op.constant(
[[1], [2]], shape=[2, 1], dtype=dtypes.int32
),
)
f4 = f.get_concrete_function(mt4)
# mt4's TraceType can be matched by mt3's spec, so it doesn't need retrace
self.assertIs(f3, f4)
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()