tensorflow/python/eager/polymorphic_function/tracing_compilation_test.py
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import copy
import functools
import itertools
import weakref
from absl.testing import parameterized
import numpy
from tensorflow.core.function.capture import capture_container
from tensorflow.core.function.polymorphism import function_cache as function_cache_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager.polymorphic_function import function_type_utils
from tensorflow.python.eager.polymorphic_function import tracing_compilation
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor as tensor_lib
from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops import while_loop
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
from tensorflow.python.saved_model.load import load
from tensorflow.python.saved_model.save import save
from tensorflow.python.util import compat
from tensorflow.python.util import nest
try:
import attr # pylint:disable=g-import-not-at-top
except ImportError:
attr = None
def compiled_fn(fn=None, **tracing_options):
"""Decorator that compiles/calls wrapped function."""
if fn is None:
return functools.partial(compiled_fn, **tracing_options)
signature = tracing_options.pop('input_signature', None)
function_type, default_values = function_type_utils.make_function_type(
fn, signature
)
tracing_options['polymorphic_type'] = function_type
tracing_options['default_values'] = default_values
def wrapped(*args, **kwargs):
bound_args = function_type.bind_with_defaults(args, kwargs, default_values)
return tracing_compilation.call_function(
bound_args.args,
bound_args.kwargs,
tracing_compilation.TracingOptions(fn, **tracing_options),
)
def trace(*args, **kwargs):
return tracing_compilation.trace_function(
args,
kwargs,
tracing_compilation.TracingOptions(fn, **tracing_options),
)
wrapped.get_concrete_function = trace
return wrapped
class TracingCompilationTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def testBackwardNoneGradient(self):
model = variables.Variable(1.0, name='model')
count = variables.Variable(0)
@compiled_fn
def forward_pass(value):
count.assign_add(1)
residuals = value - model
loss = 0.5 * math_ops.reduce_mean(math_ops.pow(residuals, 2))
# Note: count is an integer, so its doutput will be None
return loss, count
def reduce_fn(x):
if context.executing_eagerly():
with backprop.GradientTape() as t:
loss, count = forward_pass(x)
return t.gradient(loss, model), count
loss, count = forward_pass(x)
grad_only = gradients_impl.gradients(loss, model)
return grad_only, count
g, _ = reduce_fn(constant_op.constant([7.0]))
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual(nest.flatten(self.evaluate(g)), [-6.0])
def testExternalControlDependency(self):
with ops.Graph().as_default(), self.test_session():
v = variables.Variable(1.0)
v.initializer.run()
op = v.assign_add(1.0)
@compiled_fn
def f():
with ops.control_dependencies([op]):
return 1.0
self.evaluate(f())
self.assertAllEqual(self.evaluate(v), 2.0)
def testInputShapeFunctionRelaxation(self):
unknown_dim = [False]
function_cache = function_cache_lib.FunctionCache()
@compiled_fn(reduce_retracing=True, function_cache=function_cache)
def func(a):
if a._shape_tuple()[0] is None:
unknown_dim[0] = True
return a + 1
func(constant_op.constant([]))
self.assertFalse(unknown_dim[0])
self.assertLen(function_cache, 1)
func(constant_op.constant([1.0]))
self.assertTrue(unknown_dim[0])
self.assertLen(function_cache, 2)
func(constant_op.constant([1.0, 2.0]))
self.assertTrue(unknown_dim[0])
def testNestedInputShapeFunctionRelaxation(self):
unknown_dim = [False]
function_cache = function_cache_lib.FunctionCache()
@compiled_fn(reduce_retracing=True, function_cache=function_cache)
def func(a_, b_=None):
del a_ # Only used to check which cache is used.
self.assertEqual(b_[0]._shape_tuple(), ())
if b_[1]._shape_tuple()[0] is None:
unknown_dim[0] = True
return b_[0] + 1
a = 'hi'
b0 = constant_op.constant(1.0)
func(a, b_=[b0, constant_op.constant([])])
self.assertFalse(unknown_dim[0])
self.assertLen(function_cache, 1)
func(a, b_=[b0, constant_op.constant([1.0])])
self.assertTrue(unknown_dim[0])
self.assertLen(function_cache, 2)
func(a, b_=[b0, constant_op.constant([1.0, 1.0])])
self.assertTrue(unknown_dim[0])
self.assertLen(function_cache, 2)
unknown_dim[0] = False
# Now do the same except with a new a which is not a tensor; this should
# change the cache key.
a = 'bye'
func(a, b_=[b0, constant_op.constant([])])
self.assertFalse(unknown_dim[0])
self.assertLen(function_cache, 3)
# We relax the type traced previously.
func(a, b_=[b0, constant_op.constant([1.0])])
self.assertTrue(unknown_dim[0])
self.assertLen(function_cache, 4)
@test_util.run_v2_only
def testGraphEagerIsolation(self):
def f_py():
self.v = variables.Variable(1.0)
return self.v.read_value()
f = lambda: tracing_compilation.call_function( # pylint: disable=g-long-lambda
tracing_options=tracing_compilation.TracingOptions(f_py, 'f')
)
self.assertAllEqual(f(), 1.0)
with ops.Graph().as_default():
self.assertEqual(f().shape, ())
@test_util.run_v2_only
def testCompilationNumpyArraysConvertedToTensors(self):
def f(x):
self.assertIsInstance(x, tensor_lib.Tensor)
return x
x = random_ops.random_uniform([2, 2]).numpy()
function_cache = function_cache_lib.FunctionCache()
defined = compiled_fn(f, function_cache=function_cache)
defined(x)
self.assertLen(function_cache, 1)
x = random_ops.random_uniform([2, 2]).numpy()
defined(x)
# A NumPy array with different values but the same shape and dtype
# shouldn't trigger another function definition.
self.assertLen(function_cache, 1)
np_ones = numpy.ones([], numpy.float32)
np_zeros = numpy.zeros([], numpy.float32)
tf_ones = array_ops.ones([])
tf_zeros = array_ops.zeros([])
# Test that the numpy array is properly an argument to the graph function.
self.assertEqual(1.0, defined(np_ones).numpy())
self.assertLen(function_cache, 2)
self.assertEqual(0.0, defined(np_zeros).numpy())
self.assertEqual(1.0, defined(tf_ones).numpy())
self.assertEqual(0.0, defined(tf_zeros).numpy())
self.assertLen(function_cache, 2)
# Test that mutable inputs are supported.
mutable = numpy.ones([], numpy.float32)
self.assertEqual(1.0, defined(mutable).numpy())
mutable.fill(0)
self.assertEqual(0.0, defined(mutable).numpy())
class MyNdarray(numpy.ndarray):
pass
# Test that the subclasses of ndarray are converted too.
self.assertEqual(1.0, defined(np_ones.view(MyNdarray)).numpy())
self.assertEqual(0.0, defined(np_zeros.view(MyNdarray)).numpy())
# We should not have triggered any re-tracing of the python function.
self.assertLen(function_cache, 2)
@test_util.run_v2_only
def testNumpyDtypeInputSupported(self):
@compiled_fn
def f(x, dtype):
return constant_op.constant(dtype(x))
self.assertEqual(f(1, numpy.float32).numpy(), numpy.float32(1))
self.assertEqual(f(2, numpy.float32).numpy(), numpy.float32(2))
self.assertEqual(f(1, numpy.int32).numpy(), numpy.int32(1))
self.assertEqual(f(2, numpy.int32).numpy(), numpy.int32(2))
@test_util.run_v2_only
def testCompilationNumpyArraysConvertedToTensorsInKwargs(self):
def f(**kwargs):
x = kwargs.pop('x')
self.assertIsInstance(x, tensor_lib.Tensor)
return x
x = random_ops.random_uniform([2, 2]).numpy()
function_cache = function_cache_lib.FunctionCache()
defined = compiled_fn(f, function_cache=function_cache)
defined(x=x)
self.assertLen(function_cache, 1)
x = random_ops.random_uniform([2, 2]).numpy()
defined(x=x)
# A NumPy array with different values but the same shape and dtype
# shouldn't trigger another function definition.
self.assertLen(function_cache, 1)
# Test that the numpy array is properly an argument to the graph function.
self.assertEqual(1.0, defined(x=numpy.ones([])).numpy())
self.assertEqual(0.0, defined(x=numpy.zeros([])).numpy())
self.assertEqual(1.0, defined(x=array_ops.ones([])).numpy())
self.assertEqual(0.0, defined(x=array_ops.zeros([])).numpy())
@test_util.run_v2_only
def testFuncListAttr(self):
@compiled_fn
def test_function(val):
def fn1():
return array_ops.ones([10])
fn2 = lambda: array_ops.ones([10]) * 2
def fn3(x=3):
return array_ops.ones([10]) * x
fn4 = functools.partial(fn3, x=4)
fn5 = functools.partial(fn3, 5)
return gen_functional_ops.case(
val,
[],
[dtypes.float32],
[
compiled_fn(f).get_concrete_function()
for f in (fn1, fn2, fn3, fn4, fn5)
],
)
ones = array_ops.ones([10])
self.assertAllEqual([ones], test_function(0))
self.assertAllEqual([ones * 2], test_function(1))
self.assertAllEqual([ones * 3], test_function(2))
self.assertAllEqual([ones * 4], test_function(3))
self.assertAllEqual([ones * 5], test_function(4))
self.assertAllEqual([ones * 5], test_function(22)) # default branch
@test_util.enable_control_flow_v2
def testVariableInLoopInFunction(self):
def test_function_py():
def loop_test(_):
return False
def loop_body(_):
return variable_scope.get_variable('a', shape=())
return while_loop.while_loop(loop_test, loop_body, [0.0])
test_function = tracing_compilation.trace_function(
tracing_options=tracing_compilation.TracingOptions(
test_function_py, 'test_function'
)
)
self.assertEqual(test_function().shape, [])
@test_util.run_in_graph_and_eager_modes
def testCompilationForcesResourceVariables(self):
def variable_creator():
self.v = variables.Variable(0.0)
return self.v.read_value()
defined = tracing_compilation.trace_function(
tracing_options=tracing_compilation.TracingOptions(
variable_creator, 'variable_creator'
)
)
defined() # Create the variable.
self.assertIsInstance(self.v, resource_variable_ops.ResourceVariable)
@test_util.run_gpu_only
@test_util.run_in_graph_and_eager_modes
def testFunctionWithResourcesOnDifferentDevices(self):
with ops.device('/cpu:0'):
v_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
with ops.device('/gpu:0'):
v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
def sum_gather():
cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu, [1, 2]))
gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2]))
return cpu_result, gpu_result
defined = compiled_fn(sum_gather)
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
expected = self.evaluate(sum_gather())
self.assertAllEqual(expected, self.evaluate(defined()))
@test_util.assert_no_new_pyobjects_executing_eagerly()
def testCallOptionsMemory(self):
@compiled_fn
def model(x):
return x + constant_op.constant(1.0)
# This happens with a lot of option toggles, e.g. soft device placement
context.context().function_call_options = None
model(constant_op.constant(2.0))
@test_util.run_in_graph_and_eager_modes
def testVariablesPlacedOnOutsideDevice(self):
class _Obj(object):
def __init__(self):
self.v = None
@compiled_fn
def f(self):
if self.v is None:
self.v = variables.Variable(1.0)
return self.v + 1.0
has_device = _Obj()
with ops.device('cpu:0'):
has_device.f()
self.assertIn('CPU', has_device.v.device)
def testCacheObjectHashCollisions(self):
class Foo:
def __hash__(self):
return 42
def func(foo):
return constant_op.constant([id(foo)])
function_cache = function_cache_lib.FunctionCache()
defined = compiled_fn(func, function_cache=function_cache)
foo_1 = Foo()
defined(foo_1)
self.assertLen(function_cache, 1)
foo_2 = Foo()
defined(foo_2)
self.assertLen(function_cache, 2)
def testCacheTensorDtypeCollision(self):
def func(t):
return t + t
function_cache = function_cache_lib.FunctionCache()
defined = compiled_fn(func, function_cache=function_cache)
t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
defined(t)
self.assertLen(function_cache, 1)
t = constant_op.constant([[1.0]], dtype=dtypes.complex128)
defined(t)
self.assertLen(function_cache, 2)
def testCacheTensorShapeCollision(self):
def func(t):
return t + t
function_cache = function_cache_lib.FunctionCache()
defined = compiled_fn(func, function_cache=function_cache)
t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
defined(t)
self.assertLen(function_cache, 1)
t = constant_op.constant([1.0], dtype=dtypes.complex64)
defined(t)
self.assertLen(function_cache, 2)
def testCacheTensorShapeDtypeCollision(self):
def func(t):
return t + t
function_cache = function_cache_lib.FunctionCache()
defined = compiled_fn(func, function_cache=function_cache)
t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
defined(t)
self.assertLen(function_cache, 1)
t = constant_op.constant([1.0], dtype=dtypes.complex128)
defined(t)
self.assertLen(function_cache, 2)
def testCacheTensorUnknownShapesCollisionRelaxedShapes(self):
def func(t):
return t + t
with context.graph_mode(), self.cached_session():
function_cache = function_cache_lib.FunctionCache()
defined = compiled_fn(
func, reduce_retracing=True, function_cache=function_cache
)
p = array_ops.placeholder(dtype=dtypes.float32, shape=[])
defined(p)
self.assertLen(function_cache, 1)
p = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
defined(p)
self.assertLen(function_cache, 2)
p = array_ops.placeholder(dtype=dtypes.float32, shape=[2])
defined(p)
# Gradual shape relaxation is performed; and the common shape between
# [1] and [2] is one containing unknown dimensions.
self.assertLen(function_cache, 2)
t = constant_op.constant([1.0, 1.0, 1.0], dtype=dtypes.float32)
defined(t)
# Shape (3,) matches the relaxed shape TensorShape([None])
self.assertLen(function_cache, 2)
def testPythonFunctionWithDefaultArgs(self):
def func(foo, bar=1, baz=2):
del foo
del bar
del baz
return
function_cache = function_cache_lib.FunctionCache()
defined = compiled_fn(func, function_cache=function_cache)
defined(0, baz=20)
self.assertLen(function_cache, 1)
defined(1) # bar=1, baz=2
self.assertLen(function_cache, 2)
# This matches the previous call.
defined(foo=1)
self.assertLen(function_cache, 2)
defined(1, 2, 3)
self.assertLen(function_cache, 3)
# This matches the previous call.
defined(1, bar=2, baz=3)
self.assertLen(function_cache, 3)
# This matches the previous call.
defined(1, baz=3, bar=2)
self.assertLen(function_cache, 3)
@test_util.run_v2_only
def testFunctoolsPartialUnwrappedCorrectly(self):
def full_function(a, b, c=3):
return a, b, c
partial = functools.partial(full_function, 1, c=4)
a, b, c = partial(2)
defined = compiled_fn(partial)
func_a, func_b, func_c = defined(2)
self.assertEqual(func_a.numpy(), a)
self.assertEqual(func_b.numpy(), b)
self.assertEqual(func_c.numpy(), c)
def testInputSignatureWithMatchingInputs(self):
def foo(a):
self.assertEqual(a.shape, (2,))
return a
function_cache = function_cache_lib.FunctionCache()
signature = [tensor_lib.TensorSpec(shape=(2,), dtype=dtypes.float32)]
defined = compiled_fn(
foo, input_signature=signature, function_cache=function_cache
)
a = array_ops.ones([2])
self.assertAllEqual(a, defined(a))
self.assertLen(function_cache, 1)
self.assertAllEqual(a, defined.get_concrete_function()(a))
self.assertAllEqual(a, defined.get_concrete_function(a)(a))
self.assertAllEqual(
a,
defined.get_concrete_function(
tensor_lib.TensorSpec((2,), dtype=dtypes.float32)
)(a),
)
self.assertLen(function_cache, 1)
def bar(a):
self.assertEqual(a._shape_tuple(), (2, None))
return a
signature = [tensor_lib.TensorSpec((2, None), dtypes.float32)]
defined = compiled_fn(bar, input_signature=signature)
a = array_ops.ones([2, 1])
out = defined(a)
self.assertLen(function_cache, 1)
self.assertAllEqual(out, a)
# Changing the second dimension shouldn't create a new function.
b = array_ops.ones([2, 3])
out = defined(b)
self.assertLen(function_cache, 1)
self.assertAllEqual(out, b)
def testInputSignatureWithDictInPositionalArgs(self):
function_cache = function_cache_lib.FunctionCache()
@compiled_fn(function_cache=function_cache)
def f(*_args, **_kwargs):
return None
f(1, x=2)
self.assertLen(function_cache, 1)
f(1, x=2)
self.assertLen(function_cache, 1)
f(1, {'x': 2})
self.assertLen(function_cache, 2)
def testInputSignatureWithCompatibleInputs(self):
rank2_spec = tensor_lib.TensorSpec(
shape=(None, None), dtype=dtypes.float32
)
@compiled_fn(input_signature=[rank2_spec])
def func(a):
self.assertEqual([None, None], a.shape.as_list())
return array_ops.shape(a)
self.assertAllEqual([3, 1], func([[0], [1.0], [1]]))
self.assertAllEqual([2, 2], func(numpy.array([[1, 1], [2, 2]])))
with self.assertRaises(TypeError):
func([0.0, 1.0, 2.0]) # Wrong shape.
with self.assertRaises(TypeError):
func([['wrong dtype']])
@test_util.run_v2_only
def testNestedInputSignatures(self):
def expected_foo(a, b):
return [a, b]
function_cache = function_cache_lib.FunctionCache()
@compiled_fn(
input_signature=[
[tensor_lib.TensorSpec((2, None), dtypes.float32)] * 2,
tensor_lib.TensorSpec((1,), dtypes.float32),
],
function_cache=function_cache,
)
def foo(a, b):
self.assertEqual(a[0]._shape_tuple(), (2, None))
self.assertEqual(a[1]._shape_tuple(), (2, None))
self.assertEqual(b._shape_tuple(), (1,))
return [a, b]
a = array_ops.ones([2, 1])
b = array_ops.ones([1])
expected = expected_foo([a, a], b)
out = foo([a, a], b)
self.assertLen(function_cache, 1)
nest.assert_same_structure(out, expected)
self.assertAllEqual(out[0][0], a)
self.assertAllEqual(out[0][1], a)
self.assertAllEqual(out[1], b)
# Changing the unspecified dimensions shouldn't create a new function.
a = array_ops.ones([2, 3])
b = array_ops.ones([2, 5])
c = array_ops.ones([1])
expected = expected_foo([a, b], c)
out = foo([a, b], c)
self.assertLen(function_cache, 1)
nest.assert_same_structure(out, expected)
self.assertAllEqual(out[0][0], a)
self.assertAllEqual(out[0][1], b)
self.assertAllEqual(out[1], c)
# Passing compatible inputs should work.
a = a.numpy().tolist()
b = b.numpy().tolist()
c = c.numpy().tolist()
out = foo([a, b], c)
self.assertLen(function_cache, 1)
nest.assert_same_structure(out, expected)
self.assertAllEqual(out[0][0], a)
self.assertAllEqual(out[0][1], b)
self.assertAllEqual(out[1], c)
@test_util.run_v2_only
def testNestedInputSignaturesWithDict(self):
def expected_bar(a):
return a
@compiled_fn(
input_signature=[{
'a': tensor_lib.TensorSpec((2, None), dtypes.float32),
'b': tensor_lib.TensorSpec((2, None), dtypes.float32),
'c': tensor_lib.TensorSpec((1,), dtypes.float32),
}]
)
def bar(a):
self.assertEqual(a['a']._shape_tuple(), (2, None))
self.assertEqual(a['b']._shape_tuple(), (2, None))
self.assertEqual(a['c']._shape_tuple(), (1,))
return a
a = array_ops.ones([2, 3])
b = array_ops.ones([1])
inputs = {'a': a, 'b': a, 'c': b}
expected = expected_bar(inputs)
out = bar(inputs)
nest.assert_same_structure(out, expected)
self.assertAllEqual(out['a'], expected['a'])
self.assertAllEqual(out['b'], expected['b'])
self.assertAllEqual(out['c'], expected['c'])
# Passing compatible inputs should work.
a = a.numpy().tolist()
b = b.numpy().tolist()
inputs = {'a': a, 'b': a, 'c': b}
out = bar(inputs)
nest.assert_same_structure(out, expected)
self.assertAllEqual(out['a'], expected['a'])
self.assertAllEqual(out['b'], expected['b'])
self.assertAllEqual(out['c'], expected['c'])
def testInputSignatureMustBeSequenceOfTensorSpecs(self):
def foo(a, b):
del a
del b
# Signatures must be either lists or tuples on their outermost levels.
signature = {'t1': tensor_lib.TensorSpec([], dtypes.float32)}
with self.assertRaisesRegex(
TypeError, 'input_signature must be either a tuple or a list.*'
):
compiled_fn(foo, input_signature=signature)
def testInputsIncompatibleWithNestedSignatureRaisesError(self):
def foo(a, b):
return [a, b]
signature = [
[tensor_lib.TensorSpec((1,), dtypes.float32)] * 2,
[tensor_lib.TensorSpec((1,), dtypes.float32)] * 2,
]
defined = compiled_fn(foo, input_signature=signature)
a = array_ops.ones([1])
with self.assertRaises(TypeError):
defined([a, a, a], [a])
with self.assertRaises(TypeError):
defined([a], [a, a, a])
defined([a, a], [a, a])
@test_util.run_v2_only
def testUnderspecifiedInputSignature(self):
@compiled_fn(
input_signature=[
tensor_lib.TensorSpec([], dtypes.float32),
]
)
def foo(a, training=True):
if training:
return a
else:
return -1.0 * a
x = constant_op.constant(1.0)
with self.assertRaises(ValueError):
foo(x, training=False)
self.assertAllEqual(x.numpy(), foo(x).numpy())
@test_util.run_v2_only
def testInputSignatureWithPartialFunction(self):
def full_function(a, b, c=3.0):
return a, b, c
partial = functools.partial(full_function, 1, c=4)
a, b, c = partial(2.0)
signature = [tensor_lib.TensorSpec([], dtypes.float32)]
defined = compiled_fn(partial, input_signature=signature)
x = constant_op.constant(2.0)
func_a, func_b, func_c = defined(x)
self.assertEqual(func_a.numpy(), a)
self.assertEqual(func_b.numpy(), b)
self.assertEqual(func_c.numpy(), c)
@test_util.run_v2_only
def testInputSignatureWithKeywordPositionalArgs(self):
function_cache = function_cache_lib.FunctionCache()
@compiled_fn(
input_signature=[
tensor_lib.TensorSpec([], dtypes.float32),
tensor_lib.TensorSpec([], dtypes.int64),
],
function_cache=function_cache,
)
def foo(flt, integer):
return flt, integer
flt = constant_op.constant(1.0)
integer = constant_op.constant(2, dtypes.int64)
out1, out2 = foo(flt, integer)
self.assertLen(function_cache, 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
out1, out2 = foo(flt=flt, integer=integer)
self.assertLen(function_cache, 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
out1, out2 = foo(integer=integer, flt=flt)
self.assertLen(function_cache, 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
out1, out2 = foo(flt, integer=integer)
self.assertLen(function_cache, 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
@test_util.run_v2_only
def testInputSignatureWithKeywordArgs(self):
def foo(a, b, **kwargs):
del kwargs
return a, b
x = compiled_fn(
foo,
input_signature=[
tensor_lib.TensorSpec([], dtypes.float32),
tensor_lib.TensorSpec([], dtypes.int32),
],
).get_concrete_function()
result = x(constant_op.constant(5.0), constant_op.constant(5))
self.assertAllEqual(result, [5.0, 5])
def testInputSignatureWithCompositeTensors(self):
def f(rt):
self.assertEqual(rt.values.shape.as_list(), [None])
self.assertEqual(rt.row_splits.shape.as_list(), [4])
return rt
signature = [
ragged_tensor.RaggedTensorSpec(shape=[3, None], dtype=dtypes.int32)
]
function_cache = function_cache_lib.FunctionCache()
defined = compiled_fn(
f, input_signature=signature, function_cache=function_cache
)
rt1 = ragged_factory_ops.constant([[1], [], [2, 3, 4]])
out1 = defined(rt1)
self.assertLen(function_cache, 1)
self.assertAllEqual(out1.values, rt1.values)
self.assertAllEqual(out1.row_splits, rt1.row_splits)
# Changing the row lengths shouldn't create a new function.
rt2 = ragged_factory_ops.constant([[1, 2], [3, 4], [5]])
out2 = defined(rt2)
self.assertLen(function_cache, 1)
self.assertAllEqual(out2.values, rt2.values)
self.assertAllEqual(out2.row_splits, rt2.row_splits)
# Different number of rows
rt3 = ragged_factory_ops.constant([[1, 2], [3, 4], [5], [6]])
with self.assertRaises(TypeError):
defined(rt3)
# Different dtype
rt4 = ragged_factory_ops.constant([[1.0, 2.0], [], [3.0]])
with self.assertRaises(TypeError):
defined(rt4)
# Different rank
rt5 = ragged_factory_ops.constant([[[1]], [[2]], [[3]]])
with self.assertRaises(ValueError):
defined(rt5)
@test_util.run_v2_only
def testInputSignatureWithKeywordOnlyArgs(self):
def f(a, b, c=3, *, d=4):
self.assertIsInstance(a, tensor_lib.Tensor)
self.assertIsInstance(b, tensor_lib.Tensor)
self.assertIsInstance(c, int)
self.assertIsInstance(d, (int, tensor_lib.Tensor))
return a + b + c + d
signature = [
tensor_lib.TensorSpec(shape=[], dtype=dtypes.int32),
tensor_lib.TensorSpec(shape=[], dtype=dtypes.int32),
]
defined = compiled_fn(f, input_signature=signature)
self.assertEqual(defined(1, 2).numpy(), 10)
defined = compiled_fn(functools.partial(f, c=4), input_signature=signature)
self.assertEqual(defined(1, 2).numpy(), 11)
defined = compiled_fn(functools.partial(f, d=5), input_signature=signature)
self.assertEqual(defined(1, 2).numpy(), 11)
defined = compiled_fn(
functools.partial(f, d=array_ops.constant(5)), input_signature=signature
)
self.assertEqual(defined(1, 2).numpy(), 11)
mod = module.Module()
save(mod, '/tmp/kwonlyf', defined.get_concrete_function(*signature))
loaded = load('/tmp/kwonlyf')
result = loaded.signatures['serving_default'](
a=array_ops.constant(1),
b=array_ops.constant(2),
d=array_ops.constant(5),
)
self.assertEqual(result['output_0'].numpy(), 11)
def testInputSignatureWithKeywordOnlyArgsNoDefaults(self):
signature = [
tensor_lib.TensorSpec(shape=[], dtype=dtypes.int32),
tensor_lib.TensorSpec(shape=[], dtype=dtypes.int32),
]
def test_func(a, *, b):
return a + b
with self.assertRaisesRegex(
TypeError,
(
'Since input_signature is defined, keyword-only parameter `b` must'
' have a default value'
),
):
compiled_fn(test_func, input_signature=signature)
test_func_lambda = lambda a, *, b: a + b
with self.assertRaisesRegex(
TypeError,
(
'Since input_signature is defined, keyword-only parameter `b` must'
' have a default value'
),
):
compiled_fn(test_func_lambda, input_signature=signature)
def testTensorKeywordArguments(self):
def foo(a, b):
del a
return b
function_cache = function_cache_lib.FunctionCache()
defined = compiled_fn(foo, function_cache=function_cache)
a = constant_op.constant(2.0)
b = constant_op.constant([1.0, 2.0])
one = defined(a, b)
self.assertLen(function_cache, 1)
two = defined(a=a, b=b)
self.assertLen(function_cache, 1)
three = defined(b=b, a=a)
self.assertLen(function_cache, 1)
four = defined(a, b=b)
self.assertLen(function_cache, 1)
# The next call corresponds to a new input signature, hence
# we expect another function to be defined.
five = defined(b, a)
self.assertLen(function_cache, 2)
six = defined(a=b, b=a)
self.assertLen(function_cache, 2)
seven = defined(b=a, a=b)
self.assertLen(function_cache, 2)
self.assertAllEqual(one, [1.0, 2.0])
self.assertAllEqual(two, [1.0, 2.0])
self.assertAllEqual(three, [1.0, 2.0])
self.assertAllEqual(four, [1.0, 2.0])
self.assertAllEqual(five, 2.0)
self.assertAllEqual(six, 2.0)
self.assertAllEqual(seven, 2.0)
def testFunctionWithInvalidAttribute(self):
def add(x, y):
return math_ops.add(x, y)
with self.assertRaisesRegex(
ValueError,
'Tracing compilation does not support `experimental_1` as an'
' attribute.',
):
tracing_compilation.trace_function(
(1, 2),
tracing_options=tracing_compilation.TracingOptions(
add, 'add', attributes={'experimental_1': 'value1'}
),
)
def testRegisterFunction(self):
@compiled_fn(name='add', function_cache=function_cache_lib.FunctionCache())
def add(x, y):
return math_ops.add(x, y)
def matmul(x, y):
return math_ops.matmul(x, y)
defun_matmul = compiled_fn(
matmul, name='matmul', function_cache=function_cache_lib.FunctionCache()
)
with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
concrete_func_matmul = defun_matmul.get_concrete_function(t, t)
concrete_func_matmul.add_to_graph()
concrete_func_matmul.add_gradient_functions_to_graph()
concrete_func_add = add.get_concrete_function(t, t)
concrete_func_add.add_to_graph()
concrete_func_add.add_gradient_functions_to_graph()
graph = ops.get_default_graph()
# pylint: disable=protected-access
self.assertLen(graph._functions, 6)
# two sets of functions, each of them are (inference, forward, backward)
functions = list(graph._functions.values())
captured_function_names = [
f.cached_definition.signature.name for f in functions
]
expected_func_name_regex = [
'.*inference.*matmul.*',
'.*forward.*matmul.*',
'.*inference.*backward.*matmul.*',
'.*inference.*add.*',
'.*forward.*add.*',
'.*inference.*backward.*add.*',
]
for i in range(len(functions)):
self.assertRegex(
captured_function_names[i], expected_func_name_regex[i]
)
# Check the forward and backward function has the correct attributes.
self.assertEqual(
functions[1].cached_definition.attr['backward_function_name'].s,
functions[2].name,
)
self.assertEqual(
functions[2].cached_definition.attr['forward_function_name'].s,
functions[1].name,
)
self.assertEqual(
functions[4].cached_definition.attr['backward_function_name'].s,
functions[5].name,
)
self.assertEqual(
functions[5].cached_definition.attr['forward_function_name'].s,
functions[4].name,
)
sq = defun_matmul(t, t)
double = add(t, t)
self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
# Make sure the pre registered function is used, and no other function
# is added.
self.assertLen(graph._functions, 6)
functions = list(graph._functions.values())
for i in range(len(functions)):
self.assertEqual(
captured_function_names[i],
functions[i].cached_definition.signature.name,
)
@test_util.run_v2_only
def testRegisterConcreteFunction(self):
@compiled_fn(
name='py_add', function_cache=function_cache_lib.FunctionCache()
)
def py_add(x, y):
return math_ops.add(x, y)
py_add(array_ops.ones([]), array_ops.ones([]))
add = py_add.get_concrete_function(
tensor_lib.TensorSpec(None, dtypes.float32),
tensor_lib.TensorSpec(None, dtypes.float32),
)
@compiled_fn(
name='py_composite', function_cache=function_cache_lib.FunctionCache()
)
def py_composite(x, y):
return x, add(x, y)
py_composite(array_ops.ones([]), array_ops.ones([]))
composite = py_composite.get_concrete_function(
tensor_lib.TensorSpec(None, dtypes.float32),
tensor_lib.TensorSpec(None, dtypes.float32),
)
with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
composite.add_to_graph()
composite.add_gradient_functions_to_graph()
graph = ops.get_default_graph()
# pylint: disable=protected-access
self.assertLen(graph._functions, 6)
# two sets of functions, each of them are (inference, forward, backward)
functions = list(graph._functions.values())
captured_function_names = [
f.cached_definition.signature.name for f in functions
]
expected_func_name_regex = [
'.*inference.*py_composite.*',
'.*inference.*py_add.*',
'.*forward.*py_composite.*',
'.*forward.*py_add.*',
'.*inference.*backward.*py_composite.*',
'.*inference.*backward.*py_add.*',
]
for expected, found in zip(
expected_func_name_regex, captured_function_names
):
self.assertRegex(found, expected)
composite_t, composite_double = composite(t, t)
double = add(t, t)
self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(double))
self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(composite_double))
self.assertAllEqual([[1, 2], [3, 4]], self.evaluate(composite_t))
# Make sure the pre registered function is used, and no other function
# is added.
self.assertLen(graph._functions, 6)
def testEagerCaptures(self):
with context.eager_mode():
large_tensor = array_ops.ones(shape=(256,))
self.assertGreater(256, capture_container._EAGER_CONST_THRESHOLD)
small_tensor = array_ops.ones(shape=(4,))
self.assertLessEqual(4, capture_container._EAGER_CONST_THRESHOLD)
v = resource_variable_ops.ResourceVariable(0.0)
for captured, op_type in [
(large_tensor, 'Placeholder'),
(small_tensor, 'Const'),
(v, 'Placeholder'),
]:
@compiled_fn
def test_fn():
return captured + 1 # pylint: disable=cell-var-from-loop
g = test_fn.get_concrete_function().graph
internal_captures = g.internal_captures
self.assertLen(internal_captures, 1)
self.assertEqual(internal_captures[0].op.type, op_type)
def testRegisterFunctionWithInputSignature(self):
def matmul(x, y):
return math_ops.matmul(x, y)
defun_matmul = compiled_fn(
matmul,
input_signature=[
tensor_lib.TensorSpec(shape=(2, 2), dtype=dtypes.float32),
tensor_lib.TensorSpec(shape=(2, 2), dtype=dtypes.float32),
],
function_cache=function_cache_lib.FunctionCache(),
)
with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
concrete_func = defun_matmul.get_concrete_function(t, t)
concrete_func.add_to_graph()
concrete_func.add_gradient_functions_to_graph()
graph = ops.get_default_graph()
# pylint: disable=protected-access
self.assertLen(graph._functions, 3)
# Test register function with cache, note inputs are ignored.
concrete_func = defun_matmul.get_concrete_function()
concrete_func.add_to_graph()
concrete_func.add_gradient_functions_to_graph()
graph = ops.get_default_graph()
self.assertLen(graph._functions, 3)
def testRegisterFunctionWithCache(self):
def matmul(x, y):
return math_ops.matmul(x, y)
defun_matmul = compiled_fn(
matmul, function_cache=function_cache_lib.FunctionCache()
)
with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
t2 = constant_op.constant([[2.0, 3.0], [4.0, 5.0]])
concrete_func_t = defun_matmul.get_concrete_function(t, t)
concrete_func_t.add_to_graph()
concrete_func_t.add_gradient_functions_to_graph()
concrete_func_t2 = defun_matmul.get_concrete_function(t2, t2)
concrete_func_t2.add_to_graph()
concrete_func_t2.add_gradient_functions_to_graph()
graph = ops.get_default_graph()
# Only one function is registered since the input param are in same type
# pylint: disable=protected-access
self.assertLen(graph._functions, 3)
@test_util.run_v2_only
def testCallingFunctionWithDifferentVariables(self):
@compiled_fn
def foo(v):
v.assign_add(1.0)
return v.read_value()
v = resource_variable_ops.ResourceVariable(0.0)
graph_function = foo.get_concrete_function(v)
self.assertLen(graph_function.inputs, 1)
self.assertEmpty(graph_function.captured_inputs)
self.assertEqual(float(graph_function(v)), 1.0)
self.assertEqual(float(graph_function(v)), 2.0)
w = resource_variable_ops.ResourceVariable(0.0)
@compiled_fn
def bar(v):
del v
return constant_op.constant(1.0)
graph_function = bar.get_concrete_function(v)
self.assertEqual(float(graph_function(v)), 1.0)
self.assertEqual(float(graph_function(w)), 1.0)
def testCallingFunctionWithNonTensorsFails(self):
@compiled_fn
def foo(x):
return x
graph_function = foo.get_concrete_function(constant_op.constant(1.0))
with self.assertRaises((TypeError, ValueError)):
graph_function('Not a Tensor.')
@parameterized.parameters([
(
compiled_fn(
attributes={
'api_implements': 'random_boost',
'api_preferred_device': 'CPU',
}
),
compiled_fn(
attributes={
'api_implements': 'random_boost',
'api_preferred_device': 'GPU',
}
),
),
(
compiled_fn(
attributes={
'api_implements': 'random_boost',
'api_preferred_device': 'CPU',
}
),
compiled_fn(
attributes={
'api_implements': 'random_boost',
'api_preferred_device': 'GPU',
}
),
),
])
@test_util.run_v2_only
def testSwapImplementationWithGrapplerPlugin(
self, cpu_decorator, gpu_decorator
):
# Set the min_graph_nodes to -1 since the graph in this test is too small,
# and will be ignored by grappler if don't set this.
rewrites = rewriter_config_pb2.RewriterConfig()
rewrites.implementation_selector = rewriter_config_pb2.RewriterConfig.ON
rewrites.min_graph_nodes = -1
graph_options = config_pb2.GraphOptions(
rewrite_options=rewrites, build_cost_model=1
)
config_proto = config_pb2.ConfigProto(graph_options=graph_options)
with context.graph_mode(), self.cached_session(
config=config_proto, graph=ops.Graph(), use_gpu=True
):
@cpu_decorator
def cpu_boost(x):
return math_ops.add(x, 2.0)
@gpu_decorator
def gpu_boost(x):
return math_ops.add(x, 4.0)
x = constant_op.constant(1.0)
concrete_func = cpu_boost.get_concrete_function(x)
concrete_func.add_to_graph()
concrete_func.add_gradient_functions_to_graph()
y = gpu_boost(x)
y_value = self.evaluate(y)
if test.is_gpu_available():
self.assertEqual(y_value, 5.0)
else:
# Grappler fallback to use the CPU impl even called with GPU function.
self.assertEqual(y_value, 3.0)
@test_util.disable_tfrt(
"b/174712583: TFRT doesn't support behavior "
'equivalent to implementation_selector for function'
)
def testSwapImplementationInEager(self):
if not context.executing_eagerly():
self.skipTest('eager only')
# testSharedRendezvous sets the disable_meta_optimizer flag to True
# if that subtest runs before this one, then having that set to True
# will cause this subtest to fail. To avoid that scenario, explicitly
# set the disable_meta_optimizer flag to false here
context.context().set_optimizer_experimental_options({
'min_graph_nodes': -1,
'implementation_selector': True,
'disable_meta_optimizer': False,
})
@compiled_fn(
attributes={
'api_implements': 'foo',
'api_preferred_device': 'CPU',
}
)
def on_cpu(x):
return x + 2
@compiled_fn(
attributes={
'api_implements': 'foo',
'api_preferred_device': 'GPU',
}
)
def on_gpu(x):
return x + 4
@compiled_fn
def run_on_cpu(t):
concrete_func = on_cpu.get_concrete_function(t)
concrete_func.add_to_graph()
concrete_func.add_gradient_functions_to_graph()
with ops.device('CPU:0'):
return on_gpu(t)
# Expect to run the on_cpu branch, regardless whether gpu is available.
self.assertEqual(run_on_cpu(constant_op.constant(1)).numpy(), 3)
def testCompilationFunctionSeparateGraphs(self):
with context.graph_mode():
add_cache = function_cache_lib.FunctionCache()
@compiled_fn(function_cache=add_cache)
def add(x):
return x + 5
maybe_add_cache = function_cache_lib.FunctionCache()
@compiled_fn(function_cache=maybe_add_cache)
def maybe_add(x, should_add):
if should_add:
return add(x)
else:
return x
with ops.Graph().as_default():
x = constant_op.constant(11)
maybe_add(x, True)
self.assertLen(maybe_add_cache, 1)
self.assertLen(add_cache, 1)
maybe_add(x, False)
self.assertLen(maybe_add_cache, 2)
self.assertLen(add_cache, 1)
with ops.Graph().as_default():
x = constant_op.constant(11)
maybe_add(x, True)
self.assertLen(maybe_add_cache, 3)
self.assertLen(add_cache, 2)
def testCacheKeyOverlappingShapes(self):
function_cache = function_cache_lib.FunctionCache()
@compiled_fn(function_cache=function_cache)
def defined(t):
return t
defined(array_ops.zeros([12, 1]))
self.assertLen(function_cache, 1)
defined(array_ops.zeros([1, 21]))
self.assertLen(function_cache, 2)
function_cache = function_cache_lib.FunctionCache()
@compiled_fn(function_cache=function_cache)
def defined_again(t):
return defined(t)
defined_again.get_concrete_function(array_ops.zeros([12, 1]))
self.assertLen(function_cache, 1)
defined_again.get_concrete_function(array_ops.zeros([1, 21]))
self.assertLen(function_cache, 2)
def testCacheTensorSpecIdenticalToTensor(self):
@compiled_fn(function_cache=function_cache_lib.FunctionCache())
def defined(t):
return t
z = array_ops.zeros([2, 2])
z_spec = tensor_lib.TensorSpec.from_tensor(z)
self.assertIs(
defined.get_concrete_function(z_spec), defined.get_concrete_function(z)
)
def testCacheKeyNestedLists(self):
function_cache = function_cache_lib.FunctionCache()
@compiled_fn(function_cache=function_cache)
def defined(l):
return l
a = constant_op.constant(1.0)
b = constant_op.constant(2.0)
c = constant_op.constant(3.0)
defined([[a], b, c])
self.assertLen(function_cache, 1)
defined([[a, b], c])
self.assertLen(function_cache, 2)
def testCacheKeyAttrsClass(self):
if attr is None:
self.skipTest('attr module is unavailable.')
@attr.s
class TestClass:
a = attr.ib()
b = attr.ib()
function_cache = function_cache_lib.FunctionCache()
@compiled_fn(function_cache=function_cache)
def defined(l):
return l
defined(
TestClass(
constant_op.constant(1.0),
[constant_op.constant(2.0), constant_op.constant(3.0)],
)
)
self.assertLen(function_cache, 1)
defined(
TestClass(
constant_op.constant(1.0),
[constant_op.constant(2.0), constant_op.constant(3.0)],
)
)
self.assertLen(function_cache, 1)
defined(
TestClass(
[constant_op.constant(1.0), constant_op.constant(2.0)],
constant_op.constant(3.0),
)
)
self.assertLen(function_cache, 2)
def testDistinctVariablesNoRetracing(self):
function_cache = function_cache_lib.FunctionCache()
@compiled_fn(function_cache=function_cache)
def defined(a, b, c):
return a + b + c
x = resource_variable_ops.ResourceVariable(0.0)
y = resource_variable_ops.ResourceVariable(0.0)
z = resource_variable_ops.ResourceVariable(0.0)
# We generate cache keys based on unique combinations of resource ids.
defined(x, y, z)
self.assertLen(function_cache, 1)
# Re-arranging arguments should not cause cache miss
# because the three inputs are still distinct
defined(z, y, x)
self.assertLen(function_cache, 1)
def testRetracingOnDifferentVaribleCombinationPatterns(self):
function_cache = function_cache_lib.FunctionCache()
@compiled_fn(function_cache=function_cache)
def defined(a, b, c):
return a + b + c
x = resource_variable_ops.ResourceVariable(0.0)
y = resource_variable_ops.ResourceVariable(0.0)
z = resource_variable_ops.ResourceVariable(0.0)
defined(x, y, z)
self.assertLen(function_cache, 1)
# Retracing because the first two arguments are the same
defined(x, x, z)
self.assertLen(function_cache, 2)
# Replacing x with y does not cause cache miss
# because the combination stays the same as (x, x, z)
defined(y, y, z)
self.assertLen(function_cache, 2)
# A different combination pattern causes cache miss
defined(z, y, y)
self.assertLen(function_cache, 3)
defined(z, y, y)
self.assertLen(function_cache, 3)
@test_util.run_v2_only
def testDeepcopyVariableNoRetracing(self):
function_cache = function_cache_lib.FunctionCache()
@compiled_fn(function_cache=function_cache)
def defined(a, b, c):
return a + b + c
x = resource_variable_ops.ResourceVariable(0.0)
y = resource_variable_ops.ResourceVariable(0.0)
z = resource_variable_ops.ResourceVariable(0.0)
defined(x, y, z)
self.assertLen(function_cache, 1)
x_copy = copy.deepcopy(x)
defined(x_copy, y, z)
self.assertLen(function_cache, 1)
@test_util.disable_tfrt('b/173429686')
@test_util.run_v2_only
def testExecutorType(self):
@compiled_fn
def add_five(x):
return x + 5
self.assertEqual(
5, add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy()
)
with self.assertRaisesRegex(errors.NotFoundError, 'NON_EXISTENT_EXECUTOR'):
with context.function_executor_type('NON_EXISTENT_EXECUTOR'):
add_five(constant_op.constant(0, dtype=dtypes.int32))
for executor_type in ('', 'DEFAULT', None):
with context.function_executor_type(executor_type):
self.assertAllEqual(
5, add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy()
)
@test_util.assert_no_garbage_created
def testReferenceCycles(self):
fn = compiled_fn(lambda x: 2.0 * x)
fn(constant_op.constant(4.0))
weak_fn = weakref.ref(fn)
del fn
# Tests that the weak reference we made to the function is now dead, which
# means the object has been deleted. This should be true as long as the
# function itself is not involved in a reference cycle.
self.assertIs(None, weak_fn())
@test_util.run_in_graph_and_eager_modes
def testShapeCaching(self):
@compiled_fn
def func(x):
return array_ops.shape(x)
@compiled_fn(
input_signature=[tensor_lib.TensorSpec([None, None], dtypes.float32)]
)
def calls_func(x):
return func(x)
self.assertAllEqual([1, 1], self.evaluate(func(array_ops.zeros([1, 1]))))
self.assertAllEqual([2, 2], self.evaluate(func(array_ops.zeros([2, 2]))))
self.assertAllEqual(
[3, 3], self.evaluate(calls_func(array_ops.zeros([3, 3])))
)
def testLimitedRetracing(self):
trace_count = [0]
function_cache = function_cache_lib.FunctionCache()
@compiled_fn(function_cache=function_cache)
def func(x):
trace_count[0] += 1
return x
for _ in range(50):
func(constant_op.constant(3.0))
func(constant_op.constant(4.0))
func(constant_op.constant([[1.0, 2.0]]))
func(constant_op.constant([[]]))
func(constant_op.constant([[3.0, 4.0], [5.0, 6.0]]))
func(constant_op.constant([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]))
# Tracing more than twice per input doesn't make sense.
self.assertLess(trace_count[0], 13)
class CompilationCollectionTest(test.TestCase):
def testCollectionValueAccess(self):
"""Read values from graph collections inside of defun."""
with ops.Graph().as_default() as g:
with self.session(graph=g):
x = 2
y = 5
ops.add_to_collection('x', x)
ops.add_to_collection('y', y)
@compiled_fn
def fn():
x_const = constant_op.constant(ops.get_collection('x')[0])
y_const = constant_op.constant(ops.get_collection('y')[0])
z = math_ops.add(x_const, y_const)
ops.add_to_collection('z', 7)
return z
self.assertEqual(7, int(self.evaluate(fn())))
self.assertEqual(ops.get_collection('x'), [2])
self.assertEqual(ops.get_collection('y'), [5])
self.assertEqual(ops.get_collection('z'), [])
def testCollectionVariableValueAccess(self):
"""Read variable value from graph collections inside of defun."""
with ops.Graph().as_default() as g:
with self.session(graph=g):
v = resource_variable_ops.ResourceVariable(1.0)
@compiled_fn
def f():
return v.read_value()
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, float(self.evaluate(f())))
self.assertLen(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES), 1)
class MultiDeviceCompilationTest(test.TestCase, parameterized.TestCase):
@test_util.run_gpu_only
def testMultiDeviceOutput(self):
"""Tests that functions can produce outputs on multiple devices."""
@compiled_fn
def func(a, b, transpose_a):
with ops.device('/device:CPU:0'):
m1 = math_ops.matmul(a, b, transpose_a=transpose_a)
with ops.device('/device:GPU:0'):
m2 = math_ops.matmul(a, b, transpose_a=transpose_a)
return m1, m2
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
m1, m2 = func(t, t, transpose_a=True)
self.assertAllEqual(m1.numpy(), [[10, 14], [14, 20]])
self.assertRegex(m1.backing_device, 'CPU')
self.assertAllEqual(m2.numpy(), [[10, 14], [14, 20]])
self.assertRegex(m2.backing_device, 'GPU')
@test_util.run_gpu_only
def testEmptyBody(self):
@compiled_fn
def func(a, b):
return b, a
with ops.device('/device:CPU:0'):
a = array_ops.identity(3.0)
with ops.device('/device:GPU:0'):
b = array_ops.identity(5.0)
m1, m2 = func(a, b)
self.assertAllEqual(m1.numpy(), 5.0)
self.assertRegex(m1.backing_device, 'GPU')
self.assertAllEqual(m2.numpy(), 3.0)
self.assertRegex(m2.backing_device, 'CPU')
@test_util.run_gpu_only
def testMultiDeviceInt32(self):
"""Tests that multi-device functions can take and output INT32s.
When an INT32 device tensor is fed into a function, it is copied to CPU
by the eager runtime. The function sees all INT32 inputs on CPU.
We set allocator attribute 'on_host' for INT32 outputs. They can be
partitioned into the GPU component function, but will be allocated on
CPU nevertheless.
There is experimental support for `ints_on_device` in
FunctionLibraryRuntime now. We can try that.
"""
with ops.device('/device:CPU:0'):
int_cpu = constant_op.constant(3, dtype=dtypes.int32)
resource = resource_variable_ops.ResourceVariable(5, dtype=dtypes.int32)
with ops.device('/device:GPU:0'):
int_gpu = constant_op.constant(7, dtype=dtypes.int32)
@compiled_fn
def func(int_cpu, resource, int_gpu):
with ops.device('/device:CPU:0'):
m1 = int_cpu * resource + int_gpu
with ops.device('/device:GPU:0'):
# This computation will happen on GPU but m2 will be copied to CPU.
m2 = int_gpu * resource + int_cpu + 1
return m1, m2
m1, m2 = func(int_cpu, resource, int_gpu)
self.assertAllEqual(m1.numpy(), 22)
self.assertRegex(m1.backing_device, 'CPU')
self.assertAllEqual(m2.numpy(), 39)
self.assertRegex(m2.backing_device, 'CPU')
# flip arguments
m1, m2 = func(int_gpu, resource, int_cpu)
self.assertAllEqual(m1.numpy(), 38)
self.assertRegex(m1.backing_device, 'CPU')
self.assertAllEqual(m2.numpy(), 23)
self.assertRegex(m2.backing_device, 'CPU')
@test_util.run_gpu_only
def testMultiDeviceColocateWith(self):
"""Tests that function's outputs respect colocation constraints."""
@compiled_fn
def func(a, b):
with ops.colocate_with(a):
ra = 2 * a
with ops.colocate_with(b):
rb = 3 * b
return ra, rb
devices = ['/device:CPU:0', '/device:GPU:0']
for dev1, dev2 in itertools.product(devices, devices):
with ops.device(dev1):
a = array_ops.identity(1.0)
with ops.device(dev2):
b = array_ops.identity(10.0)
ra, rb = func(a, b)
self.assertEqual(ra.numpy(), 2.0)
self.assertRegex(ra.backing_device, dev1)
self.assertEqual(rb.numpy(), 30.0)
self.assertRegex(rb.backing_device, dev2)
@test_util.run_gpu_only
def testMultiDeviceResources(self):
with ops.device('/device:CPU:0'):
c1 = resource_variable_ops.ResourceVariable(2.0)
c2 = resource_variable_ops.ResourceVariable(7.0)
with ops.device('/device:GPU:0'):
g1 = resource_variable_ops.ResourceVariable(3.0)
g2 = resource_variable_ops.ResourceVariable(5.0)
@compiled_fn
def func(resource1, resource2):
with ops.device('/device:CPU:0'):
result1 = resource1 * g2
with ops.device('/device:GPU:0'):
result2 = resource2 * c2
return result1, result2
r1, r2 = func(c1, g1)
self.assertEqual(r1.numpy(), 10.0)
self.assertRegex(r1.backing_device, 'CPU')
self.assertEqual(r2.numpy(), 21.0)
self.assertRegex(r2.backing_device, 'GPU')
# Call with flipped inputs. Check that we look at resource's
# device and reinstantiates the function when inputs' devices change.
r1, r2 = func(g1, c1)
self.assertEqual(r1.numpy(), 15.0)
self.assertRegex(r1.backing_device, 'CPU')
self.assertEqual(r2.numpy(), 14.0)
self.assertRegex(r2.backing_device, 'GPU')
@test_util.run_gpu_only
def testOutputResources(self):
with ops.device('/device:CPU:0'):
c1 = resource_variable_ops.ResourceVariable(2.0)
with ops.device('/device:GPU:0'):
g1 = resource_variable_ops.ResourceVariable(3.0)
@compiled_fn
def func(resource1, resource2):
with ops.device('/device:CPU:0'):
result1 = resource1 * 5
with ops.device('/device:GPU:0'):
result2 = resource2 * 7
return result1, resource1.handle, result2, resource2.handle
r1, res1, r2, res2 = func(c1, g1)
self.assertEqual(r1.numpy(), 10.0)
self.assertRegex(r1.backing_device, 'CPU')
self.assertEqual(r2.numpy(), 21.0)
self.assertRegex(r2.backing_device, 'GPU')
def check_handle(handle, expected_value):
self.assertRegex(handle.backing_device, 'CPU')
tensor = gen_resource_variable_ops.read_variable_op(
handle, dtypes.float32
)
self.assertEqual(tensor.numpy(), expected_value)
# Check that handles returned from functions are on CPU and an op using
# the resource handle is correctly placed on the device backing the
# resource.
check_handle(res1, 2.0)
check_handle(res2, 3.0)
# Call with flipped inputs to make sure the same the function is
# reinstantiated and eager runtime does not mess up the device assignment
# for ops consuming handles returned from defuns.
r1, res1, r2, res2 = func(g1, c1)
self.assertEqual(r1.numpy(), 15.0)
self.assertRegex(r1.backing_device, 'CPU')
self.assertEqual(r2.numpy(), 14.0)
self.assertRegex(r2.backing_device, 'GPU')
check_handle(res1, 3.0)
check_handle(res2, 2.0)
@test_util.run_gpu_only
def testPassResourceThroughNestedFunctionCall(self):
"""Test passing GPU resource to noinline function call placed on CPU.
PartitionedCallOp must not enforce any particular device assignment for the
resource output. Inner function marked as `_nospecialize`, so Grappler would
not prune unused function output.
"""
with ops.device('/device:GPU:0'):
g1 = resource_variable_ops.ResourceVariable(3.0)
@compiled_fn(attributes={'_noinline': True, '_nospecialize': True})
def inner(resource1):
return resource1 * 2, resource1.handle
@compiled_fn
def outer(resource1):
with ops.device('/device:CPU:0'):
r1, _ = inner(resource1)
return r1
r1 = outer(g1)
self.assertEqual(r1.numpy(), 6.0)
self.assertRegex(r1.backing_device, 'CPU')
@test_util.run_gpu_only
def testReturnResourceFromNestedFunctionCall(self):
"""Test returning GPU resource from noinline function call placed on CPU.
When inferring output devices for the return value, do not set a device for
returns of DT_RESOURCE data type based on the device assignment of the node
that produced that resource. As an example function call placed on CPU can
return resources on GPU.
"""
with ops.device('/device:GPU:0'):
g1 = resource_variable_ops.ResourceVariable(3.0)
@compiled_fn(attributes={'_noinline': True})
def inner(resource1):
resource1.assign_add(2.0)
return resource1 * 2, resource1.handle
@compiled_fn
def outer(resource1):
with ops.device('/device:CPU:0'):
r1, res1 = inner(resource1)
return r1, res1
r1, res1 = outer(g1)
self.assertEqual(r1.numpy(), 10.0)
self.assertRegex(r1.backing_device, 'CPU')
def check_handle(handle, expected_value):
self.assertRegex(handle.backing_device, 'CPU')
tensor = gen_resource_variable_ops.read_variable_op(
handle, dtypes.float32
)
self.assertEqual(tensor.numpy(), expected_value)
# Check that handles returned from functions are on CPU and an op using
# the resource handle is correctly placed on the device backing the
# resource.
check_handle(res1, 5.0)
@test_util.run_gpu_only
def testComplexInputOutputDevicePattern(self):
"""Tests input/output mapping logic in partitioning."""
with ops.device('/device:CPU:0'):
rc0 = resource_variable_ops.ResourceVariable(2.0)
rc1 = resource_variable_ops.ResourceVariable(3.0)
cc0 = array_ops.identity(5.0)
cc1 = array_ops.identity(7.0)
with ops.device('/device:GPU:0'):
rg0 = resource_variable_ops.ResourceVariable(11.0)
rg1 = resource_variable_ops.ResourceVariable(13.0)
cg0 = array_ops.identity(17.0)
cg1 = array_ops.identity(19.0)
# Make sure tensors are on expected devices.
for tensor in [cc0, cc1]:
self.assertRegex(tensor.backing_device, 'CPU:0')
for tensor in [cg0, cg1]:
self.assertRegex(tensor.backing_device, 'GPU:0')
@compiled_fn
def func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1):
with ops.device('/device:CPU:0'):
m1 = rc0 * cg0
with ops.device('/device:GPU:0'):
m2 = rg0 * cc0
with ops.device('/device:CPU:0'):
r1 = 1000.0 * m2 + rc1 * cg1
with ops.device('/device:GPU:0'):
r2 = 1000.0 * m1 + rg1 * cc1
return r1, r2, m2, m1
r1, r2, m2, m1 = func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1)
self.assertRegex(m1.backing_device, 'CPU')
self.assertRegex(r1.backing_device, 'CPU')
self.assertRegex(m2.backing_device, 'GPU')
self.assertRegex(r2.backing_device, 'GPU')
self.assertEqual(m1.numpy(), 34.0)
self.assertEqual(r1.numpy(), 55000.0 + 3.0 * 19.0)
self.assertEqual(m2.numpy(), 55.0)
self.assertEqual(r2.numpy(), 34000.0 + 13.0 * 7.0)
@test_util.run_gpu_only
def testArgumentPruning(self):
"""Tests functions taking unnecessary arguments."""
with ops.device('/device:CPU:0'):
c1 = constant_op.constant(5.0)
c2 = constant_op.constant(7.0)
with ops.device('/device:GPU:0'):
g1 = constant_op.constant(11.0)
g2 = constant_op.constant(13.0)
g3 = constant_op.constant(17.0)
@compiled_fn
def func(g1, g2, c1, g3, c2): # pylint: disable=unused-argument
# arguments g1 and g2 are unused and can be pruned by grappler.
return c1 * g3 * c2
result = func(g1, g2, c1, g3, c2)
self.assertEqual(result.numpy(), 5.0 * 7.0 * 17.0)
class CompilationArgumentNamingTest(test.TestCase, parameterized.TestCase):
"""Tests for recognizable export signatures from concrete functions."""
@test_util.run_v2_only
def testBasic(self):
@compiled_fn
def fn(a, b):
return a + b, a * b
# Call the function to make def_function happy
fn(array_ops.ones([]), array_ops.ones([]))
fn_op = fn.get_concrete_function(
tensor_lib.TensorSpec(shape=(None,), dtype=dtypes.float32),
tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32),
)
self.assertEqual(['a', 'b'], [inp.op.name for inp in fn_op.inputs])
self.assertEqual(
[b'a', b'b'],
[inp.op.get_attr('_user_specified_name') for inp in fn_op.inputs],
)
self.assertLen(fn_op.graph.structured_outputs, 2)
self.assertAllClose(
[3.0, 2.0], fn_op(constant_op.constant(1.0), constant_op.constant(2.0))
)
self.assertAllClose(
[3.0, 2.0],
fn_op(a=constant_op.constant(1.0), b=constant_op.constant(2.0)),
)
def testVariable(self):
@compiled_fn
def fn(a, b):
return a + b, a * b
# Call the function to make def_function happy
fn(array_ops.ones([]), array_ops.ones([]))
fn_op = fn.get_concrete_function(
tensor_lib.TensorSpec(shape=(None,), dtype=dtypes.float32),
variables.Variable(1.0),
)
self.assertEqual(['a', 'b'], [inp.op.name for inp in fn_op.inputs])
self.assertEqual(
[b'a', b'b'],
[inp.op.get_attr('_user_specified_name') for inp in fn_op.inputs],
)
self.assertLen(fn_op.graph.structured_outputs, 2)
def testDictReturned(self):
@compiled_fn
def fn(x, z=(1.0, 2.0), y=3.0):
z1, z2 = z
return {'alpha': x + y + z1, 'beta': x * y + z2}
# Call the function to make def_function happy
fn(array_ops.ones([]))
fn_op = fn.get_concrete_function(
x=tensor_lib.TensorSpec(shape=(None,), dtype=dtypes.float32),
y=tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32),
)
self.assertEqual(['x', 'y'], [inp.op.name for inp in fn_op.inputs])
self.assertEqual(
[b'x', b'y'],
[inp.op.get_attr('_user_specified_name') for inp in fn_op.inputs],
)
self.assertEqual(
{'alpha', 'beta'}, set(fn_op.graph.structured_outputs.keys())
)
fn_op2 = fn.get_concrete_function(
z=(
tensor_lib.TensorSpec(
shape=(None,), dtype=dtypes.float32, name='z_first'
),
tensor_lib.TensorSpec(
shape=(), dtype=dtypes.float32, name='z_second'
),
),
y=tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32, name='custom'),
x=4.0,
)
self.assertEqual(
['z_first', 'z_second', 'custom'],
[inp.op.name for inp in fn_op2.inputs],
)
self.assertEqual(
[b'z_first', b'z_second', b'custom'],
[inp.op.get_attr('_user_specified_name') for inp in fn_op2.inputs],
)
fn_op3 = fn.get_concrete_function(
tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32, name='custom'),
z=(
tensor_lib.TensorSpec(
shape=(None,), dtype=dtypes.float32, name='z1'
),
tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32, name='z2'),
),
y=tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32),
)
self.assertEqual(
['custom', 'z1', 'z2', 'y'], [inp.op.name for inp in fn_op3.inputs]
)
self.assertEqual(
[b'custom', b'z1', b'z2', b'y'],
[inp.op.get_attr('_user_specified_name') for inp in fn_op3.inputs],
)
def testMethod(self):
class HasMethod(object):
def method(self, x):
return x
has_method = HasMethod()
compiled_method = compiled_fn(has_method.method)
class_op = compiled_method.get_concrete_function(
tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32)
)
self.assertEqual(['x'], [inp.op.name for inp in class_op.inputs])
self.assertEqual(
[b'x'],
[inp.op.get_attr('_user_specified_name') for inp in class_op.inputs],
)
method_op = compiled_method.get_concrete_function(
tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32)
)
self.assertEqual(['x'], [inp.op.name for inp in method_op.inputs])
self.assertEqual(
[b'x'],
[inp.op.get_attr('_user_specified_name') for inp in method_op.inputs],
)
# TODO(allenl): It should be possible to override names when exporting. Do
# TensorSpec names need to go in cache keys? Or maybe get_concrete_function
# should always retrace?
self.skipTest('Not working')
method_op = has_method.method.get_concrete_function(
tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32, name='y')
)
self.assertEqual(['y'], [inp.op.name for inp in method_op.inputs])
self.assertEqual(
[b'y'],
[inp.op.get_attr('_user_specified_name') for inp in method_op.inputs],
)
def testMethodSignature(self):
class HasMethod(object):
def method(self, x):
hash(self) # No weak proxies passed as `self`
return x
has_method = HasMethod()
compiled_method = compiled_fn(
has_method.method,
input_signature=(
tensor_lib.TensorSpec(shape=None, dtype=dtypes.float64, name='y'),
),
)
method_op = compiled_method.get_concrete_function()
self.assertEqual(['y'], [inp.op.name for inp in method_op.inputs])
self.assertEqual(
[b'y'],
[inp.op.get_attr('_user_specified_name') for inp in method_op.inputs],
)
method_op2 = compiled_method.get_concrete_function()
self.assertEqual(['y'], [inp.op.name for inp in method_op2.inputs])
self.assertEqual(
[b'y'],
[inp.op.get_attr('_user_specified_name') for inp in method_op2.inputs],
)
def testVariadic(self):
@compiled_fn
def variadic_fn(x, *args, **kwargs):
return x + math_ops.add_n(list(args) + list(kwargs.values()))
# Call the function to make def_function happy
variadic_fn(array_ops.ones([]), array_ops.ones([]))
variadic_op = variadic_fn.get_concrete_function(
tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32),
tensor_lib.TensorSpec(shape=None, dtype=dtypes.float32, name='y'),
tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32),
tensor_lib.TensorSpec(
shape=(), dtype=dtypes.float32, name='second_variadic'
),
z=tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32),
zz=tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32, name='cust'),
)
self.assertEqual(
['x', 'y', 'args_1', 'second_variadic', 'z', 'cust'],
[inp.op.name for inp in variadic_op.inputs],
)
self.assertEqual(
[b'x', b'y', b'args_1', b'second_variadic', b'z', b'cust'],
[inp.op.get_attr('_user_specified_name') for inp in variadic_op.inputs],
)
def testVariadicInputSignature(self):
@compiled_fn(
input_signature=(
tensor_lib.TensorSpec(shape=None, dtype=dtypes.float32),
tensor_lib.TensorSpec(shape=None, dtype=dtypes.float32, name='y'),
tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32),
tensor_lib.TensorSpec(shape=(), dtype=dtypes.float32, name='z'),
),
name='variadic_fn',
)
def variadic_fn(x, *args):
return x + math_ops.add_n(list(args))
# Call the function to make def_function happy
variadic_fn(
array_ops.ones([]),
array_ops.ones([]),
array_ops.ones([]),
array_ops.ones([]),
)
variadic_op = variadic_fn.get_concrete_function()
self.assertIn(b'variadic_fn', variadic_op.name)
self.assertEqual(
['x', 'y', 'args_1', 'z'], [inp.op.name for inp in variadic_op.inputs]
)
self.assertEqual(
[b'x', b'y', b'args_1', b'z'],
[inp.op.get_attr('_user_specified_name') for inp in variadic_op.inputs],
)
class DevicePlacementTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def testMultipleDeviceCheck(self):
def f():
with ops.device('cpu'):
return test_ops.device_placement_op()
func = compiled_fn(f)
with ops.device('cpu:0'):
output = self.evaluate(func())
self.assertIn(compat.as_bytes('CPU:0'), output)
@test_util.run_in_graph_and_eager_modes
def testDeviceAnnotationsRespected(self):
def multi_device_fn():
with ops.device('/cpu:0'):
s0 = test_ops.device_placement_op()
with ops.device('/cpu:1'):
s1 = test_ops.device_placement_op()
with ops.device('/cpu:2'):
s2 = test_ops.device_placement_op()
s3 = test_ops.device_placement_op()
return s0, s1, s2, s3
function_cache = function_cache_lib.FunctionCache()
defined = compiled_fn(multi_device_fn, function_cache=function_cache)
outputs = self.evaluate(defined())
self.assertLen(function_cache, 1)
self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
with ops.device('/cpu:3'):
outputs = self.evaluate(defined())
# All function definitions are agnostic to call site devices.
self.assertLen(function_cache, 1)
self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
self.assertIn(compat.as_bytes('CPU:3'), outputs[3])
with ops.device('/cpu:0'):
outputs = self.evaluate(defined())
self.assertLen(function_cache, 1)
self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
self.assertIn(compat.as_bytes('CPU:0'), outputs[3])
def setUpModule():
ops.enable_eager_execution()
cpus = config.list_physical_devices('CPU')
# Set 4 virtual CPUs
config.set_logical_device_configuration(
cpus[0],
[
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
],
)
if __name__ == '__main__':
test.main()