tensorflow/python/framework/op_def_library_test.py
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.ops.op_def_library."""
from google.protobuf import text_format
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import op_def_library
from tensorflow.python.framework import op_def_library_pybind
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
from tensorflow.python.util import compat
@test_util.add_graph_building_optimization_tests
class OpDefLibraryTest(test_util.TensorFlowTestCase):
def Tensor(self, t, name="in"):
return op_def_library.apply_op("OutT", T=t, name=name)
def testNoRegisteredOpFails(self):
with self.assertRaises(RuntimeError) as cm:
op_def_library.apply_op("unknown")
self.assertEqual(str(cm.exception), "Unrecognized Op name unknown")
def testSimple(self):
with ops.Graph().as_default():
out = op_def_library.apply_op("Simple", a=3)
self.assertEqual(dtypes.float32, out.dtype)
self.assertProtoEquals("""
name: 'Simple' op: 'Simple' input: 'Simple/a'
""", out.op.node_def)
out = op_def_library.apply_op("Simple", a=4)
self.assertProtoEquals("""
name: 'Simple_1' op: 'Simple' input: 'Simple_1/a'
""", out.op.node_def)
out = op_def_library.apply_op("Simple", a=5, name="named")
self.assertProtoEquals("""
name: 'named' op: 'Simple' input: 'named/a'
""", out.op.node_def)
out = op_def_library.apply_op(
"Simple", a=[[1, 2, 3], [4, 5, 6]], name="two_d")
self.assertProtoEquals("""
name: 'two_d' op: 'Simple' input: 'two_d/a'
""", out.op.node_def)
def testSimpleFailures(self):
with ops.Graph().as_default():
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("Simple", a="Bad string")
self.assertIn(
"Expected int32 passed to parameter 'a' of op 'Simple', "
"got 'Bad string' of type 'str' instead.", str(cm.exception))
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("Simple", a=self.Tensor(dtypes.string))
self.assertIn(
"Input 'a' of 'Simple' Op has type string "
"that does not match expected type of int32.", str(cm.exception))
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("Simple", a=6, extra="bogus")
self.assertIn("Simple got unexpected keyword arguments: extra",
str(cm.exception))
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op(
"Simple", a=6, extra1="bogus", extra2="also_bogus")
self.assertIn(
"Simple got unexpected keyword arguments: extra1, "
"extra2", str(cm.exception))
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("Simple")
self.assertIn("No argument for input a", str(cm.exception))
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("Simple", wrong=7)
self.assertIn("No argument for input a", str(cm.exception))
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("Simple", a={"label": 1})
self.assertIn(
"Expected int32 passed to parameter 'a' of op 'Simple', "
"got {'label': 1} of type 'dict' instead.", str(cm.exception))
def testReservedInput(self):
with ops.Graph().as_default():
op = op_def_library.apply_op("ReservedInput", input_=7, name="x")
self.assertProtoEquals("""
name: 'x' op: 'ReservedInput' input: 'x/input'
""", op.node_def)
def testPolymorphic(self):
with ops.Graph().as_default():
out = op_def_library.apply_op("Polymorphic", a=7, name="p")
self.assertEqual(dtypes.int32, out.dtype)
self.assertProtoEquals("""
name: 'p' op: 'Polymorphic' input: 'p/a'
attr { key: 'T' value { type: DT_INT32 } }
""", out.op.node_def)
out = op_def_library.apply_op("Polymorphic", a="s", name="q")
self.assertEqual(dtypes.string, out.dtype)
self.assertProtoEquals("""
name: 'q' op: 'Polymorphic' input: 'q/a'
attr { key: 'T' value { type: DT_STRING } }
""", out.op.node_def)
out = op_def_library.apply_op("Polymorphic", a=["s", "t", "u"], name="r")
self.assertEqual(dtypes.string, out.dtype)
self.assertProtoEquals("""
name: 'r' op: 'Polymorphic' input: 'r/a'
attr { key: 'T' value { type: DT_STRING } }
""", out.op.node_def)
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("Polymorphic", a="s", T=dtypes.string)
self.assertEqual(
str(cm.exception),
"Should not specify value for inferred attr 'T' for "
"Polymorphic.")
def testPolymorphicOut(self):
with ops.Graph().as_default():
out = op_def_library.apply_op("PolymorphicOut", T=dtypes.int32, name="p")
self.assertEqual(dtypes.int32, out.dtype)
self.assertProtoEquals("""
name: 'p' op: 'PolymorphicOut'
attr { key: 'T' value { type: DT_INT32 } }
""", out.op.node_def)
out = op_def_library.apply_op("PolymorphicOut", T=dtypes.bool, name="q")
self.assertEqual(dtypes.bool, out.dtype)
self.assertProtoEquals("""
name: 'q' op: 'PolymorphicOut'
attr { key: 'T' value { type: DT_BOOL } }
""", out.op.node_def)
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("PolymorphicOut")
self.assertEqual(
str(cm.exception), "No argument found for attr T for PolymorphicOut")
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("PolymorphicOut", T=None)
self.assertEqual(str(cm.exception),
"Expected DataType for argument 'T' not None.")
def testPolymorphicDefaultOut(self):
with ops.Graph().as_default():
out = op_def_library.apply_op("PolymorphicDefaultOut", T=None, name="p")
self.assertEqual(dtypes.string, out.dtype)
self.assertProtoEquals("""
name: 'p' op: 'PolymorphicDefaultOut'
attr { key: 'T' value { type: DT_STRING } }
""", out.op.node_def)
out = op_def_library.apply_op(
"PolymorphicDefaultOut", T=dtypes.bool, name="q")
self.assertEqual(dtypes.bool, out.dtype)
self.assertProtoEquals("""
name: 'q' op: 'PolymorphicDefaultOut'
attr { key: 'T' value { type: DT_BOOL } }
""", out.op.node_def)
def testBinary(self):
with ops.Graph().as_default():
out = op_def_library.apply_op("Binary", a=8, b=9, name="b")
self.assertEqual(dtypes.int32, out.dtype)
self.assertProtoEquals("""
name: 'b' op: 'Binary' input: 'b/a' input: 'b/b'
attr { key: 'T' value { type: DT_INT32 } }
""", out.op.node_def)
out = op_def_library.apply_op("Binary", a="left", b="right", name="c")
self.assertEqual(dtypes.string, out.dtype)
self.assertProtoEquals("""
name: 'c' op: 'Binary' input: 'c/a' input: 'c/b'
attr { key: 'T' value { type: DT_STRING } }
""", out.op.node_def)
with self.assertRaises(TypeError):
op_def_library.apply_op("Binary", a="left", b=12)
with self.assertRaises(TypeError):
op_def_library.apply_op(
"Binary", a=self.Tensor(dtypes.string), b=self.Tensor(dtypes.int32))
def testRestrict(self):
with ops.Graph().as_default():
out = op_def_library.apply_op("Restrict", a="foo", name="g")
self.assertEqual(dtypes.string, out.dtype)
self.assertProtoEquals("""
name: 'g' op: 'Restrict' input: 'g/a'
attr { key: 'T' value { type: DT_STRING } }
""", out.op.node_def)
out = op_def_library.apply_op("Restrict", a=True, name="h")
self.assertEqual(dtypes.bool, out.dtype)
self.assertProtoEquals("""
name: 'h' op: 'Restrict' input: 'h/a'
attr { key: 'T' value { type: DT_BOOL } }
""", out.op.node_def)
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("Restrict", a=17)
self.assertEqual(str(cm.exception),
"Value passed to parameter 'a' has DataType int32 "
"not in list of allowed values: string, bool")
def testTypeList(self):
with ops.Graph().as_default():
op = op_def_library.apply_op("TypeList", a=["foo"], name="z")
self.assertProtoEquals("""
name: 'z' op: 'TypeList' input: 'z/a_0'
attr { key: 'T' value { list { type: DT_STRING } } }
""", op.node_def)
op = op_def_library.apply_op("TypeList", a=[True, 12], name="y")
self.assertProtoEquals("""
name: 'y' op: 'TypeList' input: 'y/a_0' input: 'y/a_1'
attr { key: 'T' value { list { type: DT_BOOL type: DT_INT32 } } }
""", op.node_def)
op = op_def_library.apply_op("TypeList", a=[], name="empty")
self.assertProtoEquals("""
name: 'empty' op: 'TypeList' attr { key: 'T' value { list { } } }
""", op.node_def)
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("TypeList", a=17)
self.assertStartsWith(str(cm.exception),
"Expected list for 'a' "
"argument to 'TypeList' Op, not ")
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("TypeList", a=[self.Tensor(dtypes.int32), None])
self.assertStartsWith(str(cm.exception),
"Tensors in list passed to 'a' of 'TypeList' Op "
"have types [int32, <NOT CONVERTIBLE TO TENSOR>]")
def testTypeListTwice(self):
with ops.Graph().as_default():
op = op_def_library.apply_op(
"TypeListTwice", a=["foo", True], b=["bar", False], name="z")
self.assertProtoEquals("""
name: 'z' op: 'TypeListTwice'
input: 'z/a_0' input: 'z/a_1' input: 'z/b_0' input: 'z/b_1'
attr { key: 'T' value { list { type: DT_STRING type: DT_BOOL } } }
""", op.node_def)
op = op_def_library.apply_op("TypeListTwice", a=[], b=[], name="empty")
self.assertProtoEquals("""
name: 'empty' op: 'TypeListTwice' attr { key: 'T' value { list { } } }
""", op.node_def)
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("TypeListTwice", a=["foo", True], b=["bar", 6])
self.assertEqual(str(cm.exception),
"Input 'b' of 'TypeListTwice' Op has type list of "
"string, int32 that does not match type list "
"string, bool of argument 'a'.")
def testOutTypeList(self):
with ops.Graph().as_default():
out, = op_def_library.apply_op(
"OutTypeList", T=[dtypes.float32], name="x")
self.assertEqual(dtypes.float32, out.dtype)
self.assertProtoEquals("""
name: 'x' op: 'OutTypeList'
attr { key: 'T' value { list { type: DT_FLOAT } } }
""", out.op.node_def)
out1, out2 = op_def_library.apply_op(
"OutTypeList", T=[dtypes.int32, dtypes.bool], name="w")
self.assertEqual(dtypes.int32, out1.dtype)
self.assertEqual(dtypes.bool, out2.dtype)
self.assertProtoEquals("""
name: 'w' op: 'OutTypeList'
attr { key: 'T' value { list { type: DT_INT32 type: DT_BOOL } } }
""", out1.op.node_def)
out = op_def_library.apply_op("OutTypeList", T=[], name="empty")
self.assertEqual([], out)
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("OutTypeList", T=dtypes.int32)
self.assertEqual(
str(cm.exception), "Expected list for attr T, obtained "
"DType instead.")
def testTypeListRestrict(self):
with ops.Graph().as_default():
op = op_def_library.apply_op(
"TypeListRestrict", a=["foo", False], name="v")
self.assertProtoEquals("""
name: 'v' op: 'TypeListRestrict' input: 'v/a_0' input: 'v/a_1'
attr { key: 'T' value { list { type: DT_STRING type: DT_BOOL } } }
""", op.node_def)
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("TypeListRestrict", a=[True, 12])
self.assertEqual(str(cm.exception),
"Value passed to parameter 'a' has DataType int32 "
"not in list of allowed values: string, bool")
def testOutTypeListRestrict(self):
with ops.Graph().as_default():
out1, out2 = op_def_library.apply_op(
"OutTypeListRestrict", t=[dtypes.bool, dtypes.string], name="u")
self.assertEqual(dtypes.bool, out1.dtype)
self.assertEqual(dtypes.string, out2.dtype)
self.assertProtoEquals("""
name: 'u' op: 'OutTypeListRestrict'
attr { key: 't' value { list { type: DT_BOOL type: DT_STRING } } }
""", out1.op.node_def)
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op(
"OutTypeListRestrict", t=[dtypes.string, dtypes.int32])
self.assertEqual(str(cm.exception),
"Value passed to parameter 't' has DataType int32 "
"not in list of allowed values: string, bool")
def testAttr(self):
with ops.Graph().as_default():
op = op_def_library.apply_op("Attr", a=12, name="t")
self.assertProtoEquals("""
name: 't' op: 'Attr' attr { key: 'a' value { i: 12 } }
""", op.node_def)
op = op_def_library.apply_op(
"Attr", a=tensor_shape.Dimension(13), name="u")
self.assertProtoEquals("""
name: 'u' op: 'Attr' attr { key: 'a' value { i: 13 } }
""", op.node_def)
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("Attr", a="bad")
self.assertEqual(str(cm.exception),
"Expected int for argument 'a' not 'bad'.")
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("Attr", a=[12])
self.assertEqual(str(cm.exception),
"Expected int for argument 'a' not [12].")
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("Attr", a=None)
self.assertEqual(str(cm.exception),
"Expected int for argument 'a' not None.")
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("Attr")
self.assertEqual(
str(cm.exception), "No argument found for attr a for "
"Attr")
def testAttrFloat(self):
with ops.Graph().as_default():
op = op_def_library.apply_op("AttrFloat", a=1.2, name="t")
self.assertProtoEquals("""
name: 't' op: 'AttrFloat' attr { key: 'a' value { f: 1.2 } }
""", op.node_def)
op = op_def_library.apply_op("AttrFloat", a=12, name="u")
self.assertProtoEquals("""
name: 'u' op: 'AttrFloat' attr { key: 'a' value { f: 12 } }
""", op.node_def)
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("AttrFloat", a="bad")
self.assertEqual(str(cm.exception),
"Expected float for argument 'a' not 'bad'.")
def testAttrFunc(self):
with ops.Graph().as_default():
@function.Defun(dtypes.float32, func_name="MyFn")
def fn(x):
return 2 + x
op = op_def_library.apply_op("FuncAttr", f=fn, name="t")
self.assertProtoEquals("""
name: 't' op: 'FuncAttr' attr { key: 'f'
value { func { name: 'MyFn' } } }
""", op.node_def)
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("FuncAttr", f=3)
self.assertEqual(str(cm.exception),
"Don't know how to convert 3 to a func for argument f")
def testAttrFuncWithFuncWithAttrs(self):
with ops.Graph().as_default():
@def_function.function(
input_signature=(tensor.TensorSpec(None, dtypes.float32),),
autograph=False,
experimental_attributes={"_implements": 15})
def fn(x):
return 2 + x
concrete_fn = fn.get_concrete_function()
op = op_def_library.apply_op("FuncAttr", f=concrete_fn, name="t")
self.assertEqual(15, op.node_def.attr["f"].func.attr["_implements"].i)
self.assertEqual(
compat.as_str(concrete_fn.name), op.node_def.attr["f"].func.name
)
def testAttrFuncList(self):
with ops.Graph().as_default():
@function.Defun(dtypes.float32, func_name="MyFn")
def fn1(x):
return 2 + x
@function.Defun(dtypes.int32, dtypes.float32, func_name="MyFn2")
def fn2(x, y):
return 2 + x, y * 3
@function.Defun(dtypes.int32, func_name="MyFn3")
def fn3(y):
return 2 + y
op = op_def_library.apply_op("FuncListAttr", f=[fn1, fn2, fn3], name="t")
self.assertProtoEquals("""
name: 't' op: 'FuncListAttr'
attr { key: 'f' value { list { func { name: 'MyFn' }
func { name: 'MyFn2' }
func { name: 'MyFn3' } } } }
""", op.node_def)
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("FuncListAttr", f=[fn1, 3, fn2])
self.assertEqual(str(cm.exception),
"Don't know how to convert 3 to a func for argument f")
def testAttrBool(self):
with ops.Graph().as_default():
op = op_def_library.apply_op("AttrBool", a=True, name="t")
self.assertProtoEquals("""
name: 't' op: 'AttrBool' attr { key: 'a' value { b: true } }
""", op.node_def)
op = op_def_library.apply_op("AttrBool", a=False, name="u")
self.assertProtoEquals("""
name: 'u' op: 'AttrBool' attr { key: 'a' value { b: false } }
""", op.node_def)
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("AttrBool", a=0)
self.assertEqual(str(cm.exception),
"Expected bool for argument 'a' not 0.")
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("AttrBool", a=1)
self.assertEqual(str(cm.exception),
"Expected bool for argument 'a' not 1.")
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("AttrBool", a=[])
self.assertEqual(str(cm.exception),
"Expected bool for argument 'a' not [].")
def testAttrBoolList(self):
with ops.Graph().as_default():
op = op_def_library.apply_op(
"AttrBoolList", a=[True, False, True], name="t")
self.assertProtoEquals("""
name: 't' op: 'AttrBoolList'
attr { key: 'a' value { list { b: true b: false b:true } } }
""", op.node_def)
op = op_def_library.apply_op("AttrBoolList", a=[], name="u")
self.assertProtoEquals("""
name: 'u' op: 'AttrBoolList' attr { key: 'a' value { list { } } }
""", op.node_def)
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("AttrBoolList", a=[0])
self.assertEqual(str(cm.exception),
"Expected bool for argument 'a' not 0.")
def testAttrMin(self):
with ops.Graph().as_default():
op = op_def_library.apply_op("AttrMin", a=12, name="s")
self.assertProtoEquals("""
name: 's' op: 'AttrMin' attr { key: 'a' value { i: 12 } }
""", op.node_def)
with self.assertRaises(ValueError) as cm:
op_def_library.apply_op("AttrMin", a=2)
self.assertEqual(str(cm.exception),
"Attr 'a' of 'AttrMin' Op passed 2 less than minimum 5.")
def testAttrListMin(self):
with ops.Graph().as_default():
op = op_def_library.apply_op("AttrListMin", a=[1, 2], name="r")
self.assertProtoEquals("""
name: 'r' op: 'AttrListMin'
attr { key: 'a' value { list { i: 1 i: 2 } } }
""", op.node_def)
with self.assertRaises(ValueError) as cm:
op_def_library.apply_op("AttrListMin", a=[17])
self.assertEqual(str(cm.exception),
"Attr 'a' of 'AttrListMin' Op "
"passed list of length 1 less than minimum 2.")
def testAttrEnum(self):
with ops.Graph().as_default():
op = op_def_library.apply_op("AttrEnum", a="oranges", name="e")
self.assertProtoEquals("""
name: 'e' op: 'AttrEnum' attr { key: 'a' value { s: 'oranges' } }
""", op.node_def)
with self.assertRaises(ValueError) as cm:
op_def_library.apply_op("AttrEnum", a="invalid")
self.assertEqual(str(cm.exception),
'Attr \'a\' of \'AttrEnum\' Op '
'passed string \'invalid\' not in: '
'"apples", "oranges".')
def testAttrEnumList(self):
with ops.Graph().as_default():
op = op_def_library.apply_op(
"AttrEnumList", a=["oranges", "apples"], name="f")
self.assertProtoEquals("""
name: 'f' op: 'AttrEnumList'
attr { key: 'a' value { list { s: 'oranges' s: 'apples' } } }
""", op.node_def)
with self.assertRaises(ValueError) as cm:
op_def_library.apply_op(
"AttrEnumList", a=["apples", "invalid", "oranges"])
self.assertEqual(str(cm.exception),
'Attr \'a\' of \'AttrEnumList\' Op '
'passed string \'invalid\' not '
'in: "apples", "oranges".')
def testAttrShape(self):
with ops.Graph().as_default():
op = op_def_library.apply_op("AttrShape", a=[5], name="s1")
self.assertProtoEquals("""
name: 's1' op: 'AttrShape'
attr { key: 'a' value { shape { dim { size: 5 } } } }
""", op.node_def)
op = op_def_library.apply_op("AttrShape", a=(4, 3, 2), name="s2")
self.assertProtoEquals("""
name: 's2' op: 'AttrShape'
attr { key: 'a' value {
shape { dim { size: 4 } dim { size: 3 } dim { size: 2 } } } }
""", op.node_def)
op = op_def_library.apply_op(
"AttrShape", a=tensor_shape.TensorShape([3, 2]), name="s3")
self.assertProtoEquals("""
name: 's3' op: 'AttrShape'
attr { key: 'a' value {
shape { dim { size: 3 } dim { size: 2 } } } }
""", op.node_def)
op = op_def_library.apply_op("AttrShape", a=[], name="s4")
self.assertProtoEquals("""
name: 's4' op: 'AttrShape' attr { key: 'a' value { shape { } } }
""", op.node_def)
shape = tensor_shape_pb2.TensorShapeProto()
shape.dim.add().size = 6
shape.dim.add().size = 3
op = op_def_library.apply_op("AttrShape", a=shape, name="s5")
self.assertProtoEquals("""
name: 's5' op: 'AttrShape'
attr { key: 'a' value { shape { dim { size: 6 } dim { size: 3 } } } }
""", op.node_def)
# TODO(josh11b): Re-enable this test once we stop promoting scalars to
# shapes.
# with self.assertRaises(TypeError) as cm:
# op_def_library.apply_op("AttrShape", a=5)
# self.assertEqual(str(cm.exception),
# "Don't know how to convert 5 to a TensorShapeProto for"
# " argument 'a'")
with self.assertRaises(TypeError):
op_def_library.apply_op("AttrShape", a="ABC")
def testAttrShapeList(self):
with ops.Graph().as_default():
op = op_def_library.apply_op(
"AttrShapeList", a=[[3, 2], [6, 5, 4]], name="sl")
self.assertProtoEquals("""
name: 'sl' op: 'AttrShapeList'
attr { key: 'a' value { list {
shape { dim { size: 3 } dim { size: 2 } }
shape { dim { size: 6 } dim { size: 5 } dim { size: 4 } } } } }
""", op.node_def)
op = op_def_library.apply_op("AttrShapeList", a=[], name="esl")
self.assertProtoEquals("""
name: 'esl' op: 'AttrShapeList' attr { key: 'a' value { list { } } }
""", op.node_def)
def testAttrPartialShape(self):
with ops.Graph().as_default():
op = op_def_library.apply_op("AttrPartialShape", a=[5], name="s1")
self.assertProtoEquals("""
name: 's1' op: 'AttrPartialShape'
attr { key: 'a' value { shape { dim { size: 5 } } } }
""", op.node_def)
op = op_def_library.apply_op(
"AttrPartialShape", a=(4, None, 2), name="s2")
self.assertProtoEquals("""
name: 's2' op: 'AttrPartialShape'
attr { key: 'a' value {
shape { dim { size: 4 } dim { size: -1 } dim { size: 2 } } } }
""", op.node_def)
op = op_def_library.apply_op(
"AttrPartialShape", a=tensor_shape.TensorShape([3, None]), name="s3")
self.assertProtoEquals("""
name: 's3' op: 'AttrPartialShape'
attr { key: 'a' value {
shape { dim { size: 3 } dim { size: -1 } } } }
""", op.node_def)
op = op_def_library.apply_op("AttrPartialShape", a=[], name="s4")
self.assertProtoEquals("""
name: 's4' op: 'AttrPartialShape'
attr { key: 'a' value { shape { } } }
""", op.node_def)
shape = tensor_shape_pb2.TensorShapeProto()
shape.dim.add().size = -1
shape.dim.add().size = 3
op = op_def_library.apply_op("AttrPartialShape", a=shape, name="s5")
self.assertProtoEquals("""
name: 's5' op: 'AttrPartialShape'
attr { key: 'a' value {
shape { dim { size: -1 } dim { size: 3 } } } }
""", op.node_def)
# TODO(ebrevdo): Re-enable once we stop promoting scalars to shapes.
# with self.assertRaises(TypeError) as cm:
# op_def_library.apply_op("AttrPartialShape", a=5)
# self.assertEqual(str(cm.exception),
# "Don't know how to convert 5 to a TensorShapeProto for"
# " argument 'a'")
with self.assertRaises(TypeError):
op_def_library.apply_op("AttrPartialShape", a="ABC")
def testAttrPartialShapeList(self):
with ops.Graph().as_default():
op = op_def_library.apply_op(
"AttrPartialShapeList", a=[[3, 2], [6, None, 4]], name="sl")
self.assertProtoEquals("""
name: 'sl' op: 'AttrPartialShapeList'
attr { key: 'a' value { list {
shape { dim { size: 3 } dim { size: 2 } }
shape { dim { size: 6 } dim { size: -1 } dim { size: 4 } } } } }
""", op.node_def)
op = op_def_library.apply_op("AttrPartialShapeList", a=[], name="esl")
self.assertProtoEquals("""
name: 'esl' op: 'AttrPartialShapeList' attr {
key: 'a' value { list { } } }
""", op.node_def)
def testAttrDefault(self):
with ops.Graph().as_default():
op = op_def_library.apply_op("AttrDefault", a=None, name="d")
self.assertProtoEquals("""
name: 'd' op: 'AttrDefault' attr { key: 'a' value { s: 'banana' } }
""", op.node_def)
op = op_def_library.apply_op("AttrDefault", a="kiwi", name="c")
self.assertProtoEquals("""
name: 'c' op: 'AttrDefault' attr { key: 'a' value { s: 'kiwi' } }
""", op.node_def)
def testAttrListDefault(self):
with ops.Graph().as_default():
op = op_def_library.apply_op("AttrListDefault", a=None, name="b")
self.assertProtoEquals("""
name: 'b' op: 'AttrListDefault'
attr { key: 'a' value { list { i: 5 i: 15 } } }
""", op.node_def)
op = op_def_library.apply_op("AttrListDefault", a=[3], name="a")
self.assertProtoEquals("""
name: 'a' op: 'AttrListDefault'
attr { key: 'a' value { list { i: 3 } } }
""", op.node_def)
op = op_def_library.apply_op("AttrListDefault", a=[], name="empty")
self.assertProtoEquals("""
name: 'empty' op: 'AttrListDefault'
attr { key: 'a' value { list { } } }
""", op.node_def)
def testAttrEmptyListDefault(self):
with ops.Graph().as_default():
op = op_def_library.apply_op("AttrEmptyListDefault", a=None, name="b")
self.assertProtoEquals("""
name: 'b' op: 'AttrEmptyListDefault'
attr { key: 'a' value { list { } } }
""", op.node_def)
op = op_def_library.apply_op("AttrEmptyListDefault", a=[3], name="a")
self.assertProtoEquals("""
name: 'a' op: 'AttrEmptyListDefault'
attr { key: 'a' value { list { f: 3 } } }
""", op.node_def)
op = op_def_library.apply_op("AttrEmptyListDefault", a=[], name="empty")
self.assertProtoEquals("""
name: 'empty' op: 'AttrEmptyListDefault'
attr { key: 'a' value { list { } } }
""", op.node_def)
def testReservedAttr(self):
with ops.Graph().as_default():
op = op_def_library.apply_op("ReservedAttr", range_=7, name="x")
self.assertProtoEquals("""
name: 'x' op: 'ReservedAttr' attr { key: 'range' value { i: 7 } }
""", op.node_def)
def testDefaultAttrType(self):
with ops.Graph().as_default():
# Give an input whose type has no obvious output type.
op = op_def_library.apply_op("AttrTypeDefault", a=[], name="n")
self.assertProtoEquals("""
name: 'n' op: 'AttrTypeDefault' input: 'n/a'
attr { key: 'T' value { type: DT_INT32 } }
""", op.node_def)
# Give an input whose type can be inferred as different
# than the default.
op = op_def_library.apply_op("AttrTypeDefault", a=[1.0], name="f")
self.assertProtoEquals("""
name: 'f' op: 'AttrTypeDefault' input: 'f/a'
attr { key: 'T' value { type: DT_FLOAT } }
""", op.node_def)
def testDefaultListAttrType(self):
with ops.Graph().as_default():
# Give an input whose type can be inferred as different
# than the default.
op = op_def_library.apply_op(
"AttrListTypeDefault", a=[1.0], b=[2.0], name="n")
self.assertProtoEquals("""
name: 'n' op: 'AttrListTypeDefault' input: 'n/a_0' input: 'n/b_0'
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'N' value { i: 1 } }
""", op.node_def)
def testNIntsIn(self):
with ops.Graph().as_default():
op = op_def_library.apply_op("NIntsIn", a=[1, 2], name="n")
self.assertProtoEquals("""
name: 'n' op: 'NIntsIn' input: 'n/a_0' input: 'n/a_1'
attr { key: 'N' value { i: 2 } }
""", op.node_def)
op = op_def_library.apply_op("NIntsIn", a=[5, 4, 3, 2, 1], name="o")
self.assertProtoEquals("""
name: 'o' op: 'NIntsIn'
input: 'o/a_0' input: 'o/a_1' input: 'o/a_2' input: 'o/a_3' input: 'o/a_4'
attr { key: 'N' value { i: 5 } }
""", op.node_def)
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("NIntsIn", a=["foo", "bar"])
self.assertEqual(
str(cm.exception),
"Tensors in list passed to 'a' of 'NIntsIn' Op have types "
"[string, string] that do not match expected type int32.")
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op(
"NIntsIn",
a=[self.Tensor(dtypes.string),
self.Tensor(dtypes.string)])
self.assertEqual(str(cm.exception),
"Tensors in list passed to 'a' of 'NIntsIn' Op have "
"types [string, string] that do not match expected type "
"int32.")
with self.assertRaises(ValueError) as cm:
op_def_library.apply_op("NIntsIn", a=[99])
self.assertEqual(str(cm.exception),
"List argument 'a' to 'NIntsIn' Op "
"with length 1 shorter than "
"minimum length 2.")
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("NIntsIn", a=[38, "bar"])
self.assertEqual(
str(cm.exception),
"Tensors in list passed to 'a' of 'NIntsIn' Op have types "
"[int32, string] that do not match expected type int32.")
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op(
"NIntsIn",
a=[self.Tensor(dtypes.int32),
self.Tensor(dtypes.string)])
self.assertEqual(str(cm.exception),
"Tensors in list passed to 'a' of 'NIntsIn' Op "
"have types [int32, string] that do not match expected "
"type int32.")
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("NIntsIn", a=17)
self.assertStartsWith(str(cm.exception),
"Expected list for 'a' argument "
"to 'NIntsIn' Op, not ")
def testNPolymorphicIn(self):
with ops.Graph().as_default():
op = op_def_library.apply_op("NPolymorphicIn", a=[1, 2], name="n")
self.assertProtoEquals("""
name: 'n' op: 'NPolymorphicIn' input: 'n/a_0' input: 'n/a_1'
attr { key: 'T' value { type: DT_INT32 } }
attr { key: 'N' value { i: 2 } }
""", op.node_def)
op = op_def_library.apply_op(
"NPolymorphicIn", a=[5, 4, 3, 2, 1], name="o")
self.assertProtoEquals("""
name: 'o' op: 'NPolymorphicIn'
input: 'o/a_0' input: 'o/a_1' input: 'o/a_2' input: 'o/a_3' input: 'o/a_4'
attr { key: 'T' value { type: DT_INT32 } }
attr { key: 'N' value { i: 5 } }
""", op.node_def)
op = op_def_library.apply_op("NPolymorphicIn", a=["foo", "bar"], name="p")
self.assertProtoEquals("""
name: 'p' op: 'NPolymorphicIn' input: 'p/a_0' input: 'p/a_1'
attr { key: 'T' value { type: DT_STRING } }
attr { key: 'N' value { i: 2 } }
""", op.node_def)
op = op_def_library.apply_op(
"NPolymorphicIn",
a=[1, self.Tensor(dtypes.float32, name="x")],
name="q")
self.assertProtoEquals("""
name: 'q' op: 'NPolymorphicIn' input: 'q/a_0' input: 'x'
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'N' value { i: 2 } }
""", op.node_def)
op = op_def_library.apply_op(
"NPolymorphicIn",
a=[
self.Tensor(dtypes.float32, name="y"),
self.Tensor(dtypes.float32_ref, name="z")
],
name="r")
self.assertProtoEquals("""
name: 'r' op: 'NPolymorphicIn' input: 'y' input: 'z'
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'N' value { i: 2 } }
""", op.node_def)
with self.assertRaises(ValueError) as cm:
op_def_library.apply_op("NPolymorphicIn", a=[99])
self.assertEqual(str(cm.exception),
"List argument 'a' to 'NPolymorphicIn' Op with length 1 "
"shorter than minimum length 2.")
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("NPolymorphicIn", a=[38, "bar"])
self.assertEqual(str(cm.exception),
"Tensors in list passed to 'a' of 'NPolymorphicIn' Op "
"have types [int32, string] that don't all match.")
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op(
"NPolymorphicIn", a=[38, self.Tensor(dtypes.string)])
self.assertEqual(str(cm.exception),
"Tensors in list passed to 'a' of 'NPolymorphicIn' Op "
"have types [int32, string] that don't all match.")
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("NPolymorphicIn", a=[38, None])
self.assertEqual(str(cm.exception),
"Tensors in list passed to 'a' of 'NPolymorphicIn' Op "
"have types [int32, <NOT CONVERTIBLE TO TENSOR>] that "
"don't all match.")
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op(
"NPolymorphicIn", a=["abcd", self.Tensor(dtypes.int32)])
self.assertEqual(str(cm.exception),
"Tensors in list passed to 'a' of 'NPolymorphicIn' Op "
"have types [string, int32] that don't all match.")
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("NPolymorphicIn", a=17)
self.assertStartsWith(str(cm.exception),
"Expected list for 'a' argument "
"to 'NPolymorphicIn' Op, not ")
def testNPolymorphicRestrictIn(self):
with ops.Graph().as_default():
op = op_def_library.apply_op(
"NPolymorphicRestrictIn", a=["foo", "bar"], name="p")
self.assertProtoEquals("""
name: 'p' op: 'NPolymorphicRestrictIn' input: 'p/a_0' input: 'p/a_1'
attr { key: 'T' value { type: DT_STRING } }
attr { key: 'N' value { i: 2 } }
""", op.node_def)
op = op_def_library.apply_op(
"NPolymorphicRestrictIn", a=[False, True, False], name="b")
self.assertProtoEquals("""
name: 'b' op: 'NPolymorphicRestrictIn'
input: 'b/a_0' input: 'b/a_1' input: 'b/a_2'
attr { key: 'T' value { type: DT_BOOL } }
attr { key: 'N' value { i: 3 } }
""", op.node_def)
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("NPolymorphicRestrictIn", a=[1, 2])
self.assertEqual(
str(cm.exception),
"Value passed to parameter 'a' has DataType int32 not in "
"list of allowed values: string, bool")
def testNInTwice(self):
with ops.Graph().as_default():
op = op_def_library.apply_op(
"NInTwice", a=[1, 2], b=["one", "two"], name="n")
self.assertProtoEquals("""
name: 'n' op: 'NInTwice'
input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1'
attr { key: 'N' value { i: 2 } }
""", op.node_def)
op = op_def_library.apply_op("NInTwice", a=[], b=[], name="o")
self.assertProtoEquals("""
name: 'o' op: 'NInTwice' attr { key: 'N' value { i: 0 } }
""", op.node_def)
with self.assertRaises(ValueError) as cm:
op_def_library.apply_op("NInTwice", a=[1, 2, 3], b=["too short"])
self.assertEqual(str(cm.exception),
"List argument 'b' to 'NInTwice' Op "
"with length 1 must match "
"length 3 of argument 'a'.")
def testNInPolymorphicTwice(self):
with ops.Graph().as_default():
op = op_def_library.apply_op(
"NInPolymorphicTwice", a=[1, 2], b=[3, 4], name="n")
self.assertProtoEquals("""
name: 'n' op: 'NInPolymorphicTwice'
input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1'
attr { key: 'T' value { type: DT_INT32 } }
attr { key: 'N' value { i: 2 } }
""", op.node_def)
with self.assertRaises(ValueError) as cm:
op_def_library.apply_op("NInPolymorphicTwice", a=[1, 2, 3], b=[5])
self.assertEqual(str(cm.exception),
"List argument 'b' to 'NInPolymorphicTwice' Op "
"with length 1 "
"must match length 3 of argument 'a'.")
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op(
"NInPolymorphicTwice", a=[1, 2], b=["one", "two"])
self.assertEqual(str(cm.exception),
"Tensors in list passed to 'b' of 'NInPolymorphicTwice' "
"Op have types [string, string] that do not match type "
"int32 inferred from earlier arguments.")
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op(
"NInPolymorphicTwice",
a=[self.Tensor(dtypes.int32)],
b=[self.Tensor(dtypes.string)])
self.assertEqual(str(cm.exception),
"Tensors in list passed to 'b' of "
"'NInPolymorphicTwice' Op have types [string] that do "
"not match type int32 inferred from earlier arguments.")
def testNInTwoTypeVariables(self):
with ops.Graph().as_default():
op = op_def_library.apply_op(
"NInTwoTypeVariables", a=[1, 2], b=[True, False], name="n")
self.assertProtoEquals("""
name: 'n' op: 'NInTwoTypeVariables'
input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1'
attr { key: 'S' value { type: DT_INT32 } }
attr { key: 'T' value { type: DT_BOOL } }
attr { key: 'N' value { i: 2 } }
""", op.node_def)
op = op_def_library.apply_op(
"NInTwoTypeVariables", a=[1, 2], b=[3, 4], name="o")
self.assertProtoEquals("""
name: 'o' op: 'NInTwoTypeVariables'
input: 'o/a_0' input: 'o/a_1' input: 'o/b_0' input: 'o/b_1'
attr { key: 'S' value { type: DT_INT32 } }
attr { key: 'T' value { type: DT_INT32 } }
attr { key: 'N' value { i: 2 } }
""", op.node_def)
op = op_def_library.apply_op(
"NInTwoTypeVariables",
a=[self.Tensor(dtypes.int32, name="q")],
b=[self.Tensor(dtypes.string, name="r")],
name="p")
self.assertProtoEquals("""
name: 'p' op: 'NInTwoTypeVariables' input: 'q' input: 'r'
attr { key: 'S' value { type: DT_INT32 } }
attr { key: 'T' value { type: DT_STRING } }
attr { key: 'N' value { i: 1 } }
""", op.node_def)
with self.assertRaises(ValueError) as cm:
op_def_library.apply_op("NInTwoTypeVariables", a=[1, 2, 3], b=["5"])
self.assertEqual(str(cm.exception),
"List argument 'b' to 'NInTwoTypeVariables' Op "
"with length 1 "
"must match length 3 of argument 'a'.")
def testInPolymorphicTwice(self):
with ops.Graph().as_default():
op = op_def_library.apply_op(
"InPolymorphicTwice", a=[8], b=[3, 4, 5], name="n")
self.assertProtoEquals("""
name: 'n' op: 'InPolymorphicTwice'
input: 'n/a_0' input: 'n/b_0' input: 'n/b_1' input: 'n/b_2'
attr { key: 'T' value { type: DT_INT32 } }
attr { key: 'N' value { i: 1 } }
attr { key: 'M' value { i: 3 } }
""", op.node_def)
op = op_def_library.apply_op("InPolymorphicTwice", a=[8], b=[], name="o")
self.assertProtoEquals("""
name: 'o' op: 'InPolymorphicTwice' input: 'o/a_0'
attr { key: 'T' value { type: DT_INT32 } }
attr { key: 'N' value { i: 1 } }
attr { key: 'M' value { i: 0 } }
""", op.node_def)
op = op_def_library.apply_op(
"InPolymorphicTwice", a=[], b=[3, 4], name="p")
self.assertProtoEquals("""
name: 'p' op: 'InPolymorphicTwice' input: 'p/b_0' input: 'p/b_1'
attr { key: 'T' value { type: DT_INT32 } }
attr { key: 'N' value { i: 0 } }
attr { key: 'M' value { i: 2 } }
""", op.node_def)
op = op_def_library.apply_op(
"InPolymorphicTwice", a=[], b=[3.0, 4.0], name="q")
self.assertProtoEquals("""
name: 'q' op: 'InPolymorphicTwice' input: 'q/b_0' input: 'q/b_1'
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'N' value { i: 0 } }
attr { key: 'M' value { i: 2 } }
""", op.node_def)
# Empty input lists: assume default type for T.
op = op_def_library.apply_op(
"InPolymorphicTwice", a=[], b=[], name="r")
self.assertProtoEquals("""
name: 'r' op: 'InPolymorphicTwice'
attr { key: 'T' value { type: DT_INT32 } }
attr { key: 'N' value { i: 0 } }
attr { key: 'M' value { i: 0 } }
""", op.node_def)
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op(
"InPolymorphicTwice", a=[1, 2], b=["one", "two"])
self.assertEqual(
str(cm.exception),
"Tensors in list passed to 'b' of 'InPolymorphicTwice' Op "
"have types [string, string] that do not match type int32 "
"inferred from earlier arguments.")
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op(
"InPolymorphicTwice",
a=[self.Tensor(dtypes.int32)],
b=[self.Tensor(dtypes.string)])
self.assertEqual(str(cm.exception),
"Tensors in list passed to 'b' of 'InPolymorphicTwice' "
"Op have types [string] that do not match type int32 "
"inferred from earlier arguments.")
def testNIntsOut(self):
with ops.Graph().as_default():
out1, out2 = op_def_library.apply_op("NIntsOut", N=2, name="n")
self.assertEqual(dtypes.int32, out1.dtype)
self.assertEqual(dtypes.int32, out2.dtype)
self.assertProtoEquals("""
name: 'n' op: 'NIntsOut' attr { key: 'N' value { i: 2 } }
""", out1.op.node_def)
out1, out2, out3, out4, out5 = op_def_library.apply_op(
"NIntsOut", N=5, name="o")
self.assertEqual(dtypes.int32, out1.dtype)
self.assertEqual(dtypes.int32, out2.dtype)
self.assertEqual(dtypes.int32, out3.dtype)
self.assertEqual(dtypes.int32, out4.dtype)
self.assertEqual(dtypes.int32, out5.dtype)
self.assertProtoEquals("""
name: 'o' op: 'NIntsOut' attr { key: 'N' value { i: 5 } }
""", out5.op.node_def)
with self.assertRaises(ValueError) as cm:
op_def_library.apply_op("NIntsOut", N=1)
self.assertEqual(
str(cm.exception),
"Attr 'N' of 'NIntsOut' Op passed 1 less than minimum 2.")
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("NIntsOut", N=[3])
self.assertEqual(str(cm.exception),
"Expected int for argument 'N' not [3].")
def testNIntsOutDefault(self):
with ops.Graph().as_default():
out1, out2, out3 = op_def_library.apply_op(
"NIntsOutDefault", N=None, name="z")
self.assertEqual(dtypes.int32, out1.dtype)
self.assertEqual(dtypes.int32, out2.dtype)
self.assertEqual(dtypes.int32, out3.dtype)
self.assertProtoEquals("""
name: 'z' op: 'NIntsOutDefault' attr { key: 'N' value { i: 3 } }
""", out1.op.node_def)
out1, out2 = op_def_library.apply_op("NIntsOutDefault", N=2, name="y")
self.assertEqual(dtypes.int32, out1.dtype)
self.assertEqual(dtypes.int32, out2.dtype)
self.assertProtoEquals("""
name: 'y' op: 'NIntsOutDefault' attr { key: 'N' value { i: 2 } }
""", out2.op.node_def)
def testNPolymorphicOut(self):
with ops.Graph().as_default():
out1, out2 = op_def_library.apply_op(
"NPolymorphicOut", N=2, T=dtypes.int32, name="n")
self.assertEqual(dtypes.int32, out1.dtype)
self.assertEqual(dtypes.int32, out2.dtype)
self.assertProtoEquals("""
name: 'n' op: 'NPolymorphicOut'
attr { key: 'T' value { type: DT_INT32 } }
attr { key: 'N' value { i: 2 } }
""", out1.op.node_def)
out1, out2, out3 = op_def_library.apply_op(
"NPolymorphicOut", T=dtypes.string, N=3, name="o")
self.assertEqual(dtypes.string, out1.dtype)
self.assertEqual(dtypes.string, out2.dtype)
self.assertEqual(dtypes.string, out3.dtype)
self.assertProtoEquals("""
name: 'o' op: 'NPolymorphicOut'
attr { key: 'T' value { type: DT_STRING } }
attr { key: 'N' value { i: 3 } }
""", out3.op.node_def)
with self.assertRaises(ValueError) as cm:
op_def_library.apply_op("NPolymorphicOut", N=1, T=dtypes.string)
self.assertEqual(str(cm.exception),
"Attr 'N' of 'NPolymorphicOut' Op "
"passed 1 less than minimum 2.")
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("NPolymorphicOut", N=3, T=[dtypes.string])
self.assertEqual(
str(cm.exception),
"Expected DataType for argument 'T' not [tf.string].")
def testNPolymorphicOutDefault(self):
with ops.Graph().as_default():
out1, out2 = op_def_library.apply_op(
"NPolymorphicOutDefault", N=None, T=None, name="r")
self.assertEqual(dtypes.bool, out1.dtype)
self.assertEqual(dtypes.bool, out2.dtype)
self.assertProtoEquals("""
name: 'r' op: 'NPolymorphicOutDefault'
attr { key: 'T' value { type: DT_BOOL } }
attr { key: 'N' value { i: 2 } }
""", out1.op.node_def)
out1, out2, out3 = op_def_library.apply_op(
"NPolymorphicOutDefault", N=3, T=None, name="s")
self.assertEqual(dtypes.bool, out1.dtype)
self.assertEqual(dtypes.bool, out2.dtype)
self.assertEqual(dtypes.bool, out3.dtype)
self.assertProtoEquals("""
name: 's' op: 'NPolymorphicOutDefault'
attr { key: 'T' value { type: DT_BOOL } }
attr { key: 'N' value { i: 3 } }
""", out1.op.node_def)
out1, out2 = op_def_library.apply_op(
"NPolymorphicOutDefault", N=None, T=dtypes.int32, name="t")
self.assertEqual(dtypes.int32, out1.dtype)
self.assertEqual(dtypes.int32, out2.dtype)
self.assertProtoEquals("""
name: 't' op: 'NPolymorphicOutDefault'
attr { key: 'T' value { type: DT_INT32 } }
attr { key: 'N' value { i: 2 } }
""", out1.op.node_def)
out1, out2, out3 = op_def_library.apply_op(
"NPolymorphicOutDefault", N=3, T=dtypes.int32, name="u")
self.assertEqual(dtypes.int32, out1.dtype)
self.assertEqual(dtypes.int32, out2.dtype)
self.assertEqual(dtypes.int32, out3.dtype)
self.assertProtoEquals("""
name: 'u' op: 'NPolymorphicOutDefault'
attr { key: 'T' value { type: DT_INT32 } }
attr { key: 'N' value { i: 3 } }
""", out1.op.node_def)
def testNPolymorphicRestrictOut(self):
with ops.Graph().as_default():
out1, out2, out3 = op_def_library.apply_op(
"NPolymorphicRestrictOut", N=3, T=dtypes.bool, name="u")
self.assertEqual(dtypes.bool, out1.dtype)
self.assertEqual(dtypes.bool, out2.dtype)
self.assertEqual(dtypes.bool, out3.dtype)
self.assertProtoEquals("""
name: 'u' op: 'NPolymorphicRestrictOut'
attr { key: 'T' value { type: DT_BOOL } }
attr { key: 'N' value { i: 3 } }
""", out1.op.node_def)
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("NPolymorphicRestrictOut", N=2, T=dtypes.int32)
self.assertEqual(str(cm.exception),
"Value passed to parameter 'T' has DataType int32 "
"not in list of allowed values: string, bool")
def testRef(self):
with ops.Graph().as_default():
out = op_def_library.apply_op("RefOut", T=dtypes.bool, name="o")
self.assertEqual(dtypes.bool_ref, out.dtype)
self.assertProtoEquals("""
name: 'o' op: 'RefOut'
attr { key: 'T' value { type: DT_BOOL } }
""", out.op.node_def)
op = op_def_library.apply_op("RefIn", a=out, name="i")
self.assertProtoEquals("""
name: 'i' op: 'RefIn' input: 'o'
attr { key: 'T' value { type: DT_BOOL } }
attr { key: "_class" value { list { s: "loc:@o" } } }
""", op.node_def)
# Can pass ref to non-ref input.
out = op_def_library.apply_op("RefOut", T=dtypes.int32, name="r")
out = op_def_library.apply_op("Simple", a=out, name="s")
self.assertProtoEquals("""
name: 's' op: 'Simple' input: 'r'
""", out.op.node_def)
# Can't pass non-ref to ref input.
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("RefIn", a=2)
self.assertEqual(
str(cm.exception),
"'RefIn' Op requires that input 'a' be a mutable tensor " +
"(e.g.: a tf.Variable)")
input_a = op_def_library.apply_op("RefOut", T=dtypes.int32, name="t")
input_b = op_def_library.apply_op("RefOut", T=dtypes.int32, name="u")
op = op_def_library.apply_op("TwoRefsIn", a=input_a, b=input_b, name="v")
# NOTE(mrry): The order of colocation constraints is an implementation
# detail.
self.assertProtoEquals("""
name: 'v' op: 'TwoRefsIn' input: 't' input: 'u'
attr { key: 'T' value { type: DT_INT32 } }
attr { key: "_class" value { list { s: "loc:@t" s: "loc:@u" } } }
""", op.node_def)
def testSpecifyDevice(self):
graph = ops.Graph()
with graph.as_default():
with graph.device("/job:ADevice"):
op_def_library.apply_op("Simple", a=3)
# We look at the whole graph here to make sure the Const op is also given
# the specified device.
graph_def = graph.as_graph_def()
self.assertEqual(len(graph_def.node), 2)
for node in graph_def.node:
self.assertDeviceEqual(node.device, "/job:ADevice")
def testStructuredOutputSingleList(self):
with ops.Graph().as_default():
for n_a in [0, 1, 3]:
a = op_def_library.apply_op("SimpleStruct", n_a=n_a)
self.assertIsInstance(a, list)
self.assertEqual(n_a, len(a))
def testStructuredOutputListAndSingle(self):
with ops.Graph().as_default():
for n_a in [0, 1, 3]:
a, b = op_def_library.apply_op("MixedStruct", n_a=n_a)
self.assertIsInstance(a, list)
self.assertEqual(n_a, len(a))
self.assertTrue(all(x.dtype == dtypes.int32 for x in a))
self.assertIsInstance(b, tensor.Tensor)
self.assertEqual(dtypes.float32, b.dtype)
def testStructuredOutputMultipleLists(self):
with ops.Graph().as_default():
for n_a in [0, 1, 3]:
for n_b in [0, 1, 3]:
for t_c in [[],
[dtypes.int32],
[dtypes.int32, dtypes.float32]]:
a, b, c = op_def_library.apply_op(
"ComplexStruct", n_a=n_a, n_b=n_b, t_c=t_c)
self.assertEqual(n_a, len(a))
self.assertTrue(all(x.dtype == dtypes.int32 for x in a))
self.assertEqual(n_b, len(b))
self.assertTrue(all(x.dtype == dtypes.int64 for x in b))
self.assertEqual(t_c, [x.dtype for x in c])
@test_util.add_graph_building_optimization_tests
class OpDefLibraryGraphTest(test_util.TensorFlowTestCase):
def testNoGraph(self):
out = op_def_library.apply_op("Simple", a=3)
self.assertEqual(out.graph, ops.get_default_graph())
def testDefaultGraph(self):
graph = ops.Graph()
with graph.as_default():
out = op_def_library.apply_op("Simple", a=3)
self.assertEqual(out.graph, graph)
def testDifferentGraphFails(self):
with ops.Graph().as_default():
a = op_def_library.apply_op("Simple", a=3)
with ops.Graph().as_default():
b = op_def_library.apply_op("Simple", a=4)
with self.assertRaises(ValueError) as cm:
op_def_library.apply_op("Binary", a=a, b=b)
self.assertIn("must be from the same graph", str(cm.exception))
class OpDefLibraryPybindTest(test_util.TensorFlowTestCase):
def testPybind(self):
x = constant_op.constant(32, dtype=dtypes.float32)
y = constant_op.constant(32, dtype=dtypes.float32)
attrs, inputs, input_types, output_structure = (
op_def_library_pybind.process_inputs("AddV2", 1, {
"x": x,
"y": y
}))
proto = text_format.Parse("type: DT_FLOAT", attr_value_pb2.AttrValue())
self.assertEqual(attrs, {"T": proto})
self.assertEqual(inputs, [x, y])
self.assertEqual(input_types, [dtypes.float32, dtypes.float32])
self.assertEqual(output_structure, [None])
if __name__ == "__main__":
googletest.main()