tensorflow/tensorflow

View on GitHub
tensorflow/python/kernel_tests/check_ops_test.py

Summary

Maintainability
F
2 mos
Test Coverage
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.check_ops."""

import time

import numpy as np

from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test


# pylint:disable=g-error-prone-assert-raises
class AssertV2Asserts(test.TestCase):

  def test_passes_when_it_should(self):
    # This is a v2 test and need to run eagerly
    with context.eager_mode():
      c1 = constant_op.constant(-1, name="minus_one", dtype=dtypes.int32)
      c2 = constant_op.constant(2, name="two", dtype=dtypes.int32)
      c3 = constant_op.constant([3., 3.], name="three", dtype=dtypes.float32)
      c4 = constant_op.constant([3., 3.5], name="three_and_a_half",
                                dtype=dtypes.float32)
      scalar = c1
      non_scalar = c3
      integer = c1
      non_integer = c3
      positive = c2
      negative = c1
      cases = [
          (check_ops.assert_equal_v2, (c1, c1), (c1, c2)),
          (check_ops.assert_less_v2, (c1, c2), (c1, c1)),
          (check_ops.assert_near_v2, (c3, c3), (c3, c4)),
          (check_ops.assert_greater_v2, (c2, c1), (c1, c1)),
          (check_ops.assert_negative_v2, (negative,), (positive,)),
          (check_ops.assert_positive_v2, (positive,), (negative,)),
          (check_ops.assert_less_equal_v2, (c1, c1), (c2, c1)),
          (check_ops.assert_none_equal_v2, (c1, c2), (c3, c4)),
          (check_ops.assert_non_negative_v2, (positive,), (negative,)),
          (check_ops.assert_non_positive_v2, (negative,), (positive,)),
          (check_ops.assert_greater_equal_v2, (c1, c1), (c1, c2)),
          (check_ops.assert_type_v2, (c1, dtypes.int32), (c1, dtypes.float32),
           TypeError),
          (check_ops.assert_integer_v2, (integer,), (non_integer,),
           TypeError),
          (check_ops.assert_scalar_v2, (scalar,), (non_scalar,),
           ValueError),
          (check_ops.assert_rank_v2, (c1, 0), (c3, 2), ValueError),
          (check_ops.assert_rank_in_v2, (c1, [0, 1]), (c1, [1, 2]),
           ValueError),
          (check_ops.assert_rank_at_least_v2, (non_scalar, 1), (scalar, 1),
           ValueError),
      ]

      for case in cases:
        fn = case[0]
        passing_args = case[1]
        failing_args = case[2]
        error = errors.InvalidArgumentError if len(case) < 4 else case[3]

        print("Testing %s passing properly." % fn)

        fn(*passing_args)

        print("Testing %s failing properly." % fn)

        @def_function.function
        def failing_fn():
          fn(*failing_args, message="fail")  # pylint: disable=cell-var-from-loop

        with self.assertRaisesRegex(error, "fail"):
          failing_fn()

        del failing_fn


class AssertProperIterableTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_single_tensor_raises(self):
    tensor = constant_op.constant(1)
    with self.assertRaisesRegex(TypeError, "proper"):
      check_ops.assert_proper_iterable(tensor)

  @test_util.run_in_graph_and_eager_modes
  def test_single_sparse_tensor_raises(self):
    ten = sparse_tensor.SparseTensor(
        indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
    with self.assertRaisesRegex(TypeError, "proper"):
      check_ops.assert_proper_iterable(ten)

  @test_util.run_in_graph_and_eager_modes
  def test_single_ndarray_raises(self):
    array = np.array([1, 2, 3])
    with self.assertRaisesRegex(TypeError, "proper"):
      check_ops.assert_proper_iterable(array)

  @test_util.run_in_graph_and_eager_modes
  def test_single_string_raises(self):
    mystr = "hello"
    with self.assertRaisesRegex(TypeError, "proper"):
      check_ops.assert_proper_iterable(mystr)

  @test_util.run_in_graph_and_eager_modes
  def test_non_iterable_object_raises(self):
    non_iterable = 1234
    with self.assertRaisesRegex(TypeError, "to be iterable"):
      check_ops.assert_proper_iterable(non_iterable)

  @test_util.run_in_graph_and_eager_modes
  def test_list_does_not_raise(self):
    list_of_stuff = [
        constant_op.constant([11, 22]), constant_op.constant([1, 2])
    ]
    check_ops.assert_proper_iterable(list_of_stuff)

  @test_util.run_in_graph_and_eager_modes
  def test_generator_does_not_raise(self):
    generator_of_stuff = (constant_op.constant([11, 22]), constant_op.constant(
        [1, 2]))
    check_ops.assert_proper_iterable(generator_of_stuff)


class AssertEqualTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_equal(self):
    small = constant_op.constant([1, 2], name="small")
    with ops.control_dependencies([check_ops.assert_equal(small, small)]):
      out = array_ops.identity(small)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_scalar_comparison(self):
    const_true = constant_op.constant(True, name="true")
    const_false = constant_op.constant(False, name="false")
    with self.assertRaisesRegex(errors.InvalidArgumentError, "fail"):
      check_ops.assert_equal(const_true, const_false, message="fail")

  def test_returns_none_with_eager(self):
    with context.eager_mode():
      small = constant_op.constant([1, 2], name="small")
      x = check_ops.assert_equal(small, small)
      assert x is None

  @test_util.run_in_graph_and_eager_modes
  @test_util.run_deprecated_v1
  def test_raises_when_greater(self):
    # Static check
    static_small = constant_op.constant([1, 2], name="small")
    static_big = constant_op.constant([3, 4], name="big")
    with self.assertRaisesRegex(errors.InvalidArgumentError, "fail"):
      check_ops.assert_equal(static_big, static_small, message="fail")

  @test_util.run_deprecated_v1
  def test_raises_when_greater_dynamic(self):
    with self.cached_session():
      small = array_ops.placeholder(dtypes.int32, name="small")
      big = array_ops.placeholder(dtypes.int32, name="big")
      with ops.control_dependencies(
          [check_ops.assert_equal(big, small, message="fail")]):
        out = array_ops.identity(small)
      with self.assertRaisesOpError("fail.*big.*small"):
        out.eval(feed_dict={small: [1, 2], big: [3, 4]})

  def test_error_message_eager(self):
    expected_error_msg_full = r"""big does not equal small
Condition x == y did not hold.
Indices of first 3 different values:
\[\[0 0\]
 \[1 1\]
 \[2 0\]\]
Corresponding x values:
\[2 3 6\]
Corresponding y values:
\[20 30 60\]
First 6 elements of x:
\[2 2 3 3 6 6\]
First 6 elements of y:
\[20  2  3 30 60  6\]"""
    expected_error_msg_default = r"""big does not equal small
Condition x == y did not hold.
Indices of first 3 different values:
\[\[0 0\]
 \[1 1\]
 \[2 0\]\]
Corresponding x values:
\[2 3 6\]
Corresponding y values:
\[20 30 60\]
First 3 elements of x:
\[2 2 3\]
First 3 elements of y:
\[20  2  3\]"""
    expected_error_msg_short = r"""big does not equal small
Condition x == y did not hold.
Indices of first 2 different values:
\[\[0 0\]
 \[1 1\]\]
