tensorflow/tensorflow

View on GitHub
tensorflow/python/framework/dtypes_test.py

Summary

Maintainability
F
2 wks
Test Coverage
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.framework.dtypes."""

from absl.testing import parameterized
import numpy as np

# pylint: disable=g-bad-import-order
from tensorflow.python.framework import _dtypes
# pylint: enable=g-bad-import-order

from tensorflow.core.framework import types_pb2
from tensorflow.core.function import trace_type
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest


def _is_numeric_dtype_enum(datatype_enum):
  non_numeric_dtypes = [
      types_pb2.DT_VARIANT, types_pb2.DT_VARIANT_REF, types_pb2.DT_INVALID,
      types_pb2.DT_RESOURCE, types_pb2.DT_RESOURCE_REF
  ]
  return datatype_enum not in non_numeric_dtypes


class TypesTest(test_util.TensorFlowTestCase, parameterized.TestCase):

  def testAllTypesConstructible(self):
    for datatype_enum in types_pb2.DataType.values():
      if datatype_enum == types_pb2.DT_INVALID:
        continue
      self.assertEqual(datatype_enum,
                       dtypes.DType(datatype_enum).as_datatype_enum)

  def testAllTypesConvertibleToDType(self):
    for datatype_enum in types_pb2.DataType.values():
      if datatype_enum == types_pb2.DT_INVALID:
        continue
      dt = dtypes.as_dtype(datatype_enum)
      self.assertEqual(datatype_enum, dt.as_datatype_enum)

  def testAllTypesConvertibleToNumpyDtype(self):
    for datatype_enum in types_pb2.DataType.values():
      if not _is_numeric_dtype_enum(datatype_enum):
        continue
      dtype = dtypes.as_dtype(datatype_enum)
      numpy_dtype = dtype.as_numpy_dtype
      _ = np.empty((1, 1, 1, 1), dtype=numpy_dtype)
      if dtype.base_dtype != dtypes.bfloat16:
        # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
        self.assertEqual(
            dtypes.as_dtype(datatype_enum).base_dtype,
            dtypes.as_dtype(numpy_dtype))

  def testAllPybind11DTypeConvertibleToDType(self):
    for datatype_enum in types_pb2.DataType.values():
      if datatype_enum == types_pb2.DT_INVALID:
        continue
      dtype = _dtypes.DType(datatype_enum)
      self.assertEqual(dtypes.as_dtype(datatype_enum), dtype)

  def testInvalid(self):
    with self.assertRaises(TypeError):
      dtypes.DType(types_pb2.DT_INVALID)
    with self.assertRaises(TypeError):
      dtypes.as_dtype(types_pb2.DT_INVALID)

  def testNumpyConversion(self):
    self.assertIs(dtypes.float32, dtypes.as_dtype(np.float32))
    self.assertIs(dtypes.float64, dtypes.as_dtype(np.float64))
    self.assertIs(dtypes.int32, dtypes.as_dtype(np.int32))
    self.assertIs(dtypes.int64, dtypes.as_dtype(np.int64))
    self.assertIs(dtypes.uint8, dtypes.as_dtype(np.uint8))
    self.assertIs(dtypes.uint16, dtypes.as_dtype(np.uint16))
    self.assertIs(dtypes.int16, dtypes.as_dtype(np.int16))
    self.assertIs(dtypes.int8, dtypes.as_dtype(np.int8))
    self.assertIs(dtypes.complex64, dtypes.as_dtype(np.complex64))
    self.assertIs(dtypes.complex128, dtypes.as_dtype(np.complex128))
    self.assertIs(dtypes.string, dtypes.as_dtype(np.object_))
    self.assertIs(dtypes.string,
                  dtypes.as_dtype(np.array(["foo", "bar"]).dtype))
    self.assertIs(dtypes.bool, dtypes.as_dtype(np.bool_))
    self.assertIs(dtypes.float8_e5m2, dtypes.as_dtype(dtypes._np_float8_e5m2))
    self.assertIs(dtypes.float8_e4m3fn,
                  dtypes.as_dtype(dtypes._np_float8_e4m3fn))
    with self.assertRaises(TypeError):
      dtypes.as_dtype(np.dtype([("f1", np.uint), ("f2", np.int32)]))

    class AnObject(object):
      dtype = "f4"

    self.assertIs(dtypes.float32, dtypes.as_dtype(AnObject))

    class AnotherObject(object):
      dtype = np.dtype(np.complex64)

    self.assertIs(dtypes.complex64, dtypes.as_dtype(AnotherObject))

  def testRealDtype(self):
    for dtype in [
        dtypes.float32,
        dtypes.float64,
        dtypes.bool,
        dtypes.uint8,
        dtypes.int8,
        dtypes.int16,
        dtypes.int32,
        dtypes.int64,
        dtypes.float8_e5m2,
        dtypes.float8_e4m3fn,
        dtypes.int4,
        dtypes.uint4,
    ]:
      self.assertIs(dtype.real_dtype, dtype)
    self.assertIs(dtypes.complex64.real_dtype, dtypes.float32)
    self.assertIs(dtypes.complex128.real_dtype, dtypes.float64)

  def testStringConversion(self):
    self.assertIs(dtypes.float32, dtypes.as_dtype("float32"))
    self.assertIs(dtypes.float64, dtypes.as_dtype("float64"))
    self.assertIs(dtypes.int32, dtypes.as_dtype("int32"))
    self.assertIs(dtypes.uint8, dtypes.as_dtype("uint8"))
    self.assertIs(dtypes.uint16, dtypes.as_dtype("uint16"))
    self.assertIs(dtypes.int16, dtypes.as_dtype("int16"))
    self.assertIs(dtypes.int8, dtypes.as_dtype("int8"))
    self.assertIs(dtypes.string, dtypes.as_dtype("string"))
    self.assertIs(dtypes.complex64, dtypes.as_dtype("complex64"))
    self.assertIs(dtypes.complex128, dtypes.as_dtype("complex128"))
    self.assertIs(dtypes.int64, dtypes.as_dtype("int64"))
    self.assertIs(dtypes.bool, dtypes.as_dtype("bool"))
    self.assertIs(dtypes.qint8, dtypes.as_dtype("qint8"))
    self.assertIs(dtypes.quint8, dtypes.as_dtype("quint8"))
    self.assertIs(dtypes.qint32, dtypes.as_dtype("qint32"))
    self.assertIs(dtypes.bfloat16, dtypes.as_dtype("bfloat16"))
    self.assertIs(dtypes.float8_e5m2, dtypes.as_dtype("float8_e5m2"))
    self.assertIs(dtypes.float8_e4m3fn, dtypes.as_dtype("float8_e4m3fn"))
    self.assertIs(dtypes.int4, dtypes.as_dtype("int4"))
    self.assertIs(dtypes.uint4, dtypes.as_dtype("uint4"))
    self.assertIs(dtypes.float32_ref, dtypes.as_dtype("float32_ref"))
    self.assertIs(dtypes.float64_ref, dtypes.as_dtype("float64_ref"))
    self.assertIs(dtypes.int32_ref, dtypes.as_dtype("int32_ref"))
    self.assertIs(dtypes.uint8_ref, dtypes.as_dtype("uint8_ref"))
    self.assertIs(dtypes.int16_ref, dtypes.as_dtype("int16_ref"))
    self.assertIs(dtypes.int8_ref, dtypes.as_dtype("int8_ref"))
    self.assertIs(dtypes.string_ref, dtypes.as_dtype("string_ref"))
    self.assertIs(dtypes.complex64_ref, dtypes.as_dtype("complex64_ref"))
    self.assertIs(dtypes.complex128_ref, dtypes.as_dtype("complex128_ref"))
    self.assertIs(dtypes.int64_ref, dtypes.as_dtype("int64_ref"))
    self.assertIs(dtypes.bool_ref, dtypes.as_dtype("bool_ref"))
    self.assertIs(dtypes.qint8_ref, dtypes.as_dtype("qint8_ref"))
    self.assertIs(dtypes.quint8_ref, dtypes.as_dtype("quint8_ref"))
    self.assertIs(dtypes.qint32_ref, dtypes.as_dtype("qint32_ref"))
    self.assertIs(dtypes.bfloat16_ref, dtypes.as_dtype("bfloat16_ref"))
    self.assertIs(dtypes.float8_e5m2_ref, dtypes.as_dtype("float8_e5m2_ref"))
    self.assertIs(dtypes.float8_e4m3fn_ref,
                  dtypes.as_dtype("float8_e4m3fn_ref"))
    self.assertIs(dtypes.int4_ref, dtypes.as_dtype("int4_ref"))
    self.assertIs(dtypes.uint4_ref, dtypes.as_dtype("uint4_ref"))
    with self.assertRaises(TypeError):
      dtypes.as_dtype("not_a_type")

  def testDTypesHaveUniqueNames(self):
    dtypez = []
    names = set()
    for datatype_enum in types_pb2.DataType.values():
      if datatype_enum == types_pb2.DT_INVALID:
        continue
      dtype = dtypes.as_dtype(datatype_enum)
      dtypez.append(dtype)
      names.add(dtype.name)
    self.assertEqual(len(dtypez), len(names))

  def testIsInteger(self):
    self.assertEqual(dtypes.as_dtype("int8").is_integer, True)
    self.assertEqual(dtypes.as_dtype("int16").is_integer, True)
    self.assertEqual(dtypes.as_dtype("int32").is_integer, True)
    self.assertEqual(dtypes.as_dtype("int64").is_integer, True)
    self.assertEqual(dtypes.as_dtype("uint8").is_integer, True)
    self.assertEqual(dtypes.as_dtype("uint16").is_integer, True)
    self.assertEqual(dtypes.as_dtype("complex64").is_integer, False)
    self.assertEqual(dtypes.as_dtype("complex128").is_integer, False)
    self.assertEqual(dtypes.as_dtype("float").is_integer, False)
    self.assertEqual(dtypes.as_dtype("double").is_integer, False)
    self.assertEqual(dtypes.as_dtype("string").is_integer, False)
    self.assertEqual(dtypes.as_dtype("bool").is_integer, False)
    self.assertEqual(dtypes.as_dtype("bfloat16").is_integer, False)
    self.assertEqual(dtypes.as_dtype("float8_e5m2").is_integer, False)
    self.assertEqual(dtypes.as_dtype("float8_e4m3fn").is_integer, False)
    self.assertEqual(dtypes.as_dtype("int4").is_integer, True)
    self.assertEqual(dtypes.as_dtype("uint4").is_integer, True)
    self.assertEqual(dtypes.as_dtype("qint8").is_integer, False)
    self.assertEqual(dtypes.as_dtype("qint16").is_integer, False)
    self.assertEqual(dtypes.as_dtype("qint32").is_integer, False)
    self.assertEqual(dtypes.as_dtype("quint8").is_integer, False)
    self.assertEqual(dtypes.as_dtype("quint16").is_integer, False)

  def testIsFloating(self):
    self.assertEqual(dtypes.as_dtype("int8").is_floating, False)
    self.assertEqual(dtypes.as_dtype("int16").is_floating, False)
    self.assertEqual(dtypes.as_dtype("int32").is_floating, False)
    self.assertEqual(dtypes.as_dtype("int64").is_floating, False)
    self.assertEqual(dtypes.as_dtype("uint8").is_floating, False)
    self.assertEqual(dtypes.as_dtype("uint16").is_floating, False)
    self.assertEqual(dtypes.as_dtype("complex64").is_floating, False)
    self.assertEqual(dtypes.as_dtype("complex128").is_floating, False)
    self.assertEqual(dtypes.as_dtype("float32").is_floating, True)
    self.assertEqual(dtypes.as_dtype("float64").is_floating, True)
    self.assertEqual(dtypes.as_dtype("string").is_floating, False)
    self.assertEqual(dtypes.as_dtype("bool").is_floating, False)
    self.assertEqual(dtypes.as_dtype("bfloat16").is_floating, True)
    self.assertEqual(dtypes.as_dtype("float8_e5m2").is_floating, True)
    self.assertEqual(dtypes.as_dtype("float8_e4m3fn").is_floating, True)
    self.assertEqual(dtypes.as_dtype("int4").is_floating, False)
    self.assertEqual(dtypes.as_dtype("uint4").is_floating, False)
    self.assertEqual(dtypes.as_dtype("qint8").is_floating, False)
    self.assertEqual(dtypes.as_dtype("qint16").is_floating, False)
    self.assertEqual(dtypes.as_dtype("qint32").is_floating, False)
    self.assertEqual(dtypes.as_dtype("quint8").is_floating, False)
    self.assertEqual(dtypes.as_dtype("quint16").is_floating, False)

  def testIsComplex(self):
    self.assertEqual(dtypes.as_dtype("int8").is_complex, False)
    self.assertEqual(dtypes.as_dtype("int16").is_complex, False)
    self.assertEqual(dtypes.as_dtype("int32").is_complex, False)
    self.assertEqual(dtypes.as_dtype("int64").is_complex, False)
    self.assertEqual(dtypes.as_dtype("uint8").is_complex, False)
    self.assertEqual(dtypes.as_dtype("uint16").is_complex, False)
    self.assertEqual(dtypes.as_dtype("complex64").is_complex, True)
    self.assertEqual(dtypes.as_dtype("complex128").is_complex, True)
    self.assertEqual(dtypes.as_dtype("float32").is_complex, False)
    self.assertEqual(dtypes.as_dtype("float64").is_complex, False)
    self.assertEqual(dtypes.as_dtype("string").is_complex, False)
    self.assertEqual(dtypes.as_dtype("bool").is_complex, False)
    self.assertEqual(dtypes.as_dtype("bfloat16").is_complex, False)
    self.assertEqual(dtypes.as_dtype("float8_e5m2").is_complex, False)
    self.assertEqual(dtypes.as_dtype("float8_e4m3fn").is_complex, False)
    self.assertEqual(dtypes.as_dtype("int4").is_complex, False)
    self.assertEqual(dtypes.as_dtype("uint4").is_complex, False)
    self.assertEqual(dtypes.as_dtype("qint8").is_complex, False)
    self.assertEqual(dtypes.as_dtype("qint16").is_complex, False)
    self.assertEqual(dtypes.as_dtype("qint32").is_complex, False)
    self.assertEqual(dtypes.as_dtype("quint8").is_complex, False)
    self.assertEqual(dtypes.as_dtype("quint16").is_complex, False)

  def testIsUnsigned(self):
    self.assertEqual(dtypes.as_dtype("int8").is_unsigned, False)
    self.assertEqual(dtypes.as_dtype("int16").is_unsigned, False)
    self.assertEqual(dtypes.as_dtype("int32").is_unsigned, False)
    self.assertEqual(dtypes.as_dtype("int64").is_unsigned, False)
    self.assertEqual(dtypes.as_dtype("uint8").is_unsigned, True)
    self.assertEqual(dtypes.as_dtype("uint16").is_unsigned, True)
    self.assertEqual(dtypes.as_dtype("float32").is_unsigned, False)
    self.assertEqual(dtypes.as_dtype("float64").is_unsigned, False)
    self.assertEqual(dtypes.as_dtype("bool").is_unsigned, False)
    self.assertEqual(dtypes.as_dtype("string").is_unsigned, False)
    self.assertEqual(dtypes.as_dtype("complex64").is_unsigned, False)
    self.assertEqual(dtypes.as_dtype("complex128").is_unsigned, False)
    self.assertEqual(dtypes.as_dtype("bfloat16").is_unsigned, False)
    self.assertEqual(dtypes.as_dtype("float8_e5m2").is_unsigned, False)
    self.assertEqual(dtypes.as_dtype("float8_e4m3fn").is_unsigned, False)
    self.assertEqual(dtypes.as_dtype("int4").is_unsigned, False)
    self.assertEqual(dtypes.as_dtype("uint4").is_unsigned, True)
    self.assertEqual(dtypes.as_dtype("qint8").is_unsigned, False)
    self.assertEqual(dtypes.as_dtype("qint16").is_unsigned, False)
    self.assertEqual(dtypes.as_dtype("qint32").is_unsigned, False)
    self.assertEqual(dtypes.as_dtype("quint8").is_unsigned, False)
    self.assertEqual(dtypes.as_dtype("quint16").is_unsigned, False)

  def testMinMax(self):
    # make sure min/max evaluates for all data types that have min/max
    for datatype_enum in types_pb2.DataType.values():
      if not _is_numeric_dtype_enum(datatype_enum):
        continue
      dtype = dtypes.as_dtype(datatype_enum)
      numpy_dtype = dtype.as_numpy_dtype

      # ignore types for which there are no minimum/maximum (or we cannot
      # compute it, such as for the q* types)
      if (dtype.is_quantized or dtype.base_dtype == dtypes.bool or
          dtype.base_dtype == dtypes.string or
          dtype.base_dtype == dtypes.complex64 or
          dtype.base_dtype == dtypes.complex128):
        continue

      print("%s: %s - %s" % (dtype, dtype.min, dtype.max))

      # check some values that are known
      if numpy_dtype == np.bool_:
        self.assertEqual(dtype.min, 0)
        self.assertEqual(dtype.max, 1)
      if numpy_dtype == np.int8:
        self.assertEqual(dtype.min, -128)
        self.assertEqual(dtype.max, 127)
      if numpy_dtype == np.int16:
        self.assertEqual(dtype.min, -32768)
        self.assertEqual(dtype.max, 32767)
      if numpy_dtype == np.int32:
        self.assertEqual(dtype.min, -2147483648)
        self.assertEqual(dtype.max, 2147483647)
      if numpy_dtype == np.int64:
        self.assertEqual(dtype.min, -9223372036854775808)
        self.assertEqual(dtype.max, 9223372036854775807)
      if numpy_dtype == np.uint8:
        self.assertEqual(dtype.min, 0)
        self.assertEqual(dtype.max, 255)
      if numpy_dtype == np.uint16:
        if dtype == dtypes.uint16:
          self.assertEqual(dtype.min, 0)
          self.assertEqual(dtype.max, 65535)
        elif dtype == dtypes.bfloat16:
          self.assertEqual(dtype.min, 0)
          self.assertEqual(dtype.max, 4294967295)
      if numpy_dtype == np.uint32:
        self.assertEqual(dtype.min, 0)
        self.assertEqual(dtype.max, 4294967295)
      if numpy_dtype == np.uint64:
        self.assertEqual(dtype.min, 0)
        self.assertEqual(dtype.max, 18446744073709551615)
      if numpy_dtype in (np.float16, np.float32, np.float64):
        self.assertEqual(dtype.min, np.finfo(numpy_dtype).min)
        self.assertEqual(dtype.max, np.finfo(numpy_dtype).max)
      if numpy_dtype == dtypes.bfloat16.as_numpy_dtype:
        self.assertEqual(dtype.min, float.fromhex("-0x1.FEp127"))
        self.assertEqual(dtype.max, float.fromhex("0x1.FEp127"))
      if numpy_dtype == dtypes.float8_e5m2.as_numpy_dtype:
        self.assertEqual(dtype.min, -57344.0)
        self.assertEqual(dtype.max, 57344.0)
      if numpy_dtype == dtypes.float8_e4m3fn.as_numpy_dtype:
        self.assertEqual(dtype.min, -448.0)
        self.assertEqual(dtype.max, 448.0)
      if numpy_dtype == dtypes.int4.as_numpy_dtype:
        self.assertEqual(dtype.min, -8)
        self.assertEqual(dtype.max, 7)
      if numpy_dtype == dtypes.uint4.as_numpy_dtype:
        self.assertEqual(dtype.min, 0)
        self.assertEqual(dtype.max, 15)

  def testLimitsUndefinedError(self):
    with self.assertRaises(ValueError):
      dtypes.string.limits()

  def testRepr(self):
    self.skipTest("b/142725777")
    for enum, name in dtypes._TYPE_TO_STRING.items():
      if enum > 100:
        continue
      dtype = dtypes.DType(enum)
      self.assertEqual(repr(dtype), "tf." + name)
      import tensorflow as tf
      dtype2 = eval(repr(dtype))
      self.assertEqual(type(dtype2), dtypes.DType)
      self.assertEqual(dtype, dtype2)

  def testEqWithNonTFTypes(self):
    self.assertNotEqual(dtypes.int32, int)
    self.assertNotEqual(dtypes.float64, 2.1)

  def testPythonLongConversion(self):
    self.assertIs(dtypes.int64, dtypes.as_dtype(np.array(2**32).dtype))

  def testPythonTypesConversion(self):
    self.assertIs(dtypes.float32, dtypes.as_dtype(float))
    self.assertIs(dtypes.bool, dtypes.as_dtype(bool))

  def testReduce(self):
    for enum in dtypes._TYPE_TO_STRING:
      dtype = dtypes.DType(enum)
      ctor, args = dtype.__reduce__()
      self.assertEqual(ctor, dtypes.as_dtype)
      self.assertEqual(args, (dtype.name,))
      reconstructed = ctor(*args)
      self.assertEqual(reconstructed, dtype)

  def testAsDtypeInvalidArgument(self):
    with self.assertRaises(TypeError):
      dtypes.as_dtype((dtypes.int32, dtypes.float32))

  def testAsDtypeReturnsInternedVersion(self):
    dt = dtypes.DType(types_pb2.DT_VARIANT)
    self.assertIs(dtypes.as_dtype(dt), dtypes.variant)

  def testDTypeSubtypes(self):
    self.assertTrue(dtypes.string.is_subtype_of(dtypes.string))
    self.assertFalse(dtypes.string.is_subtype_of(dtypes.uint32))
    self.assertTrue(dtypes.uint64.is_subtype_of(dtypes.uint64))

  def testDTypeSupertypes(self):
    self.assertEqual(dtypes.string,
                     dtypes.string.most_specific_common_supertype([]))
    self.assertEqual(
        dtypes.string,
        dtypes.string.most_specific_common_supertype([dtypes.string]))
    self.assertIsNone(
        dtypes.string.most_specific_common_supertype([dtypes.uint32]))

  @parameterized.parameters(*tuple(dtype for dtype in dtypes.TF_VALUE_DTYPES))
  def testDTypeSerialization(self, dtype):
    self.assertEqual(trace_type.deserialize(trace_type.serialize(dtype)), dtype)


if __name__ == "__main__":
  googletest.main()