Corresponding x values:
\[2 3\]
Corresponding y values:
\[20 30\]
First 2 elements of x:
\[2 2\]
First 2 elements of y:
\[20  2\]"""
    with context.eager_mode():
      big = constant_op.constant([[2, 2], [3, 3], [6, 6]])
      small = constant_op.constant([[20, 2], [3, 30], [60, 6]])
      with self.assertRaisesRegex(errors.InvalidArgumentError,
                                  expected_error_msg_full):
        check_ops.assert_equal(big, small, message="big does not equal small",
                               summarize=10)
      with self.assertRaisesRegex(errors.InvalidArgumentError,
                                  expected_error_msg_default):
        check_ops.assert_equal(big, small, message="big does not equal small")
      with self.assertRaisesRegex(errors.InvalidArgumentError,
                                  expected_error_msg_short):
        check_ops.assert_equal(big, small, message="big does not equal small",
                               summarize=2)

  @test_util.run_in_graph_and_eager_modes
  @test_util.run_deprecated_v1
  def test_raises_when_less(self):
    # Static check
    static_small = constant_op.constant([3, 1], name="small")
    static_big = constant_op.constant([4, 2], name="big")
    with self.assertRaisesRegex(errors.InvalidArgumentError, "fail"):
      check_ops.assert_equal(static_big, static_small, message="fail")

  @test_util.run_deprecated_v1
  def test_raises_when_less_dynamic(self):
    with self.cached_session():
      small = array_ops.placeholder(dtypes.int32, name="small")
      big = array_ops.placeholder(dtypes.int32, name="big")
      with ops.control_dependencies([check_ops.assert_equal(small, big)]):
        out = array_ops.identity(small)
      with self.assertRaisesOpError("small.*big"):
        out.eval(feed_dict={small: [3, 1], big: [4, 2]})

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_equal_and_broadcastable_shapes(self):
    small = constant_op.constant([[1, 2], [1, 2]], name="small")
    small_2 = constant_op.constant([1, 2], name="small_2")
    with ops.control_dependencies([check_ops.assert_equal(small, small_2)]):
      out = array_ops.identity(small)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_raises_when_equal_but_non_broadcastable_shapes(self):
    small = constant_op.constant([1, 1, 1], name="small")
    small_2 = constant_op.constant([1, 1], name="small_2")
    # The exception in eager and non-eager mode is different because
    # eager mode relies on shape check done as part of the C++ op, while
    # graph mode does shape checks when creating the `Operation` instance.
    with self.assertRaisesIncompatibleShapesError(
        (errors.InvalidArgumentError, ValueError)):
      with ops.control_dependencies([check_ops.assert_equal(small, small_2)]):
        out = array_ops.identity(small)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_raises_when_not_equal_and_broadcastable_shapes(self):
    cond = constant_op.constant([True, False], name="small")
    with self.assertRaisesRegex(errors.InvalidArgumentError, "fail"):
      check_ops.assert_equal(cond, False, message="fail")

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_both_empty(self):
    larry = constant_op.constant([])
    curly = constant_op.constant([])
    with ops.control_dependencies([check_ops.assert_equal(larry, curly)]):
      out = array_ops.identity(larry)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_noop_when_both_identical(self):
    larry = constant_op.constant([])
    check_op = check_ops.assert_equal(larry, larry)
    if context.executing_eagerly():
      self.assertIs(check_op, None)
    else:
      self.assertEqual(check_op.type, "NoOp")


class AssertNoneEqualTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_not_equal(self):
    small = constant_op.constant([1, 2], name="small")
    big = constant_op.constant([10, 20], name="small")
    with ops.control_dependencies(
        [check_ops.assert_none_equal(big, small)]):
      out = array_ops.identity(small)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  @test_util.run_deprecated_v1
  def test_raises_when_equal(self):
    small = constant_op.constant([3, 1], name="small")
    with self.assertRaisesOpError("x != y did not hold"):
      with ops.control_dependencies(
          [check_ops.assert_none_equal(small, small)]):
        out = array_ops.identity(small)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_not_equal_and_broadcastable_shapes(self):
    small = constant_op.constant([1, 2], name="small")
    big = constant_op.constant([3], name="big")
    with ops.control_dependencies(
        [check_ops.assert_none_equal(small, big)]):
      out = array_ops.identity(small)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_raises_when_not_equal_but_non_broadcastable_shapes(self):
    small = constant_op.constant([1, 1, 1], name="small")
    big = constant_op.constant([10, 10], name="big")
    # The exception in eager and non-eager mode is different because
    # eager mode relies on shape check done as part of the C++ op, while
    # graph mode does shape checks when creating the `Operation` instance.
    with self.assertRaisesIncompatibleShapesError(
        (ValueError, errors.InvalidArgumentError)):
      with ops.control_dependencies(
          [check_ops.assert_none_equal(small, big)]):
        out = array_ops.identity(small)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_both_empty(self):
    larry = constant_op.constant([])
    curly = constant_op.constant([])
    with ops.control_dependencies(
        [check_ops.assert_none_equal(larry, curly)]):
      out = array_ops.identity(larry)
    self.evaluate(out)

  def test_returns_none_with_eager(self):
    with context.eager_mode():
      t1 = constant_op.constant([1, 2])
      t2 = constant_op.constant([3, 4])
      x = check_ops.assert_none_equal(t1, t2)
      assert x is None

  def test_static_check_in_graph_mode(self):
    with ops.Graph().as_default():
      with self.assertRaisesRegex(  # pylint:disable=g-error-prone-assert-raises
          errors.InvalidArgumentError, "Custom error message"):
        check_ops.assert_none_equal(1, 1, message="Custom error message")

  def test_error_message_eager(self):
    # Note that the following three strings are regexes
    expected_error_msg_full = r"""\[ *0\. +1\. +2\. +3\. +4\. +5\.\]"""
    expected_error_msg_default = r"""\[ *0\. +1\. +2\.\]"""
    expected_error_msg_short = r"""\[ *0\. +1\.\]"""
    with context.eager_mode():
      t = constant_op.constant(
          np.array(range(6)), shape=[2, 3], dtype=np.float32)
      with self.assertRaisesRegex(  # pylint:disable=g-error-prone-assert-raises
          errors.InvalidArgumentError, expected_error_msg_full):
        check_ops.assert_none_equal(
            t, t, message="This is the error message.", summarize=10)
      with self.assertRaisesRegex(  # pylint:disable=g-error-prone-assert-raises
          errors.InvalidArgumentError, expected_error_msg_full):
        check_ops.assert_none_equal(
            t, t, message="This is the error message.", summarize=-1)
      with self.assertRaisesRegex(  # pylint:disable=g-error-prone-assert-raises
          errors.InvalidArgumentError, expected_error_msg_default):
        check_ops.assert_none_equal(t, t, message="This is the error message.")
      with self.assertRaisesRegex(  # pylint:disable=g-error-prone-assert-raises
          errors.InvalidArgumentError, expected_error_msg_short):
        check_ops.assert_none_equal(
            t, t, message="This is the error message.", summarize=2)


class AssertAllCloseTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_equal(self):
    x = constant_op.constant(1., name="x")
    y = constant_op.constant(1., name="y")
    with ops.control_dependencies(
        [check_ops.assert_near(x, y, message="failure message")]):
      out = array_ops.identity(x)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_close_enough_32_bit_due_to_default_rtol(self):
    eps = np.finfo(np.float32).eps
    # Default rtol/atol is 10*eps
    x = constant_op.constant(1., name="x")
    y = constant_op.constant(1. + 2 * eps, name="y", dtype=np.float32)
    with ops.control_dependencies(
        [check_ops.assert_near(x, y, atol=0., message="failure message")]):
      out = array_ops.identity(x)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_close_enough_32_bit_due_to_default_atol(self):
    eps = np.finfo(np.float32).eps
    # Default rtol/atol is 10*eps
    x = constant_op.constant(0., name="x")
    y = constant_op.constant(0. + 2 * eps, name="y", dtype=np.float32)
    with ops.control_dependencies(
        [check_ops.assert_near(x, y, rtol=0., message="failure message")]):
      out = array_ops.identity(x)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_close_enough_64_bit_due_to_default_rtol(self):
    eps = np.finfo(np.float64).eps
    # Default rtol/atol is 10*eps
    x = constant_op.constant(1., name="x", dtype=np.float64)
    y = constant_op.constant(1. + 2 * eps, name="y", dtype=np.float64)
    with ops.control_dependencies(
        [check_ops.assert_near(x, y, atol=0., message="failure message")]):
      out = array_ops.identity(x)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_close_enough_64_bit_due_to_default_atol(self):
    eps = np.finfo(np.float64).eps
    # Default rtol/atol is 10*eps
    x = constant_op.constant(0., name="x", dtype=np.float64)
    y = constant_op.constant(0. + 2 * eps, name="y", dtype=np.float64)
    with ops.control_dependencies(
        [check_ops.assert_near(x, y, rtol=0., message="failure message")]):
      out = array_ops.identity(x)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_close_enough_due_to_custom_rtol(self):
    x = constant_op.constant(1., name="x")
    y = constant_op.constant(1.1, name="y")
    with ops.control_dependencies(
        [check_ops.assert_near(x, y, atol=0., rtol=0.5,
                               message="failure message")]):
      out = array_ops.identity(x)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_close_enough_due_to_custom_atol(self):
    x = constant_op.constant(0., name="x")
    y = constant_op.constant(0.1, name="y", dtype=np.float32)
    with ops.control_dependencies(
        [check_ops.assert_near(x, y, atol=0.5, rtol=0.,
                               message="failure message")]):
      out = array_ops.identity(x)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_both_empty(self):
    larry = constant_op.constant([])
    curly = constant_op.constant([])
    with ops.control_dependencies([check_ops.assert_near(larry, curly)]):
      out = array_ops.identity(larry)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_raises_when_atol_violated(self):
    x = constant_op.constant(10., name="x")
    y = constant_op.constant(10.2, name="y")
    with self.assertRaisesOpError(  # pylint:disable=g-error-prone-assert-raises
        "x and y not equal to tolerance"):
      with ops.control_dependencies(
          [check_ops.assert_near(x, y, atol=0.1,
                                 message="failure message")]):
        out = array_ops.identity(x)
        self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_raises_when_default_rtol_violated(self):
    x = constant_op.constant(0.1, name="x")
    y = constant_op.constant(0.0, name="y")
    with self.assertRaisesOpError(  # pylint:disable=g-error-prone-assert-raises
        "x and y not equal to tolerance"):
      with ops.control_dependencies(
          [check_ops.assert_near(x, y, message="failure message")]):
        out = array_ops.identity(x)
        self.evaluate(out)

  def test_returns_none_with_eager(self):
    with context.eager_mode():
      t1 = constant_op.constant([1., 2.])
      t2 = constant_op.constant([1., 2.])
      x = check_ops.assert_near(t1, t2)
      assert x is None

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_complex(self):
    x = constant_op.constant(1. + 0.1j, name="x")
    y = constant_op.constant(1.1 + 0.1j, name="y")
    with ops.control_dependencies([
        check_ops.assert_near(
            x, y, atol=0., rtol=0.5, message="failure message")
    ]):
      out = array_ops.identity(x)
      self.evaluate(out)


class AssertLessTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  @test_util.run_deprecated_v1
  def test_raises_when_equal(self):
    small = constant_op.constant([1, 2], name="small")
    with self.assertRaisesOpError(  # pylint:disable=g-error-prone-assert-raises
        "failure message.*\n*.* x < y did not hold"):
      with ops.control_dependencies(
          [check_ops.assert_less(
              small, small, message="failure message")]):
        out = array_ops.identity(small)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  @test_util.run_deprecated_v1
  def test_raises_when_greater(self):
    small = constant_op.constant([1, 2], name="small")
    big = constant_op.constant([3, 4], name="big")
    with self.assertRaisesOpError(  # pylint:disable=g-error-prone-assert-raises
        "x < y did not hold"):
      with ops.control_dependencies([check_ops.assert_less(big, small)]):
        out = array_ops.identity(small)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_less(self):
    small = constant_op.constant([3, 1], name="small")
    big = constant_op.constant([4, 2], name="big")
    with ops.control_dependencies([check_ops.assert_less(small, big)]):
      out = array_ops.identity(small)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_less_and_broadcastable_shapes(self):
    small = constant_op.constant([1], name="small")
    big = constant_op.constant([3, 2], name="big")
    with ops.control_dependencies([check_ops.assert_less(small, big)]):
      out = array_ops.identity(small)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_raises_when_less_but_non_broadcastable_shapes(self):
    small = constant_op.constant([1, 1, 1], name="small")
    big = constant_op.constant([3, 2], name="big")
    # The exception in eager and non-eager mode is different because
    # eager mode relies on shape check done as part of the C++ op, while
    # graph mode does shape checks when creating the `Operation` instance.
    with self.assertRaisesIncompatibleShapesError(
        (ValueError, errors.InvalidArgumentError)):
      with ops.control_dependencies([check_ops.assert_less(small, big)]):
        out = array_ops.identity(small)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_both_empty(self):
    larry = constant_op.constant([])
    curly = constant_op.constant([])
    with ops.control_dependencies([check_ops.assert_less(larry, curly)]):
      out = array_ops.identity(larry)
    self.evaluate(out)

  def test_returns_none_with_eager(self):
    with context.eager_mode():
      t1 = constant_op.constant([1, 2])
      t2 = constant_op.constant([3, 4])
      x = check_ops.assert_less(t1, t2)
      assert x is None

  def test_static_check_in_graph_mode(self):
    with ops.Graph().as_default():
      with self.assertRaisesRegex(  # pylint:disable=g-error-prone-assert-raises
          errors.InvalidArgumentError, "Custom error message"):
        check_ops.assert_less(1, 1, message="Custom error message")


class AssertLessEqualTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_equal(self):
    small = constant_op.constant([1, 2], name="small")
    with ops.control_dependencies(
        [check_ops.assert_less_equal(small, small)]):
      out = array_ops.identity(small)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  @test_util.run_deprecated_v1
  def test_raises_when_greater(self):
    small = constant_op.constant([1, 2], name="small")
    big = constant_op.constant([3, 4], name="big")
    with self.assertRaisesOpError(  # pylint:disable=g-error-prone-assert-raises
        "fail"):
      with ops.control_dependencies(
          [check_ops.assert_less_equal(
              big, small, message="fail")]):
        out = array_ops.identity(small)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_less_equal(self):
    small = constant_op.constant([1, 2], name="small")
    big = constant_op.constant([3, 2], name="big")
    with ops.control_dependencies([check_ops.assert_less_equal(small, big)]):
      out = array_ops.identity(small)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_less_equal_and_broadcastable_shapes(self):
    small = constant_op.constant([1], name="small")
    big = constant_op.constant([3, 1], name="big")
    with ops.control_dependencies([check_ops.assert_less_equal(small, big)]):
      out = array_ops.identity(small)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_raises_when_less_equal_but_non_broadcastable_shapes(self):
    small = constant_op.constant([3, 1], name="small")
    big = constant_op.constant([1, 1, 1], name="big")
    # The exception in eager and non-eager mode is different because
    # eager mode relies on shape check done as part of the C++ op, while
    # graph mode does shape checks when creating the `Operation` instance.
    with self.assertRaisesRegex(  # pylint:disable=g-error-prone-assert-raises
        (errors.InvalidArgumentError, ValueError),
        (r"Incompatible shapes: \[2\] vs. \[3\]|"
         r"Dimensions must be equal, but are 2 and 3")):
      with ops.control_dependencies(
          [check_ops.assert_less_equal(small, big)]):
        out = array_ops.identity(small)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_both_empty(self):
    larry = constant_op.constant([])
    curly = constant_op.constant([])
    with ops.control_dependencies(
        [check_ops.assert_less_equal(larry, curly)]):
      out = array_ops.identity(larry)
    self.evaluate(out)

  def test_static_check_in_graph_mode(self):
    with ops.Graph().as_default():
      with self.assertRaisesRegex(  # pylint:disable=g-error-prone-assert-raises
          errors.InvalidArgumentError, "Custom error message"):
        check_ops.assert_less_equal(1, 0, message="Custom error message")


class AssertGreaterTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  @test_util.run_deprecated_v1
  def test_raises_when_equal(self):
    small = constant_op.constant([1, 2], name="small")
    with self.assertRaisesOpError(  # pylint:disable=g-error-prone-assert-raises
        "fail"):
      with ops.control_dependencies(
          [check_ops.assert_greater(
              small, small, message="fail")]):
        out = array_ops.identity(small)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  @test_util.run_deprecated_v1
  def test_raises_when_less(self):
    small = constant_op.constant([1, 2], name="small")
    big = constant_op.constant([3, 4], name="big")
    with self.assertRaisesOpError(  # pylint:disable=g-error-prone-assert-raises
        "x > y did not hold"):
      with ops.control_dependencies([check_ops.assert_greater(small, big)]):
        out = array_ops.identity(big)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_greater(self):
    small = constant_op.constant([3, 1], name="small")
    big = constant_op.constant([4, 2], name="big")
    with ops.control_dependencies([check_ops.assert_greater(big, small)]):
      out = array_ops.identity(small)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_greater_and_broadcastable_shapes(self):
    small = constant_op.constant([1], name="small")
    big = constant_op.constant([3, 2], name="big")
    with ops.control_dependencies([check_ops.assert_greater(big, small)]):
      out = array_ops.identity(small)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_raises_when_greater_but_non_broadcastable_shapes(self):
    small = constant_op.constant([1, 1, 1], name="small")
    big = constant_op.constant([3, 2], name="big")
    # The exception in eager and non-eager mode is different because
    # eager mode relies on shape check done as part of the C++ op, while
    # graph mode does shape checks when creating the `Operation` instance.
    with self.assertRaisesRegex(  # pylint:disable=g-error-prone-assert-raises
        (errors.InvalidArgumentError, ValueError),
        (r"Incompatible shapes: \[2\] vs. \[3\]|"
         r"Dimensions must be equal, but are 2 and 3")):
      with ops.control_dependencies([check_ops.assert_greater(big, small)]):
        out = array_ops.identity(small)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_both_empty(self):
    larry = constant_op.constant([])
    curly = constant_op.constant([])
    with ops.control_dependencies([check_ops.assert_greater(larry, curly)]):
      out = array_ops.identity(larry)
    self.evaluate(out)

  def test_static_check_in_graph_mode(self):
    with ops.Graph().as_default():
      with self.assertRaisesRegex(  # pylint:disable=g-error-prone-assert-raises
          errors.InvalidArgumentError, "Custom error message"):
        check_ops.assert_greater(0, 1, message="Custom error message")


class AssertGreaterEqualTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_equal(self):
    small = constant_op.constant([1, 2], name="small")
    with ops.control_dependencies(
        [check_ops.assert_greater_equal(small, small)]):
      out = array_ops.identity(small)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  @test_util.run_deprecated_v1
  def test_raises_when_less(self):
    small = constant_op.constant([1, 2], name="small")
    big = constant_op.constant([3, 4], name="big")
    with self.assertRaisesOpError(  # pylint:disable=g-error-prone-assert-raises
        "fail"):
      with ops.control_dependencies(
          [check_ops.assert_greater_equal(
              small, big, message="fail")]):
        out = array_ops.identity(small)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_greater_equal(self):
    small = constant_op.constant([1, 2], name="small")
    big = constant_op.constant([3, 2], name="big")
    with ops.control_dependencies(
        [check_ops.assert_greater_equal(big, small)]):
      out = array_ops.identity(small)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_greater_equal_and_broadcastable_shapes(self):
    small = constant_op.constant([1], name="small")
    big = constant_op.constant([3, 1], name="big")
    with ops.control_dependencies(
        [check_ops.assert_greater_equal(big, small)]):
      out = array_ops.identity(small)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_raises_when_less_equal_but_non_broadcastable_shapes(self):
    small = constant_op.constant([1, 1, 1], name="big")
    big = constant_op.constant([3, 1], name="small")
    # The exception in eager and non-eager mode is different because
    # eager mode relies on shape check done as part of the C++ op, while
    # graph mode does shape checks when creating the `Operation` instance.
    with self.assertRaisesRegex(  # pylint:disable=g-error-prone-assert-raises
        (errors.InvalidArgumentError, ValueError),
        (r"Incompatible shapes: \[2\] vs. \[3\]|"
         r"Dimensions must be equal, but are 2 and 3")):
      with ops.control_dependencies(
          [check_ops.assert_greater_equal(big, small)]):
        out = array_ops.identity(small)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_both_empty(self):
    larry = constant_op.constant([])
    curly = constant_op.constant([])
    with ops.control_dependencies(
        [check_ops.assert_greater_equal(larry, curly)]):
      out = array_ops.identity(larry)
    self.evaluate(out)

  def test_static_check_in_graph_mode(self):
    with ops.Graph().as_default():
      with self.assertRaisesRegex(  # pylint:disable=g-error-prone-assert-raises
          errors.InvalidArgumentError, "Custom error message"):
        check_ops.assert_greater_equal(0, 1, message="Custom error message")


class AssertNegativeTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_negative(self):
    frank = constant_op.constant([-1, -2], name="frank")
    with ops.control_dependencies([check_ops.assert_negative(frank)]):
      out = array_ops.identity(frank)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  @test_util.run_deprecated_v1
  def test_raises_when_positive(self):
    doug = constant_op.constant([1, 2], name="doug")
    with self.assertRaisesOpError(  # pylint:disable=g-error-prone-assert-raises
        "fail"):
      with ops.control_dependencies(
          [check_ops.assert_negative(
              doug, message="fail")]):
        out = array_ops.identity(doug)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  @test_util.run_deprecated_v1
  def test_raises_when_zero(self):
    claire = constant_op.constant([0], name="claire")
    with self.assertRaisesOpError(  # pylint:disable=g-error-prone-assert-raises
        "x < 0 did not hold"):
      with ops.control_dependencies([check_ops.assert_negative(claire)]):
        out = array_ops.identity(claire)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_empty_tensor_doesnt_raise(self):
    # A tensor is negative when it satisfies:
    #   For every element x_i in x, x_i < 0
    # and an empty tensor has no elements, so this is trivially satisfied.
    # This is standard set theory.
    empty = constant_op.constant([], name="empty")
    with ops.control_dependencies([check_ops.assert_negative(empty)]):
      out = array_ops.identity(empty)
    self.evaluate(out)

  def test_static_check_in_graph_mode(self):
    with ops.Graph().as_default():
      with self.assertRaisesRegex(errors.InvalidArgumentError,
                                  "Custom error message"):
        check_ops.assert_negative(1, message="Custom error message")


# pylint:disable=g-error-prone-assert-raises
class AssertPositiveTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  @test_util.run_deprecated_v1
  def test_raises_when_negative(self):
    freddie = constant_op.constant([-1, -2], name="freddie")
    with self.assertRaisesOpError("fail"):
      with ops.control_dependencies(
          [check_ops.assert_positive(
              freddie, message="fail")]):
        out = array_ops.identity(freddie)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_positive(self):
    remmy = constant_op.constant([1, 2], name="remmy")
    with ops.control_dependencies([check_ops.assert_positive(remmy)]):
      out = array_ops.identity(remmy)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  @test_util.run_deprecated_v1
  def test_raises_when_zero(self):
    meechum = constant_op.constant([0], name="meechum")
    with self.assertRaisesOpError("x > 0 did not hold"):
      with ops.control_dependencies([check_ops.assert_positive(meechum)]):
        out = array_ops.identity(meechum)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_empty_tensor_doesnt_raise(self):
    # A tensor is positive when it satisfies:
    #   For every element x_i in x, x_i > 0
    # and an empty tensor has no elements, so this is trivially satisfied.
    # This is standard set theory.
    empty = constant_op.constant([], name="empty")
    with ops.control_dependencies([check_ops.assert_positive(empty)]):
      out = array_ops.identity(empty)
    self.evaluate(out)

  def test_static_check_in_graph_mode(self):
    with ops.Graph().as_default():
      with self.assertRaisesRegex(errors.InvalidArgumentError,
                                  "Custom error message"):
        check_ops.assert_positive(-1, message="Custom error message")


class EnsureShapeTest(test.TestCase):

  # Static shape inference
  @test_util.run_deprecated_v1
  def testStaticShape(self):
    placeholder = array_ops.placeholder(dtypes.int32)
    ensure_shape_op = check_ops.ensure_shape(placeholder, (3, 3, 3))
    self.assertEqual(ensure_shape_op.get_shape(), (3, 3, 3))

  @test_util.run_deprecated_v1
  def testStaticShape_MergesShapes(self):
    placeholder = array_ops.placeholder(dtypes.int32, shape=(None, None, 3))
    ensure_shape_op = check_ops.ensure_shape(placeholder, (5, 4, None))
    self.assertEqual(ensure_shape_op.get_shape(), (5, 4, 3))

  @test_util.run_deprecated_v1
  def testStaticShape_RaisesErrorWhenRankIncompatible(self):
    placeholder = array_ops.placeholder(dtypes.int32, shape=(None, None, 3))
    with self.assertRaises(ValueError):
      check_ops.ensure_shape(placeholder, (2, 3))

  @test_util.run_deprecated_v1
  def testStaticShape_RaisesErrorWhenDimIncompatible(self):
    placeholder = array_ops.placeholder(dtypes.int32, shape=(None, None, 3))
    with self.assertRaises(ValueError):
      check_ops.ensure_shape(placeholder, (2, 2, 4))

  @test_util.run_deprecated_v1
  def testStaticShape_CanSetUnknownShape(self):
    placeholder = array_ops.placeholder(dtypes.int32)
    derived = placeholder / 3
    ensure_shape_op = check_ops.ensure_shape(derived, None)
    self.assertEqual(ensure_shape_op.get_shape(), None)

  # Dynamic shape check
  @test_util.run_deprecated_v1
  @test_util.disable_xla(
      "b/123337890")  # Dynamic shapes not supported now with XLA
  def testEnsuresDynamicShape_RaisesError(self):
    placeholder = array_ops.placeholder(dtypes.int32)
    derived = math_ops.divide(placeholder, 3, name="MyDivide")
    derived = check_ops.ensure_shape(derived, (3, 3, 3))
    feed_val = [[1], [2]]
    with self.cached_session() as sess:
      with self.assertRaisesWithPredicateMatch(
          errors.InvalidArgumentError,
          r"Shape of tensor MyDivide \[2,1\] is not compatible with "
          r"expected shape \[3,3,3\]."):
        sess.run(derived, feed_dict={placeholder: feed_val})

  @test_util.run_deprecated_v1
  @test_util.disable_xla(
      "b/123337890")  # Dynamic shapes not supported now with XLA
  def testEnsuresDynamicShape_RaisesErrorDimUnknown(self):
    placeholder = array_ops.placeholder(dtypes.int32)
    derived = placeholder / 3
    derived = check_ops.ensure_shape(derived, (None, None, 3))
    feed_val = [[1], [2]]
    with self.cached_session() as sess:
      with self.assertRaisesWithPredicateMatch(
          errors.InvalidArgumentError,
          r"Shape of tensor [A-Za-z_]* \[2,1\] is not compatible with "
          r"expected shape \[\?,\?,3\]."):
        sess.run(derived, feed_dict={placeholder: feed_val})

  @test_util.run_deprecated_v1
  def testEnsuresDynamicShape(self):
    placeholder = array_ops.placeholder(dtypes.int32)
    derived = placeholder / 3
    derived = check_ops.ensure_shape(derived, (2, 1))
    feed_val = [[1], [2]]
    with self.cached_session() as sess:
      sess.run(derived, feed_dict={placeholder: feed_val})

  @test_util.run_deprecated_v1
  def testEnsuresDynamicShape_WithUnknownDims(self):
    placeholder = array_ops.placeholder(dtypes.int32)
    derived = placeholder / 3
    derived = check_ops.ensure_shape(derived, (None, None))
    feed_val = [[1], [2]]
    with self.cached_session() as sess:
      sess.run(derived, feed_dict={placeholder: feed_val})

  @test_util.run_deprecated_v1
  def testGradient(self):
    placeholder = array_ops.placeholder(dtypes.float32)
    derived = check_ops.ensure_shape(placeholder, (None, None))
    gradient = gradients.gradients(derived, placeholder)

    feed_val = [[4.0], [-1.0]]
    with self.cached_session() as sess:
      gradient_values, = sess.run(gradient, feed_dict={placeholder: feed_val})

    expected = [[1.0], [1.0]]
    self.assertAllEqual(gradient_values, expected)


class EnsureShapeBenchmark(test.Benchmark):

  def _grappler_all_off_config(self):
    config = config_pb2.ConfigProto()
    off = rewriter_config_pb2.RewriterConfig.OFF
    config.graph_options.optimizer_options.opt_level = -1
    config.graph_options.rewrite_options.disable_model_pruning = 1
    config.graph_options.rewrite_options.constant_folding = off
    config.graph_options.rewrite_options.layout_optimizer = off
    config.graph_options.rewrite_options.arithmetic_optimization = off
    config.graph_options.rewrite_options.dependency_optimization = off
    return config

  def _run(self, op, feed_dict=None, num_iters=5000, name=None, **kwargs):
    config = self._grappler_all_off_config()
    with session.Session(config=config) as sess:
      deltas = []
      # Warm up the session
      for _ in range(5):
        sess.run(op, feed_dict=feed_dict)
      for _ in range(num_iters):
        start = time.time()
        sess.run(op, feed_dict=feed_dict)
        end = time.time()
        deltas.append(end - start)
      mean_time = np.median(deltas)
      mean_us = mean_time * 1e6
      # mean_us = (end - start) * 1e6 / num_iters
      self.report_benchmark(
          name=name,
          wall_time=mean_us,
          extras=kwargs,
      )

  def benchmark_const_op(self):
    # In this case, we expect that the overhead of a `session.run` call
    # far outweighs the time taken to execute the op...
    shape = (3, 3, 100)
    input_op = random_ops.random_normal(shape)
    self._run(array_ops.identity(input_op), name="SingleConstOp")

  def benchmark_single_ensure_op(self):
    # In this case, we expect that the overhead of a `session.run` call
    # far outweighs the time taken to execute the op...
    shape = (3, 3, 100)
    input_op = random_ops.random_normal(shape)
    ensure_shape_op = check_ops.ensure_shape(input_op, shape)
    self._run(ensure_shape_op, name="SingleEnsureShapeOp")

  def _apply_n_times(self, op, target, n=1000):
    for _ in range(n):
      target = op(target)
    return target

  def benchmark_n_ops(self):
    shape = (1000,)
    input_op = random_ops.random_normal(shape)
    n_ops = self._apply_n_times(array_ops.identity, input_op)
    self._run(n_ops, name="NIdentityOps_1000")

  def benchmark_n_ensure_ops(self):
    shape = (1000,)
    input_op = random_ops.random_normal(shape)
    n_ensure_ops = self._apply_n_times(
        lambda x: check_ops.ensure_shape(array_ops.identity(x), shape),
        input_op)
    self._run(n_ensure_ops, name="NEnsureShapeAndIdentityOps_1000")


class AssertRankTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self):
    tensor = constant_op.constant(1, name="my_tensor")
    desired_rank = 1
    with self.assertRaisesRegex(ValueError, "fail.*must have rank 1"):
      with ops.control_dependencies(
          [check_ops.assert_rank(
              tensor, desired_rank, message="fail")]):
        self.evaluate(array_ops.identity(tensor))

  @test_util.run_deprecated_v1
  def test_rank_zero_tensor_raises_if_rank_too_small_dynamic_rank(self):
    with self.cached_session():
      tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
      desired_rank = 1
      with ops.control_dependencies(
          [check_ops.assert_rank(
              tensor, desired_rank, message="fail")]):
        with self.assertRaisesOpError("fail.*my_tensor.*rank"):
          array_ops.identity(tensor).eval(feed_dict={tensor: 0})

  @test_util.run_in_graph_and_eager_modes
  def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_static_rank(self):
    tensor = constant_op.constant(1, name="my_tensor")
    desired_rank = 0
    with ops.control_dependencies(
        [check_ops.assert_rank(tensor, desired_rank)]):
      self.evaluate(array_ops.identity(tensor))

  @test_util.run_deprecated_v1
  def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
    with self.cached_session():
      tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
      desired_rank = 0
      with ops.control_dependencies(
          [check_ops.assert_rank(tensor, desired_rank)]):
        array_ops.identity(tensor).eval(feed_dict={tensor: 0})

  @test_util.run_in_graph_and_eager_modes
  def test_rank_one_tensor_raises_if_rank_too_large_static_rank(self):
    tensor = constant_op.constant([1, 2], name="my_tensor")
    desired_rank = 0
    with self.assertRaisesRegex(ValueError, "rank"):
      with ops.control_dependencies(
          [check_ops.assert_rank(tensor, desired_rank)]):
        self.evaluate(array_ops.identity(tensor))

  @test_util.run_deprecated_v1
  def test_rank_one_tensor_raises_if_rank_too_large_dynamic_rank(self):
    with self.cached_session():
      tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
      desired_rank = 0
      with ops.control_dependencies(
          [check_ops.assert_rank(tensor, desired_rank)]):
        with self.assertRaisesOpError("my_tensor.*rank"):
          array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})

  @test_util.run_in_graph_and_eager_modes
  def test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank(self):
    tensor = constant_op.constant([1, 2], name="my_tensor")
    desired_rank = 1
    with ops.control_dependencies(
        [check_ops.assert_rank(tensor, desired_rank)]):
      self.evaluate(array_ops.identity(tensor))

  @test_util.run_deprecated_v1
  def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
    with self.cached_session():
      tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
      desired_rank = 1
      with ops.control_dependencies(
          [check_ops.assert_rank(tensor, desired_rank)]):
        array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})

  @test_util.run_in_graph_and_eager_modes
  def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self):
    tensor = constant_op.constant([1, 2], name="my_tensor")
    desired_rank = 2
    with self.assertRaisesRegex(ValueError, "rank"):
      with ops.control_dependencies(
          [check_ops.assert_rank(tensor, desired_rank)]):
        self.evaluate(array_ops.identity(tensor))

  @test_util.run_deprecated_v1
  def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self):
    with self.cached_session():
      tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
      desired_rank = 2
      with ops.control_dependencies(
          [check_ops.assert_rank(tensor, desired_rank)]):
        with self.assertRaisesOpError("my_tensor.*rank"):
          array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})

  @test_util.run_in_graph_and_eager_modes
  def test_raises_if_rank_is_not_scalar_static(self):
    tensor = constant_op.constant([1, 2], name="my_tensor")
    with self.assertRaisesRegex(ValueError, "Rank must be a scalar"):
      check_ops.assert_rank(tensor, np.array([], dtype=np.int32))

  @test_util.run_deprecated_v1
  def test_raises_if_rank_is_not_scalar_dynamic(self):
    with self.cached_session():
      tensor = constant_op.constant(
          [1, 2], dtype=dtypes.float32, name="my_tensor")
      rank_tensor = array_ops.placeholder(dtypes.int32, name="rank_tensor")
      with self.assertRaisesOpError("Rank must be a scalar"):
        with ops.control_dependencies(
            [check_ops.assert_rank(tensor, rank_tensor)]):
          array_ops.identity(tensor).eval(feed_dict={rank_tensor: [1, 2]})

  @test_util.run_in_graph_and_eager_modes
  def test_raises_if_rank_is_not_integer_static(self):
    tensor = constant_op.constant([1, 2], name="my_tensor")
    with self.assertRaisesRegex(TypeError, "must be of type tf.int32"):
      check_ops.assert_rank(tensor, .5)

  @test_util.run_deprecated_v1
  def test_raises_if_rank_is_not_integer_dynamic(self):
    with self.cached_session():
      tensor = constant_op.constant(
          [1, 2], dtype=dtypes.float32, name="my_tensor")
      rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor")
      with self.assertRaisesRegex(TypeError, "must be of type tf.int32"):
        with ops.control_dependencies(
            [check_ops.assert_rank(tensor, rank_tensor)]):
          array_ops.identity(tensor).eval(feed_dict={rank_tensor: .5})


class AssertRankInTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_rank_zero_tensor_raises_if_rank_mismatch_static_rank(self):
    tensor_rank0 = constant_op.constant(42, name="my_tensor")
    with self.assertRaisesRegex(ValueError, "fail.*must have rank.*in.*1.*2"):
      with ops.control_dependencies([
          check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]):
        self.evaluate(array_ops.identity(tensor_rank0))

  @test_util.run_deprecated_v1
  def test_rank_zero_tensor_raises_if_rank_mismatch_dynamic_rank(self):
    with self.cached_session():
      tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor")
      with ops.control_dependencies([
          check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]):
        with self.assertRaisesOpError("fail.*my_tensor.*rank"):
          array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0})

  @test_util.run_in_graph_and_eager_modes
  def test_rank_zero_tensor_doesnt_raise_if_rank_matches_static_rank(self):
    tensor_rank0 = constant_op.constant(42, name="my_tensor")
    for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
      with ops.control_dependencies([
          check_ops.assert_rank_in(tensor_rank0, desired_ranks)]):
        self.evaluate(array_ops.identity(tensor_rank0))

  @test_util.run_deprecated_v1
  def test_rank_zero_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self):
    with self.cached_session():
      tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor")
      for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
        with ops.control_dependencies([
            check_ops.assert_rank_in(tensor_rank0, desired_ranks)]):
          array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0})

  @test_util.run_in_graph_and_eager_modes
  def test_rank_one_tensor_doesnt_raise_if_rank_matches_static_rank(self):
    tensor_rank1 = constant_op.constant([42, 43], name="my_tensor")
    for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
      with ops.control_dependencies([
          check_ops.assert_rank_in(tensor_rank1, desired_ranks)]):
        self.evaluate(array_ops.identity(tensor_rank1))

  @test_util.run_deprecated_v1
  def test_rank_one_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self):
    with self.cached_session():
      tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor")
      for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
        with ops.control_dependencies([
            check_ops.assert_rank_in(tensor_rank1, desired_ranks)]):
          array_ops.identity(tensor_rank1).eval(feed_dict={
              tensor_rank1: (42.0, 43.0)
          })

  @test_util.run_in_graph_and_eager_modes
  def test_rank_one_tensor_raises_if_rank_mismatches_static_rank(self):
    tensor_rank1 = constant_op.constant((42, 43), name="my_tensor")
    with self.assertRaisesRegex(ValueError, "rank"):
      with ops.control_dependencies([
          check_ops.assert_rank_in(tensor_rank1, (0, 2))]):
        self.evaluate(array_ops.identity(tensor_rank1))

  @test_util.run_deprecated_v1
  def test_rank_one_tensor_raises_if_rank_mismatches_dynamic_rank(self):
    with self.cached_session():
      tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor")
      with ops.control_dependencies([
          check_ops.assert_rank_in(tensor_rank1, (0, 2))]):
        with self.assertRaisesOpError("my_tensor.*rank"):
          array_ops.identity(tensor_rank1).eval(feed_dict={
              tensor_rank1: (42.0, 43.0)
          })

  @test_util.run_in_graph_and_eager_modes
  def test_raises_if_rank_is_not_scalar_static(self):
    tensor = constant_op.constant((42, 43), name="my_tensor")
    desired_ranks = (
        np.array(1, dtype=np.int32),
        np.array((2, 1), dtype=np.int32))
    with self.assertRaisesRegex(ValueError, "Rank must be a scalar"):
      check_ops.assert_rank_in(tensor, desired_ranks)

  @test_util.run_deprecated_v1
  def test_raises_if_rank_is_not_scalar_dynamic(self):
    with self.cached_session():
      tensor = constant_op.constant(
          (42, 43), dtype=dtypes.float32, name="my_tensor")
      desired_ranks = (
          array_ops.placeholder(dtypes.int32, name="rank0_tensor"),
          array_ops.placeholder(dtypes.int32, name="rank1_tensor"))
      with self.assertRaisesOpError("Rank must be a scalar"):
        with ops.control_dependencies(
            (check_ops.assert_rank_in(tensor, desired_ranks),)):
          array_ops.identity(tensor).eval(feed_dict={
              desired_ranks[0]: 1,
              desired_ranks[1]: [2, 1],
          })

  @test_util.run_in_graph_and_eager_modes
  def test_raises_if_rank_is_not_integer_static(self):
    tensor = constant_op.constant((42, 43), name="my_tensor")
    with self.assertRaisesRegex(TypeError, "must be of type tf.int32"):
      check_ops.assert_rank_in(tensor, (1, .5,))

  @test_util.run_deprecated_v1
  def test_raises_if_rank_is_not_integer_dynamic(self):
    with self.cached_session():
      tensor = constant_op.constant(
          (42, 43), dtype=dtypes.float32, name="my_tensor")
      rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor")
      with self.assertRaisesRegex(TypeError, "must be of type tf.int32"):
        with ops.control_dependencies(
            [check_ops.assert_rank_in(tensor, (1, rank_tensor))]):
          array_ops.identity(tensor).eval(feed_dict={rank_tensor: .5})


class AssertRankAtLeastTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self):
    tensor = constant_op.constant(1, name="my_tensor")
    desired_rank = 1
    with self.assertRaisesRegex(ValueError, "rank at least 1"):
      with ops.control_dependencies(
          [check_ops.assert_rank_at_least(tensor, desired_rank)]):
        self.evaluate(array_ops.identity(tensor))

  @test_util.run_deprecated_v1
  def test_rank_zero_tensor_raises_if_rank_too_small_dynamic_rank(self):
    with self.cached_session():
      tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
      desired_rank = 1
      with ops.control_dependencies(
          [check_ops.assert_rank_at_least(tensor, desired_rank)]):
        with self.assertRaisesOpError("my_tensor.*rank"):
          array_ops.identity(tensor).eval(feed_dict={tensor: 0})

  @test_util.run_in_graph_and_eager_modes
  def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_static_rank(self):
    tensor = constant_op.constant(1, name="my_tensor")
    desired_rank = 0
    with ops.control_dependencies(
        [check_ops.assert_rank_at_least(tensor, desired_rank)]):
      self.evaluate(array_ops.identity(tensor))

  @test_util.run_deprecated_v1
  def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
    with self.cached_session():
      tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
      desired_rank = 0
      with ops.control_dependencies(
          [check_ops.assert_rank_at_least(tensor, desired_rank)]):
        array_ops.identity(tensor).eval(feed_dict={tensor: 0})

  @test_util.run_in_graph_and_eager_modes
  def test_rank_one_ten_doesnt_raise_raise_if_rank_too_large_static_rank(self):
    tensor = constant_op.constant([1, 2], name="my_tensor")
    desired_rank = 0
    with ops.control_dependencies(
        [check_ops.assert_rank_at_least(tensor, desired_rank)]):
      self.evaluate(array_ops.identity(tensor))

  @test_util.run_deprecated_v1
  def test_rank_one_ten_doesnt_raise_if_rank_too_large_dynamic_rank(self):
    with self.cached_session():
      tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
      desired_rank = 0
      with ops.control_dependencies(
          [check_ops.assert_rank_at_least(tensor, desired_rank)]):
        array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})

  @test_util.run_in_graph_and_eager_modes
  def test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank(self):
    tensor = constant_op.constant([1, 2], name="my_tensor")
    desired_rank = 1
    with ops.control_dependencies(
        [check_ops.assert_rank_at_least(tensor, desired_rank)]):
      self.evaluate(array_ops.identity(tensor))

  @test_util.run_deprecated_v1
  def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
    with self.cached_session():
      tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
      desired_rank = 1
      with ops.control_dependencies(
          [check_ops.assert_rank_at_least(tensor, desired_rank)]):
        array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})

  @test_util.run_in_graph_and_eager_modes
  def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self):
    tensor = constant_op.constant([1, 2], name="my_tensor")
    desired_rank = 2
    with self.assertRaisesRegex(ValueError, "rank at least 2"):
      with ops.control_dependencies(
          [check_ops.assert_rank_at_least(tensor, desired_rank)]):
        self.evaluate(array_ops.identity(tensor))

  @test_util.run_deprecated_v1
  def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self):
    with self.cached_session():
      tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
      desired_rank = 2
      with ops.control_dependencies(
          [check_ops.assert_rank_at_least(tensor, desired_rank)]):
        with self.assertRaisesOpError("my_tensor.*rank"):
          array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})


class AssertNonNegativeTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  @test_util.run_deprecated_v1
  def test_raises_when_negative(self):
    zoe = constant_op.constant([-1, -2], name="zoe")
    with self.assertRaisesOpError("x >= 0 did not hold"):
      with ops.control_dependencies([check_ops.assert_non_negative(zoe)]):
        out = array_ops.identity(zoe)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_zero_and_positive(self):
    lucas = constant_op.constant([0, 2], name="lucas")
    with ops.control_dependencies([check_ops.assert_non_negative(lucas)]):
      out = array_ops.identity(lucas)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_empty_tensor_doesnt_raise(self):
    # A tensor is non-negative when it satisfies:
    #   For every element x_i in x, x_i >= 0
    # and an empty tensor has no elements, so this is trivially satisfied.
    # This is standard set theory.
    empty = constant_op.constant([], name="empty")
    with ops.control_dependencies([check_ops.assert_non_negative(empty)]):
      out = array_ops.identity(empty)
    self.evaluate(out)

  def test_static_check_in_graph_mode(self):
    with ops.Graph().as_default():
      with self.assertRaisesRegex(errors.InvalidArgumentError,
                                  "Custom error message"):
        check_ops.assert_non_negative(-1, message="Custom error message")


class AssertNonPositiveTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_zero_and_negative(self):
    tom = constant_op.constant([0, -2], name="tom")
    with ops.control_dependencies([check_ops.assert_non_positive(tom)]):
      out = array_ops.identity(tom)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  @test_util.run_deprecated_v1
  def test_raises_when_positive(self):
    rachel = constant_op.constant([0, 2], name="rachel")
    with self.assertRaisesOpError("x <= 0 did not hold"):
      with ops.control_dependencies([check_ops.assert_non_positive(rachel)]):
        out = array_ops.identity(rachel)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_empty_tensor_doesnt_raise(self):
    # A tensor is non-positive when it satisfies:
    #   For every element x_i in x, x_i <= 0
    # and an empty tensor has no elements, so this is trivially satisfied.
    # This is standard set theory.
    empty = constant_op.constant([], name="empty")
    with ops.control_dependencies([check_ops.assert_non_positive(empty)]):
      out = array_ops.identity(empty)
    self.evaluate(out)

  def test_static_check_in_graph_mode(self):
    with ops.Graph().as_default():
      with self.assertRaisesRegex(errors.InvalidArgumentError,
                                  "Custom error message"):
        check_ops.assert_non_positive(1, message="Custom error message")


class AssertIntegerTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_integer(self):
    integers = constant_op.constant([1, 2], name="integers")
    with ops.control_dependencies([check_ops.assert_integer(integers)]):
      out = array_ops.identity(integers)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_raises_when_float(self):
    floats = constant_op.constant([1.0, 2.0], name="floats")
    with self.assertRaisesRegex(TypeError, "Expected.*integer"):
      check_ops.assert_integer(floats)


class AssertTypeTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_doesnt_raise_when_correct_type(self):
    integers = constant_op.constant([1, 2], dtype=dtypes.int64)
    with ops.control_dependencies([
        check_ops.assert_type(integers, dtypes.int64)]):
      out = array_ops.identity(integers)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_sparsetensor_doesnt_raise_when_correct_type(self):
    sparse_float = sparse_tensor.SparseTensor(
        constant_op.constant([[111], [232]], dtypes.int64),
        constant_op.constant([23.4, -43.2], dtypes.float32),
        constant_op.constant([500], dtypes.int64))

    with ops.control_dependencies(
        [check_ops.assert_type(sparse_float, dtypes.float32)]):
      out = sparse_tensor.SparseTensor(sparse_float.indices,
                                       array_ops.identity(sparse_float.values),
                                       sparse_float.dense_shape)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_raggedtensor_doesnt_raise_when_correct_type(self):
    x = ragged_factory_ops.constant([[1., 2.], [3.]])
    with ops.control_dependencies(
        [check_ops.assert_type(x, dtypes.float32)]):
      y = array_ops.identity(x)
    self.assertAllEqual(x, y)

  @test_util.run_in_graph_and_eager_modes
  def test_raises_when_wrong_type(self):
    floats = constant_op.constant([1.0, 2.0], dtype=dtypes.float16)
    with self.assertRaisesRegex(TypeError, "must be of type tf.float32; "
                                "got tf.float16"):
      check_ops.assert_type(floats, dtypes.float32)

  @test_util.run_in_graph_and_eager_modes
  def test_sparsetensor_raises_when_wrong_type(self):
    sparse_float16 = sparse_tensor.SparseTensor(
        constant_op.constant([[111], [232]], dtypes.int64),
        constant_op.constant([23.4, -43.2], dtypes.float16),
        constant_op.constant([500], dtypes.int64))
    with self.assertRaisesRegexp(TypeError, "must be of type.*float32"):
      check_ops.assert_type(sparse_float16, dtypes.float32)

  @test_util.run_in_graph_and_eager_modes
  def test_raggedtensor_raises_when_wrong_type(self):
    x = ragged_factory_ops.constant([[1, 2], [3]])
    with self.assertRaisesRegex(TypeError, "must be of type.*float32"):
      check_ops.assert_type(x, dtypes.float32)

  def test_raise_when_tf_type_is_not_dtype(self):
    # Test case for GitHub issue:
    # https://github.com/tensorflow/tensorflow/issues/45975
    value = constant_op.constant(0.0)
    with self.assertRaisesRegexp(TypeError,
                                 "Cannot convert.*to a TensorFlow DType"):
      check_ops.assert_type(value, (dtypes.float32,))


class AssertShapesTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_raise_static_shape_mismatch(self):
    x = array_ops.ones([3, 2], name="x")
    y = array_ops.ones([2, 3], name="y")
    shapes = [
        (x, ("N", "Q")),
        (y, ("N", "D")),
    ]
    regex = (r"Specified by tensor .* dimension 0.  "
             r"Tensor .* dimension 0 must have size 3.  "
             r"Received size 2")
    self.raises_static_error(shapes=shapes, regex=regex)

  def test_raise_dynamic_shape_mismatch(self):
    with ops.Graph().as_default():
      x = array_ops.placeholder(dtypes.float32, [None, 2], name="x")
      y = array_ops.placeholder(dtypes.float32, [None, 3], name="y")
      shapes = [
          (x, ("N", "Q")),
          (y, ("N", "D")),
      ]
      regex = (r"\[Specified by tensor x.* dimension 0\] "
               r"\[Tensor y.* dimension\] \[0\] \[must have size\] \[3\]")
      feed_dict = {x: np.ones([3, 2]), y: np.ones([2, 3])}
      self.raises_dynamic_error(shapes=shapes, regex=regex, feed_dict=feed_dict)

  @test_util.run_in_graph_and_eager_modes
  def test_raise_static_shape_explicit_mismatch(self):
    x = array_ops.ones([3, 2], name="x")
    y = array_ops.ones([2, 3], name="y")
    shapes = [
        (x, (3, "Q")),
        (y, (3, "D")),
    ]
    regex = (r"Specified explicitly.  "
             r"Tensor .* dimension 0 must have size 3.  "
             r"Received size 2")
    self.raises_static_error(shapes=shapes, regex=regex)

  @test_util.run_in_graph_and_eager_modes
  def test_rank_zero_rank_one_size_one_equivalence(self):
    rank_one_size_one = array_ops.ones([1], name="rank_one_size_one")
    rank_zero = array_ops.constant(5, name="rank_zero")
    check_ops.assert_shapes([
        (rank_one_size_one, ()),
        (rank_zero, ()),
    ])
    check_ops.assert_shapes([
        (rank_one_size_one, (1,)),
        (rank_zero, (1,)),
    ])

  @test_util.run_in_graph_and_eager_modes
  def test_raise_static_rank_1_size_not_1_mismatch_scalar(self):
    x = array_ops.constant([2, 2], name="x")
    shapes = [
        (x, ()),
    ]
    regex = (r"Specified explicitly.  "
             r"Tensor .* dimension 0 must have size 1.  "
             r"Received size 2")
    self.raises_static_error(shapes=shapes, regex=regex)

  @test_util.run_in_graph_and_eager_modes
  def test_raise_static_scalar_mismatch_rank_1_size_not_1(self):
    x = array_ops.constant(2, name="x")
    shapes = [
        (x, (2,)),
    ]
    regex = (r"Specified explicitly.  "
             r"Tensor .* dimension 0 must have size 2.  "
             r"Received size 1")
    self.raises_static_error(shapes=shapes, regex=regex)

  @test_util.run_in_graph_and_eager_modes
  def test_scalar_implies_size_one(self):
    scalar = array_ops.constant(5, name="rank_zero")
    x = array_ops.ones([2, 2], name="x")
    shapes = [(scalar, ("a",)), (x, ("a", 2))]
    regex = (r"Specified by tensor .* dimension 0.  "
             r"Tensor .* dimension 0 must have size 1.  "
             r"Received size 2")
    self.raises_static_error(shapes=shapes, regex=regex)

  @test_util.run_in_graph_and_eager_modes
  def test_raise_not_iterable(self):
    x = array_ops.constant([1, 2], name="x")
    shapes = [(x, 2)]
    regex = (r"Tensor .*.  "
             r"Specified shape must be an iterable.  "
             r"An iterable has the attribute `__iter__` or `__getitem__`.  "
             r"Received specified shape: 2")
    self.raises_static_error(shapes=shapes, regex=regex)

  def test_raise_dynamic_shape_explicit_mismatch(self):
    with ops.Graph().as_default():
      x = array_ops.placeholder(dtypes.float32, [None, 2], name="xa")
      y = array_ops.placeholder(dtypes.float32, [None, 3], name="y")
      shapes = [
          (x, (3, "Q")),
          (y, (3, "D")),
      ]
      regex = (r"\[Specified explicitly\] "
               r"\[Tensor y.* dimension\] \[0\] \[must have size\] \[3\]")
      feed_dict = {x: np.ones([3, 2]), y: np.ones([2, 3])}
      self.raises_dynamic_error(shapes=shapes, regex=regex, feed_dict=feed_dict)

  @test_util.run_in_graph_and_eager_modes
  def test_no_op_when_specified_as_unknown(self):
    x = array_ops.constant([1, 1], name="x")
    assertion = check_ops.assert_shapes([(x, None)])
    with ops.control_dependencies([assertion]):
      out = array_ops.identity(x)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_raises_static_incorrect_rank(self):
    rank_two_shapes = [
        (1, 1),
        (1, 3),
        ("a", "b"),
        (None, None),
    ]
    rank_three_shapes = [
        (1, 1, 1),
        ("a", "b", "c"),
        (None, None, None),
        (1, "b", None),
    ]

    def raises_static_rank_error(shapes, x, correct_rank, actual_rank):
      for shape in shapes:
        regex = (r"Tensor .* must have rank %d.  Received rank %d" %
                 (correct_rank, actual_rank))
        self.raises_static_error(shapes=[(x, shape)], regex=regex)

    raises_static_rank_error(
        rank_two_shapes, array_ops.ones([1]), correct_rank=2, actual_rank=1)
    raises_static_rank_error(
        rank_three_shapes,
        array_ops.ones([1, 1]),
        correct_rank=3,
        actual_rank=2)
    raises_static_rank_error(
        rank_three_shapes, array_ops.constant(1), correct_rank=3, actual_rank=0)

  def test_raises_dynamic_incorrect_rank(self):
    x_value = 5
    rank_two_shapes = [(1, 1), (1, 3), ("a", "b"), (None, None)]
    with ops.Graph().as_default():
      x = array_ops.placeholder(dtypes.float32, None)

      for shape in rank_two_shapes:
        regex = r"Tensor .* must have rank\] \[2\]"
        self.raises_dynamic_error(
            shapes=[(x, shape)], regex=regex, feed_dict={x: x_value})

  @test_util.run_in_graph_and_eager_modes
  def test_correctly_matching(self):
    u = array_ops.constant(1, name="u")
    v = array_ops.ones([1, 2], name="v")
    w = array_ops.ones([3], name="w")
    x = array_ops.ones([1, 2, 3], name="x")
    y = array_ops.ones([3, 1, 2], name="y")
    z = array_ops.ones([2, 3, 1], name="z")
    assertion = check_ops.assert_shapes([
        (x, ("a", "b", "c")),
        (y, ("c", "a", "b")),
        (z, ("b", "c", "a")),
        (v, ("a", "b")),
        (w, ("c",)),
        (u, "a")
    ])
    with ops.control_dependencies([assertion]):
      out = array_ops.identity(x)
    self.evaluate(out)
    assertion = check_ops.assert_shapes([
        (x, (1, "b", "c")),
        (y, ("c", "a", 2)),
        (z, ("b", 3, "a")),
        (v, ("a", 2)),
        (w, (3,)),
        (u, ())
    ])
    with ops.control_dependencies([assertion]):
      out = array_ops.identity(x)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_variable_length_symbols(self):
    x = array_ops.ones([4, 1], name="x")
    y = array_ops.ones([4, 2], name="y")
    assertion = check_ops.assert_shapes([
        (x, ("num_observations", "input_dim")),
        (y, ("num_observations", "output_dim")),
    ])
    with ops.control_dependencies([assertion]):
      out = array_ops.identity(x)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_raise_implicit_mismatch_using_iterable_alternatives(self):
    x = array_ops.ones([2, 2], name="x")
    y = array_ops.ones([1, 3], name="y")
    styles = [[
        (x, ("A", "B")),
        (y, ("A", "C")),
    ], [
        (x, "AB"),
        (y, "AC")
    ], [
        (x, ["A", "B"]),
        (y, ["A", "C"]),
    ], [
        (x, np.array(["A", "B"])),
        (y, np.array(["A", "C"]))
    ], [
        (x, ("A", "B")),
        (y, "AC")
    ]]
    for shapes in styles:
      self.raises_static_error(
          shapes=shapes,
          regex=(r"Specified by tensor .* dimension 0.  "
                 "Tensor .* dimension 0 must have size 2.  "
                 "Received size 1"))

  @test_util.run_in_graph_and_eager_modes
  def test_raise_explicit_mismatch_using_iterable_alternatives(self):
    x = array_ops.ones([2, 2], name="x")
    y = array_ops.ones([1, 3], name="y")
    styles = [[
        (x, (2, 2)),
        (y, (2, 3)),
    ], [
        (x, "22"),
        (y, "23")
    ], [
        (x, [2, 2]),
        (y, [2, 3]),
    ], [
        (x, np.array([2, 2])),
        (y, np.array([2, 3]))
    ], [
        (x, (2, 2)),
        (y, "23")
    ]]
    for shapes in styles:
      self.raises_static_error(
          shapes=shapes,
          regex=(r"Specified explicitly.  "
                 "Tensor .* dimension 0 must have size 2.  "
                 "Received size 1"))

  @test_util.run_in_graph_and_eager_modes
  def test_dim_size_specified_as_unknown(self):
    x = array_ops.ones([1, 2, 3], name="x")
    y = array_ops.ones([2, 1], name="y")
    a1 = check_ops.assert_shapes([
        (x, (None, 2, None)),
        (y, (None, 1)),
    ])
    a2 = check_ops.assert_shapes([
        (x, (".", 2, ".")),
        (y, (".", 1)),
    ])
    a3 = check_ops.assert_shapes([
        (x, ".2."),
        (y, ".1"),
    ])
    with ops.control_dependencies([a1, a2, a3]):
      out = array_ops.identity(x)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_raise_static_shape_explicit_mismatch_innermost_dims(self):
    x = array_ops.ones([3, 2], name="x")
    y = array_ops.ones([2, 3], name="y")
    s1 = [
        (x, (3, "Q")),
        (y, (Ellipsis, 3, "D")),
    ]
    s2 = [
        (x, "3Q"),
        (y, "*3D"),
    ]
    regex = (r"Specified explicitly.  "
             r"Tensor .* dimension -2 must have size 3.  "
             r"Received size 2")
    self.raises_static_error(shapes=s1, regex=regex)
    self.raises_static_error(shapes=s2, regex=regex)

  @test_util.run_in_graph_and_eager_modes
  def test_correctly_matching_innermost_dims(self):
    x = array_ops.ones([1, 2, 3, 2], name="x")
    y = array_ops.ones([2, 3, 3], name="y")
    a1 = check_ops.assert_shapes([
        (x, (Ellipsis, "N", "Q")),
        (y, (Ellipsis, "N", "D")),
    ])
    a2 = check_ops.assert_shapes([
        (x, "*NQ"),
        (y, "*ND"),
    ])
    with ops.control_dependencies([a1, a2]):
      out = array_ops.identity(x)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_raise_variable_num_outer_dims_prefix_misuse(self):
    x = array_ops.ones([1, 2], name="x")
    s1 = [
        (x, ("N", Ellipsis, "Q")),
    ]
    s2 = [
        (x, "N*Q"),
    ]
    regex = (r"Tensor .* specified shape index .*.  "
             r"Symbol `...` or `\*` for a variable number of "
             r"unspecified dimensions is only allowed as the first entry")
    self.raises_static_error(shapes=s1, regex=regex)
    self.raises_static_error(shapes=s2, regex=regex)

  @test_util.run_in_graph_and_eager_modes
  def test_empty_shapes_dict_no_op(self):
    assertion = check_ops.assert_shapes([])
    with ops.control_dependencies([assertion]):
      out = array_ops.identity(0)
    self.evaluate(out)

  def raises_static_error(self, shapes, regex):
    with self.assertRaisesRegex(ValueError, regex):
      check_ops.assert_shapes(shapes)

  def raises_dynamic_error(self, shapes, regex, feed_dict):
    with self.session() as sess:
      with self.assertRaisesRegex(errors.InvalidArgumentError, regex):
        assertion = check_ops.assert_shapes(shapes)
        with ops.control_dependencies([assertion]):
          out = array_ops.identity(0)
        sess.run(out, feed_dict=feed_dict)


class AssertShapesSparseTensorTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_assert_shapes_sparse_tensor_scalar_target_success(self):
    sparse_float = sparse_tensor.SparseTensor(
        constant_op.constant([[]], dtypes.int64),
        constant_op.constant([42], dtypes.float32),
        constant_op.constant([], dtypes.int64))
    assertion = check_ops.assert_shapes([(sparse_float, [])])
    with ops.control_dependencies([assertion]):
      out = array_ops.identity(sparse_float)
    self.evaluate(out)

  def test_assert_shapes_sparse_tensor_nonscalar_target_fail(self):
    sparse_float = sparse_tensor.SparseTensor(
        constant_op.constant([[]], dtypes.int64),
        constant_op.constant([42], dtypes.float32),
        constant_op.constant([], dtypes.int64))
    with self.assertRaisesRegexp(ValueError,
                                 r"must have rank 2.*Received rank 0"):
      assertion = check_ops.assert_shapes([(sparse_float, [None, None])])
      with ops.control_dependencies([assertion]):
        out = array_ops.identity(sparse_float)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_assert_shapes_sparse_tensor_fully_specified_target_success(self):
    sparse_float = sparse_tensor.SparseTensor(
        constant_op.constant([[111], [232]], dtypes.int64),
        constant_op.constant([23.4, -43.2], dtypes.float32),
        constant_op.constant([500], dtypes.int64))
    assertion = check_ops.assert_shapes([(sparse_float, [500])])
    with ops.control_dependencies([assertion]):
      out = array_ops.identity(sparse_float)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_assert_shapes_sparse_tensor_fully_specified_target_fail(self):
    sparse_float = sparse_tensor.SparseTensor(
        constant_op.constant([[111], [232]], dtypes.int64),
        constant_op.constant([23.4, -43.2], dtypes.float32),
        constant_op.constant([500], dtypes.int64))
    with self.assertRaisesRegexp(ValueError, r"dimension 0 must have size 499"):
      assertion = check_ops.assert_shapes([(sparse_float, [499])])
      with ops.control_dependencies([assertion]):
        out = array_ops.identity(sparse_float)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_assert_shapes_sparse_tensor_partially_specified_target_success(self):
    sparse_int = sparse_tensor.SparseTensor(
        constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
        constant_op.constant([23, -43], dtypes.int32),
        constant_op.constant([30, 40], dtypes.int64))
    assertion = check_ops.assert_shapes([(sparse_int, [None, 40])])
    with ops.control_dependencies([assertion]):
      out = array_ops.identity(sparse_int)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_assert_shapes_sparse_tensor_symbolic_match_success(self):
    sparse_int = sparse_tensor.SparseTensor(
        constant_op.constant([[5, 6, 7], [8, 9, 10]], dtypes.int64),
        constant_op.constant([23, -43], dtypes.int32),
        constant_op.constant([30, 30, 40], dtypes.int64))
    assertion = check_ops.assert_shapes([(sparse_int, ["N", "N", "D"])])
    with ops.control_dependencies([assertion]):
      out = array_ops.identity(sparse_int)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_assert_shapes_sparse_tensor_partially_specified_target_fail(self):
    sparse_int = sparse_tensor.SparseTensor(
        constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
        constant_op.constant([23, -43], dtypes.int32),
        constant_op.constant([30, 40], dtypes.int64))
    with self.assertRaisesRegexp(ValueError, r"dimension 1 must have size 41"):
      assertion = check_ops.assert_shapes([(sparse_int, [None, 41])])
      with ops.control_dependencies([assertion]):
        out = array_ops.identity(sparse_int)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_assert_shapes_sparse_tensor_wrong_rank_fail(self):
    sparse_int = sparse_tensor.SparseTensor(
        constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
        constant_op.constant([23, -43], dtypes.int32),
        constant_op.constant([30, 40], dtypes.int64))
    with self.assertRaisesRegexp(ValueError,
                                 r"must have rank 3\..* Received rank 2"):
      assertion = check_ops.assert_shapes([(sparse_int, [None, None, 40])])
      with ops.control_dependencies([assertion]):
        out = array_ops.identity(sparse_int)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_assert_shapes_sparse_tensor_wrong_symbolic_match_fail(self):
    sparse_int = sparse_tensor.SparseTensor(
        constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
        constant_op.constant([23, -43], dtypes.int32),
        constant_op.constant([30, 40], dtypes.int64))
    with self.assertRaisesRegexp(ValueError, r"dimension 1 must have size 30"):
      assertion = check_ops.assert_shapes([(sparse_int, ["D", "D"])])
      with ops.control_dependencies([assertion]):
        out = array_ops.identity(sparse_int)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_assert_shapes_sparse_tensor_multiple_assertions_success(self):
    sparse_scalar = sparse_tensor.SparseTensor(
        constant_op.constant([[]], dtypes.int64),
        constant_op.constant([42], dtypes.float32),
        constant_op.constant([], dtypes.int64))
    sparse_2d = sparse_tensor.SparseTensor(
        constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
        constant_op.constant([23, -43], dtypes.int32),
        constant_op.constant([30, 30], dtypes.int64))
    assertion = check_ops.assert_shapes([(sparse_scalar, []),
                                         (sparse_2d, ["N", "N"])])
    with ops.control_dependencies([assertion]):
      out = array_ops.identity(sparse_2d)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_assert_shapes_sparse_tensor_multiple_assertions_fail(self):
    sparse_scalar = sparse_tensor.SparseTensor(
        constant_op.constant([[]], dtypes.int64),
        constant_op.constant([42], dtypes.float32),
        constant_op.constant([], dtypes.int64))
    sparse_2d = sparse_tensor.SparseTensor(
        constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
        constant_op.constant([23, -43], dtypes.int32),
        constant_op.constant([30, 40], dtypes.int64))
    with self.assertRaisesRegexp(ValueError, r"dimension 1 must have size 30"):
      assertion = check_ops.assert_shapes([(sparse_scalar, []),
                                           (sparse_2d, ["N", "N"])])
      with ops.control_dependencies([assertion]):
        out = array_ops.identity(sparse_2d)
      self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_assert_shapes_sparse_tensor_mixed_dense_and_sparse_success(self):
    dense_scalar = constant_op.constant([42], dtypes.float32)
    sparse_2d = sparse_tensor.SparseTensor(
        constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
        constant_op.constant([23, -43], dtypes.int32),
        constant_op.constant([30, 30], dtypes.int64))
    assertion = check_ops.assert_shapes([(dense_scalar, []),
                                         (sparse_2d, ["N", "N"])])
    with ops.control_dependencies([assertion]):
      out = array_ops.identity(sparse_2d)
    self.evaluate(out)

  @test_util.run_in_graph_and_eager_modes
  def test_assert_shapes_sparse_tensor_mixed_dense_and_sparse_fail(self):
    dense_scalar = constant_op.constant([42], dtypes.float32)
    sparse_2d = sparse_tensor.SparseTensor(
        constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
        constant_op.constant([23, -43], dtypes.int32),
        constant_op.constant([30, 40], dtypes.int64))
    with self.assertRaisesRegexp(ValueError, r"dimension 1 must have size 30"):
      assertion = check_ops.assert_shapes([(dense_scalar, []),
                                           (sparse_2d, ["N", "N"])])
      with ops.control_dependencies([assertion]):
        out = array_ops.identity(sparse_2d)
      self.evaluate(out)


class IsStrictlyIncreasingTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_constant_tensor_is_not_strictly_increasing(self):
    self.assertFalse(self.evaluate(check_ops.is_strictly_increasing([1, 1, 1])))

  @test_util.run_in_graph_and_eager_modes
  def test_decreasing_tensor_is_not_strictly_increasing(self):
    self.assertFalse(self.evaluate(
        check_ops.is_strictly_increasing([1, 0, -1])))

  @test_util.run_in_graph_and_eager_modes
  def test_2d_decreasing_tensor_is_not_strictly_increasing(self):
    self.assertFalse(
        self.evaluate(check_ops.is_strictly_increasing([[1, 3], [2, 4]])))

  @test_util.run_in_graph_and_eager_modes
  def test_increasing_tensor_is_increasing(self):
    self.assertTrue(self.evaluate(check_ops.is_strictly_increasing([1, 2, 3])))

  @test_util.run_in_graph_and_eager_modes
  def test_increasing_rank_two_tensor(self):
    self.assertTrue(
        self.evaluate(check_ops.is_strictly_increasing([[-1, 2], [3, 4]])))

  @test_util.run_in_graph_and_eager_modes
  def test_tensor_with_one_element_is_strictly_increasing(self):
    self.assertTrue(self.evaluate(check_ops.is_strictly_increasing([1])))

  @test_util.run_in_graph_and_eager_modes
  def test_empty_tensor_is_strictly_increasing(self):
    self.assertTrue(self.evaluate(check_ops.is_strictly_increasing([])))


class IsNonDecreasingTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_constant_tensor_is_non_decreasing(self):
    self.assertTrue(self.evaluate(check_ops.is_non_decreasing([1, 1, 1])))

  @test_util.run_in_graph_and_eager_modes
  def test_decreasing_tensor_is_not_non_decreasing(self):
    self.assertFalse(self.evaluate(check_ops.is_non_decreasing([3, 2, 1])))

  @test_util.run_in_graph_and_eager_modes
  def test_2d_decreasing_tensor_is_not_non_decreasing(self):
    self.assertFalse(self.evaluate(
        check_ops.is_non_decreasing([[1, 3], [2, 4]])))

  @test_util.run_in_graph_and_eager_modes
  def test_increasing_rank_one_tensor_is_non_decreasing(self):
    self.assertTrue(self.evaluate(check_ops.is_non_decreasing([1, 2, 3])))

  @test_util.run_in_graph_and_eager_modes
  def test_increasing_rank_two_tensor(self):
    self.assertTrue(self.evaluate(
        check_ops.is_non_decreasing([[-1, 2], [3, 3]])))

  @test_util.run_in_graph_and_eager_modes
  def test_tensor_with_one_element_is_non_decreasing(self):
    self.assertTrue(self.evaluate(check_ops.is_non_decreasing([1])))

  @test_util.run_in_graph_and_eager_modes
  def test_empty_tensor_is_non_decreasing(self):
    self.assertTrue(self.evaluate(check_ops.is_non_decreasing([])))


class FloatDTypeTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_assert_same_float_dtype(self):
    self.assertIs(dtypes.float32,
                  check_ops.assert_same_float_dtype(None, None))
    self.assertIs(dtypes.float32, check_ops.assert_same_float_dtype([], None))
    self.assertIs(dtypes.float32,
                  check_ops.assert_same_float_dtype([], dtypes.float32))
    self.assertIs(dtypes.float32,
                  check_ops.assert_same_float_dtype(None, dtypes.float32))
    self.assertIs(dtypes.float32,
                  check_ops.assert_same_float_dtype([None, None], None))
    self.assertIs(
        dtypes.float32,
        check_ops.assert_same_float_dtype([None, None], dtypes.float32))

    const_float = constant_op.constant(3.0, dtype=dtypes.float32)
    self.assertIs(
        dtypes.float32,
        check_ops.assert_same_float_dtype([const_float], dtypes.float32))
    self.assertRaises(ValueError, check_ops.assert_same_float_dtype,
                      [const_float], dtypes.int32)

    sparse_float = sparse_tensor.SparseTensor(
        constant_op.constant([[111], [232]], dtypes.int64),
        constant_op.constant([23.4, -43.2], dtypes.float32),
        constant_op.constant([500], dtypes.int64))
    self.assertIs(dtypes.float32,
                  check_ops.assert_same_float_dtype([sparse_float],
                                                    dtypes.float32))
    self.assertRaises(ValueError, check_ops.assert_same_float_dtype,
                      [sparse_float], dtypes.int32)
    self.assertRaises(ValueError, check_ops.assert_same_float_dtype,
                      [const_float, None, sparse_float], dtypes.float64)

    self.assertIs(dtypes.float32,
                  check_ops.assert_same_float_dtype(
                      [const_float, sparse_float]))
    self.assertIs(dtypes.float32,
                  check_ops.assert_same_float_dtype(
                      [const_float, sparse_float], dtypes.float32))

    const_int = constant_op.constant(3, dtype=dtypes.int32)
    self.assertRaises(ValueError, check_ops.assert_same_float_dtype,
                      [sparse_float, const_int])
    self.assertRaises(ValueError, check_ops.assert_same_float_dtype,
                      [sparse_float, const_int], dtypes.int32)
    self.assertRaises(ValueError, check_ops.assert_same_float_dtype,
                      [sparse_float, const_int], dtypes.float32)
    self.assertRaises(ValueError, check_ops.assert_same_float_dtype,
                      [const_int])


class AssertScalarTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_assert_scalar(self):
    check_ops.assert_scalar(constant_op.constant(3))
    check_ops.assert_scalar(constant_op.constant("foo"))
    check_ops.assert_scalar(3)
    check_ops.assert_scalar("foo")
    with self.assertRaisesRegex(ValueError, "Expected scalar"):
      check_ops.assert_scalar(constant_op.constant([3, 4]))


if __name__ == "__main__":
  test.main()