tensorflow/tensorflow

View on GitHub
tensorflow/python/ops/special_math_ops_test.py

Summary

Maintainability
F
1 mo
Test Coverage
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.ops.special_math_ops."""

from absl.testing import parameterized

import numpy as np
import opt_einsum

from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging

class LBetaTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_one_dimensional_arg(self):
    # Should evaluate to 1 and 1/2.
    x_one = [1, 1.]
    x_one_half = [2, 1.]
    with self.session():
      self.assertAllClose(
          1, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one))))
      self.assertAllClose(
          0.5, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one_half))))
      self.assertEqual([], special_math_ops.lbeta(x_one).get_shape())

  @test_util.run_deprecated_v1
  def test_one_dimensional_arg_dynamic(self):
    # Should evaluate to 1 and 1/2.
    x_one = [1, 1.]
    x_one_half = [2, 1.]
    with self.session():
      ph = array_ops.placeholder(dtypes.float32)
      beta_ph = math_ops.exp(special_math_ops.lbeta(ph))
      self.assertAllClose(1, beta_ph.eval(feed_dict={ph: x_one}))
      self.assertAllClose(0.5,
                          beta_ph.eval(feed_dict={ph: x_one_half}))

  @test_util.run_deprecated_v1
  def test_four_dimensional_arg_with_partial_shape_dynamic(self):
    x_ = np.ones((3, 2, 3, 4))
    # Gamma(1) = 0! = 1
    # Gamma(1 + 1 + 1 + 1) = Gamma(4) = 3! = 6
    # ==> Beta([1, 1, 1, 1])
    #     = Gamma(1) * Gamma(1) * Gamma(1) * Gamma(1) / Gamma(1 + 1 + 1 + 1)
    #     = 1 / 6
    expected_beta_x = 1 / 6 * np.ones((3, 2, 3))
    with self.session():
      x_ph = array_ops.placeholder(dtypes.float32, [3, 2, 3, None])
      beta_ph = math_ops.exp(special_math_ops.lbeta(x_ph))
      self.assertAllClose(expected_beta_x,
                          beta_ph.eval(feed_dict={x_ph: x_}))

  @test_util.run_in_graph_and_eager_modes
  def test_two_dimensional_arg(self):
    # Should evaluate to 1/2.
    x_one_half = [[2, 1.], [2, 1.]]
    with self.session():
      self.assertAllClose(
          [0.5, 0.5],
          self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one_half))))
      self.assertEqual((2,), special_math_ops.lbeta(x_one_half).get_shape())

  @test_util.run_deprecated_v1
  def test_two_dimensional_arg_dynamic(self):
    # Should evaluate to 1/2.
    x_one_half = [[2, 1.], [2, 1.]]
    with self.session():
      ph = array_ops.placeholder(dtypes.float32)
      beta_ph = math_ops.exp(special_math_ops.lbeta(ph))
      self.assertAllClose([0.5, 0.5],
                          beta_ph.eval(feed_dict={ph: x_one_half}))

  @test_util.run_in_graph_and_eager_modes
  def test_two_dimensional_proper_shape(self):
    # Should evaluate to 1/2.
    x_one_half = [[2, 1.], [2, 1.]]
    with self.session():
      self.assertAllClose(
          [0.5, 0.5],
          self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one_half))))
      self.assertEqual(
          (2,),
          self.evaluate(array_ops.shape(special_math_ops.lbeta(x_one_half))))
      self.assertEqual(
          tensor_shape.TensorShape([2]),
          special_math_ops.lbeta(x_one_half).get_shape())

  @test_util.run_in_graph_and_eager_modes
  def test_complicated_shape(self):
    with self.session():
      x = ops.convert_to_tensor(np.random.rand(3, 2, 2))
      self.assertAllEqual(
          (3, 2), self.evaluate(array_ops.shape(special_math_ops.lbeta(x))))
      self.assertEqual(
          tensor_shape.TensorShape([3, 2]),
          special_math_ops.lbeta(x).get_shape())

  @test_util.run_in_graph_and_eager_modes
  def test_length_1_last_dimension_results_in_one(self):
    # If there is only one coefficient, the formula still works, and we get one
    # as the answer, always.
    x_a = [5.5]
    x_b = [0.1]
    with self.session():
      self.assertAllClose(
          1,
          self.evaluate(math_ops.exp(special_math_ops.lbeta(x_a))),
          rtol=3e-6)
      self.assertAllClose(
          1, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_b))))
      self.assertEqual((), special_math_ops.lbeta(x_a).get_shape())

  @test_util.run_in_graph_and_eager_modes
  def test_empty_rank1_returns_negative_infinity(self):
    with self.session():
      x = constant_op.constant([], shape=[0])
      lbeta_x = special_math_ops.lbeta(x)
      expected_result = constant_op.constant(-np.inf, shape=())

      self.assertAllEqual(self.evaluate(expected_result),
                          self.evaluate(lbeta_x))
      self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape())

  @test_util.run_in_graph_and_eager_modes
  def test_empty_rank2_with_zero_last_dim_returns_negative_infinity(self):
    with self.session():
      event_size = 0
      for batch_size in [0, 1, 2]:
        x = constant_op.constant([], shape=[batch_size, event_size])
        lbeta_x = special_math_ops.lbeta(x)
        expected_result = constant_op.constant(-np.inf, shape=[batch_size])

        self.assertAllEqual(self.evaluate(expected_result),
                            self.evaluate(lbeta_x))
        self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape())

  @test_util.run_in_graph_and_eager_modes
  def test_empty_rank2_with_zero_batch_dim_returns_empty(self):
    with self.session():
      batch_size = 0
      for event_size in [0, 1, 2]:
        x = constant_op.constant([], shape=[batch_size, event_size])
        lbeta_x = special_math_ops.lbeta(x)

        expected_result = constant_op.constant([], shape=[batch_size])

        self.assertAllEqual(self.evaluate(expected_result),
                            self.evaluate(lbeta_x))
        self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape())


@test_util.run_all_in_graph_and_eager_modes
class DawsnTest(test.TestCase, parameterized.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_dawsn_boundary(self):
    self.assertAllClose(0., special_math_ops.dawsn(0.))
    self.assertTrue(np.isnan(self.evaluate(special_math_ops.dawsn(np.nan))))

  @parameterized.parameters(np.float32, np.float64)
  def test_dawsn_odd(self, dtype):
    x = np.random.uniform(-100., 100., size=int(1e4)).astype(dtype)
    self.assertAllClose(
        self.evaluate(special_math_ops.dawsn(x)),
        self.evaluate(-special_math_ops.dawsn(-x)))

  @parameterized.parameters(np.float32, np.float64)
  def test_dawsn_small(self, dtype):
    x = np.random.uniform(-1., 1., size=int(1e4)).astype(dtype)
    try:
      from scipy import special  # pylint: disable=g-import-not-at-top
      self.assertAllClose(
          special.dawsn(x), self.evaluate(special_math_ops.dawsn(x)))
    except ImportError as e:
      tf_logging.warn('Cannot test special functions: %s' % str(e))

  @parameterized.parameters(np.float32, np.float64)
  def test_dawsn_larger(self, dtype):
    x = np.random.uniform(1., 100., size=int(1e4)).astype(dtype)
    try:
      from scipy import special  # pylint: disable=g-import-not-at-top
      self.assertAllClose(
          special.dawsn(x), self.evaluate(special_math_ops.dawsn(x)))
    except ImportError as e:
      tf_logging.warn('Cannot test special functions: %s' % str(e))

  def test_dawsn_gradient(self):
    inputs = [np.random.uniform(-50., 50., size=int(1e2))]
    analytical, numerical = gradient_checker_v2.compute_gradient(
        special_math_ops.dawsn, inputs)
    self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)


@test_util.run_all_in_graph_and_eager_modes
class ExpintTest(test.TestCase, parameterized.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_expint_boundary(self):
    self.assertAllClose(-np.inf, special_math_ops.expint(0.))
    self.assertTrue(np.isnan(self.evaluate(special_math_ops.expint(np.nan))))
    # Check that the domain of definition is [0, inf)
    self.assertTrue(
        np.all(
            np.isnan(
                self.evaluate(
                    special_math_ops.expint(
                        np.random.uniform(-20., -1., size=int(1e3)))))))

  @parameterized.parameters(np.float32, np.float64)
  def test_expint_small(self, dtype):
    x = np.random.uniform(0., 1., size=int(1e4)).astype(dtype)
    try:
      from scipy import special  # pylint: disable=g-import-not-at-top
      self.assertAllClose(
          special.expi(x), self.evaluate(special_math_ops.expint(x)))
    except ImportError as e:
      tf_logging.warn('Cannot test special functions: %s' % str(e))

  @parameterized.parameters(np.float32, np.float64)
  def test_expint_larger(self, dtype):
    x = np.random.uniform(1., 50., size=int(1e4)).astype(dtype)
    try:
      from scipy import special  # pylint: disable=g-import-not-at-top
      self.assertAllClose(
          special.expi(x), self.evaluate(special_math_ops.expint(x)))
    except ImportError as e:
      tf_logging.warn('Cannot test special functions: %s' % str(e))

  def test_expint_gradient(self):
    inputs = [np.random.uniform(1., 10., size=int(1e2))]
    analytical, numerical = gradient_checker_v2.compute_gradient(
        special_math_ops.expint, inputs)
    self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 5e-3)


@test_util.run_all_in_graph_and_eager_modes
class FresnelCosTest(test.TestCase, parameterized.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_fresnel_cos_boundary(self):
    self.assertAllClose(0., special_math_ops.fresnel_cos(0.))
    self.assertTrue(
        np.isnan(self.evaluate(special_math_ops.fresnel_cos(np.nan))))

  @parameterized.parameters(np.float32, np.float64)
  def test_fresnel_cos_odd(self, dtype):
    x = np.random.uniform(-100., 100., size=int(1e4)).astype(dtype)
    self.assertAllClose(
        self.evaluate(special_math_ops.fresnel_cos(x)),
        self.evaluate(-special_math_ops.fresnel_cos(-x)))

  @parameterized.parameters(np.float32, np.float64)
  def test_fresnel_cos_small(self, dtype):
    x = np.random.uniform(0., 1., size=int(1e4)).astype(dtype)
    try:
      from scipy import special  # pylint: disable=g-import-not-at-top
      self.assertAllClose(
          special.fresnel(x)[1], self.evaluate(special_math_ops.fresnel_cos(x)))
    except ImportError as e:
      tf_logging.warn('Cannot test special functions: %s' % str(e))

  @parameterized.parameters(np.float32, np.float64)
  def test_fresnel_cos_larger(self, dtype):
    x = np.random.uniform(1., 100., size=int(1e4)).astype(dtype)
    try:
      from scipy import special  # pylint: disable=g-import-not-at-top
      self.assertAllClose(
          special.fresnel(x)[1],
          self.evaluate(special_math_ops.fresnel_cos(x)),
          rtol=1e-5)
    except ImportError as e:
      tf_logging.warn('Cannot test special functions: %s' % str(e))

  def test_fresnel_cos_gradient(self):
    inputs = [np.random.uniform(1., 50., size=int(1e2))]
    analytical, numerical = gradient_checker_v2.compute_gradient(
        special_math_ops.fresnel_cos, inputs)
    self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 5e-3)


@test_util.run_all_in_graph_and_eager_modes
class FresnelSinTest(test.TestCase, parameterized.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_fresnel_sin_boundary(self):
    self.assertAllClose(0., special_math_ops.fresnel_sin(0.))
    self.assertTrue(
        np.isnan(self.evaluate(special_math_ops.fresnel_sin(np.nan))))

  @parameterized.parameters(np.float32, np.float64)
  def test_fresnel_sin_odd(self, dtype):
    x = np.random.uniform(-100., 100., size=int(1e4)).astype(dtype)
    self.assertAllClose(
        self.evaluate(special_math_ops.fresnel_sin(x)),
        self.evaluate(-special_math_ops.fresnel_sin(-x)))

  @parameterized.parameters(np.float32, np.float64)
  def test_fresnel_sin_small(self, dtype):
    x = np.random.uniform(0., 1., size=int(1e4)).astype(dtype)
    try:
      from scipy import special  # pylint: disable=g-import-not-at-top
      self.assertAllClose(
          special.fresnel(x)[0], self.evaluate(special_math_ops.fresnel_sin(x)))
    except ImportError as e:
      tf_logging.warn('Cannot test special functions: %s' % str(e))

  @parameterized.parameters(np.float32, np.float64)
  def test_fresnel_sin_larger(self, dtype):
    x = np.random.uniform(1., 100., size=int(1e4)).astype(dtype)
    try:
      from scipy import special  # pylint: disable=g-import-not-at-top
      self.assertAllClose(
          special.fresnel(x)[0],
          self.evaluate(special_math_ops.fresnel_sin(x)),
          rtol=1e-5)
    except ImportError as e:
      tf_logging.warn('Cannot test special functions: %s' % str(e))

  def test_fresnel_sin_gradient(self):
    inputs = [np.random.uniform(1., 50., size=int(1e2))]
    analytical, numerical = gradient_checker_v2.compute_gradient(
        special_math_ops.fresnel_sin, inputs)
    self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 5e-3)


@test_util.run_all_in_graph_and_eager_modes
class SpenceTest(test.TestCase, parameterized.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_spence_boundary(self):
    self.assertAllClose(np.pi**2 / 6., special_math_ops.spence(0.))
    self.assertAllClose(0., special_math_ops.spence(1.))
    self.assertTrue(np.isnan(self.evaluate(special_math_ops.spence(np.nan))))
    # Check that the domain of definition is [0, inf)
    self.assertTrue(
        np.all(
            np.isnan(
                self.evaluate(
                    special_math_ops.spence(
                        np.random.uniform(-20., -1., size=int(1e3)))))))

  @parameterized.parameters(np.float32, np.float64)
  def test_spence_small(self, dtype):
    x = np.random.uniform(0., 1., size=int(1e4)).astype(dtype)
    try:
      from scipy import special  # pylint: disable=g-import-not-at-top
      self.assertAllClose(
          special.spence(x), self.evaluate(special_math_ops.spence(x)))
    except ImportError as e:
      tf_logging.warn('Cannot test special functions: %s' % str(e))

  @parameterized.parameters(np.float32, np.float64)
  def test_spence_larger(self, dtype):
    x = np.random.uniform(1., 100., size=int(1e4)).astype(dtype)
    try:
      from scipy import special  # pylint: disable=g-import-not-at-top
      self.assertAllClose(
          special.spence(x), self.evaluate(special_math_ops.spence(x)))
    except ImportError as e:
      tf_logging.warn('Cannot test special functions: %s' % str(e))

  def test_spence_gradient(self):
    inputs = [np.random.uniform(1., 50., size=int(1e2))]
    analytical, numerical = gradient_checker_v2.compute_gradient(
        special_math_ops.spence, inputs)
    self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)

  def test_spence_gradient_at_one(self):
    analytical, _ = gradient_checker_v2.compute_gradient(
        special_math_ops.spence, [1.])
    self.assertAllClose([[[-1.]]], analytical)


@test_util.run_all_in_graph_and_eager_modes
class BesselTest(test.TestCase, parameterized.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_besseli_boundary(self):
    self.assertAllClose(1., special_math_ops.bessel_i0(0.))
    self.assertAllClose(1., special_math_ops.bessel_i0e(0.))
    self.assertAllClose(0., special_math_ops.bessel_i1(0.))
    self.assertAllClose(0., special_math_ops.bessel_i1e(0.))
    self.assertTrue(np.isnan(self.evaluate(special_math_ops.bessel_i0(np.nan))))
    self.assertTrue(
        np.isnan(self.evaluate(special_math_ops.bessel_i0e(np.nan))))
    self.assertTrue(np.isnan(self.evaluate(special_math_ops.bessel_i1(np.nan))))
    self.assertTrue(
        np.isnan(self.evaluate(special_math_ops.bessel_i1e(np.nan))))

  @test_util.run_in_graph_and_eager_modes
  def test_besselj_boundary(self):
    self.assertAllClose(1., special_math_ops.bessel_j0(0.))
    self.assertAllClose(0., special_math_ops.bessel_j1(0.))
    self.assertTrue(np.isnan(self.evaluate(special_math_ops.bessel_j0(np.nan))))
    self.assertTrue(np.isnan(self.evaluate(special_math_ops.bessel_j1(np.nan))))

  @test_util.run_in_graph_and_eager_modes
  def test_besselk_boundary(self):
    self.assertTrue(np.isinf(self.evaluate(special_math_ops.bessel_k0(0.))))
    self.assertTrue(np.isinf(self.evaluate(special_math_ops.bessel_k0e(0.))))
    self.assertTrue(np.isinf(self.evaluate(special_math_ops.bessel_k1(0.))))
    self.assertTrue(np.isinf(self.evaluate(special_math_ops.bessel_k1e(0.))))
    self.assertTrue(np.isnan(self.evaluate(special_math_ops.bessel_k0(np.nan))))
    self.assertTrue(
        np.isnan(self.evaluate(special_math_ops.bessel_k0e(np.nan))))
    self.assertTrue(np.isnan(self.evaluate(special_math_ops.bessel_k1(np.nan))))
    self.assertTrue(
        np.isnan(self.evaluate(special_math_ops.bessel_k1e(np.nan))))

  @parameterized.parameters(np.float32, np.float64)
  def test_i0j0_even(self, dtype):
    x = np.random.uniform(-100., 100., size=int(1e4)).astype(dtype)
    self.assertAllClose(
        self.evaluate(special_math_ops.bessel_i0(x)),
        self.evaluate(special_math_ops.bessel_i0(-x)))

    self.assertAllClose(
        self.evaluate(special_math_ops.bessel_i0e(x)),
        self.evaluate(special_math_ops.bessel_i0e(-x)))

    self.assertAllClose(
        self.evaluate(special_math_ops.bessel_j0(x)),
        self.evaluate(special_math_ops.bessel_j0(-x)))

  @parameterized.parameters(np.float32, np.float64)
  def test_i1j1_odd(self, dtype):
    x = np.random.uniform(-100., 100., size=int(1e4)).astype(dtype)
    self.assertAllClose(
        self.evaluate(special_math_ops.bessel_i1(x)),
        self.evaluate(-special_math_ops.bessel_i1(-x)))

    self.assertAllClose(
        self.evaluate(special_math_ops.bessel_i1e(x)),
        self.evaluate(-special_math_ops.bessel_i1e(-x)))

    self.assertAllClose(
        self.evaluate(special_math_ops.bessel_j1(x)),
        self.evaluate(-special_math_ops.bessel_j1(-x)))

  @parameterized.parameters(np.float32, np.float64)
  def test_besseli_small(self, dtype):
    x = np.random.uniform(-1., 1., size=int(1e4)).astype(dtype)
    try:
      from scipy import special  # pylint: disable=g-import-not-at-top
      self.assertAllClose(
          special.i0(x), self.evaluate(special_math_ops.bessel_i0(x)))
      self.assertAllClose(
          special.i1(x), self.evaluate(special_math_ops.bessel_i1(x)))
      self.assertAllClose(
          special.i0e(x), self.evaluate(special_math_ops.bessel_i0e(x)))
      self.assertAllClose(
          special.i1e(x), self.evaluate(special_math_ops.bessel_i1e(x)))
    except ImportError as e:
      tf_logging.warn('Cannot test special functions: %s' % str(e))

  @parameterized.parameters(np.float32, np.float64)
  def test_besselj_small(self, dtype):
    x = np.random.uniform(-1., 1., size=int(1e4)).astype(dtype)
    try:
      from scipy import special  # pylint: disable=g-import-not-at-top
      self.assertAllClose(
          special.j0(x), self.evaluate(special_math_ops.bessel_j0(x)))
      self.assertAllClose(
          special.j1(x), self.evaluate(special_math_ops.bessel_j1(x)))
    except ImportError as e:
      tf_logging.warn('Cannot test special functions: %s' % str(e))

  @parameterized.parameters(np.float32, np.float64)
  def test_besselk_small(self, dtype):
    x = np.random.uniform(np.finfo(dtype).eps, 1., size=int(1e4)).astype(dtype)
    try:
      from scipy import special  # pylint: disable=g-import-not-at-top
      self.assertAllClose(
          special.k0(x), self.evaluate(special_math_ops.bessel_k0(x)))
      self.assertAllClose(
          special.k0e(x), self.evaluate(special_math_ops.bessel_k0e(x)))
      self.assertAllClose(
          special.k1(x), self.evaluate(special_math_ops.bessel_k1(x)))
      self.assertAllClose(
          special.k1e(x), self.evaluate(special_math_ops.bessel_k1e(x)))
    except ImportError as e:
      tf_logging.warn('Cannot test special functions: %s' % str(e))

  @parameterized.parameters(np.float32, np.float64)
  def test_bessely_small(self, dtype):
    x = np.random.uniform(np.finfo(dtype).eps, 1., size=int(1e4)).astype(dtype)
    try:
      from scipy import special  # pylint: disable=g-import-not-at-top
      self.assertAllClose(
          special.y0(x), self.evaluate(special_math_ops.bessel_y0(x)))
      self.assertAllClose(
          special.y1(x), self.evaluate(special_math_ops.bessel_y1(x)))
    except ImportError as e:
      tf_logging.warn('Cannot test special functions: %s' % str(e))

  @parameterized.parameters(np.float32, np.float64)
  def test_besseli_larger(self, dtype):
    x = np.random.uniform(1., 20., size=int(1e4)).astype(dtype)
    try:
      from scipy import special  # pylint: disable=g-import-not-at-top
      self.assertAllClose(
          special.i0e(x), self.evaluate(special_math_ops.bessel_i0e(x)))
      self.assertAllClose(
          special.i1e(x), self.evaluate(special_math_ops.bessel_i1e(x)))
    except ImportError as e:
      tf_logging.warn('Cannot test special functions: %s' % str(e))

  @parameterized.parameters(np.float32, np.float64)
  def test_besselj_larger(self, dtype):
    x = np.random.uniform(1., 30., size=int(1e4)).astype(dtype)
    try:
      from scipy import special  # pylint: disable=g-import-not-at-top
      self.assertAllClose(
          special.j0(x), self.evaluate(special_math_ops.bessel_j0(x)))
      self.assertAllClose(
          special.j1(x), self.evaluate(special_math_ops.bessel_j1(x)))
    except ImportError as e:
      tf_logging.warn('Cannot test special functions: %s' % str(e))

  @parameterized.parameters(np.float32, np.float64)
  def test_besselk_larger(self, dtype):
    x = np.random.uniform(1., 30., size=int(1e4)).astype(dtype)
    try:
      from scipy import special  # pylint: disable=g-import-not-at-top
      self.assertAllClose(
          special.k0(x), self.evaluate(special_math_ops.bessel_k0(x)))
      self.assertAllClose(
          special.k0e(x), self.evaluate(special_math_ops.bessel_k0e(x)))
      self.assertAllClose(
          special.k1(x), self.evaluate(special_math_ops.bessel_k1(x)))
      self.assertAllClose(
          special.k1e(x), self.evaluate(special_math_ops.bessel_k1e(x)))
    except ImportError as e:
      tf_logging.warn('Cannot test special functions: %s' % str(e))

  @parameterized.parameters(np.float32, np.float64)
  def test_bessely_larger(self, dtype):
    x = np.random.uniform(1., 30., size=int(1e4)).astype(dtype)
    try:
      from scipy import special  # pylint: disable=g-import-not-at-top
      self.assertAllClose(
          special.y0(x), self.evaluate(special_math_ops.bessel_y0(x)))
      self.assertAllClose(
          special.y1(x), self.evaluate(special_math_ops.bessel_y1(x)))
    except ImportError as e:
      tf_logging.warn('Cannot test special functions: %s' % str(e))

  def test_besseli_gradient(self):
    inputs = [np.random.uniform(-10., 10., size=int(1e2))]
    analytical, numerical = gradient_checker_v2.compute_gradient(
        special_math_ops.bessel_i0, inputs)
    self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-3)

    analytical, numerical = gradient_checker_v2.compute_gradient(
        special_math_ops.bessel_i0e, inputs)
    self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)

    analytical, numerical = gradient_checker_v2.compute_gradient(
        special_math_ops.bessel_i1, inputs)
    self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-3)

    analytical, numerical = gradient_checker_v2.compute_gradient(
        special_math_ops.bessel_i1e, inputs)
    self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)

  def test_besselj_gradient(self):
    inputs = [np.random.uniform(-50., 50., size=int(1e2))]
    analytical, numerical = gradient_checker_v2.compute_gradient(
        special_math_ops.bessel_j0, inputs)
    self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)

    analytical, numerical = gradient_checker_v2.compute_gradient(
        special_math_ops.bessel_j1, inputs)
    self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)

  def test_besselk_gradient(self):
    inputs = [np.random.uniform(1., 50., size=int(1e2))]
    analytical, numerical = gradient_checker_v2.compute_gradient(
        special_math_ops.bessel_k0, inputs)
    self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)

    analytical, numerical = gradient_checker_v2.compute_gradient(
        special_math_ops.bessel_k0e, inputs)
    self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)

    analytical, numerical = gradient_checker_v2.compute_gradient(
        special_math_ops.bessel_k1, inputs)
    self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)

    analytical, numerical = gradient_checker_v2.compute_gradient(
        special_math_ops.bessel_k1e, inputs)
    self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)

  def test_bessely_gradient(self):
    inputs = [np.random.uniform(1., 50., size=int(1e2))]
    analytical, numerical = gradient_checker_v2.compute_gradient(
        special_math_ops.bessel_y0, inputs)
    self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)

    analytical, numerical = gradient_checker_v2.compute_gradient(
        special_math_ops.bessel_y1, inputs)
    self.assertLess(gradient_checker_v2.max_error(analytical, numerical), 1e-4)


@test_util.run_all_in_graph_and_eager_modes
@test_util.run_all_without_tensor_float_32(
    'Tests einsum, which sometimes does a matmul with cuBLAS')
class EinsumTest(test.TestCase):

  def _check(self, s, *input_shapes, **kwargs):
    dtype = kwargs.pop('dtype', np.float32)
    r = np.random.RandomState(0)
    inputs = []
    for shape in input_shapes:
      arr = np.array(r.randn(*shape)).astype(dtype)
      if dtype == np.complex64 or dtype == np.complex128:
        arr += 1j * np.array(r.randn(*shape)).astype(dtype)
      inputs.append(arr)
    input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
    a = np.einsum(s, *inputs)
    b = self.evaluate(special_math_ops.einsum(s, *input_tensors))
    self.assertAllClose(a, b, atol=1e-4, rtol=1e-4)

  def test_invalid_keyword_arguments(self):
    r = np.random.RandomState(0)
    a = array_ops.placeholder_with_default(r.randn(2, 3), shape=(2, 3))
    b = array_ops.placeholder_with_default(r.randn(3, 4), shape=(3, 4))
    with self.assertRaises(TypeError):
      _ = special_math_ops.einsum(
          'ij,jk->ik', a, b, name='name', invalid1='value1', invalid2='value2')

  def test_unary(self):
    self._check('a', (3,))
    self._check('aa', (3, 3))
    self._check('ab->', (3, 3))
    self._check('ab->ab', (3, 3))
    self._check('abc->b', (3, 4, 5))
    self._check('abc->ca', (3, 4, 5))
    self._check('abc->cab', (3, 4, 5))

    # Empty cases.
    self._check('', ())
    self._check('->', ())

    # Repeated indices cases.
    self._check('aa->', (3, 3))
    self._check('aa->a', (3, 3))
    self._check('aaa->', (3, 3, 3))
    self._check('aaa->a', (3, 3, 3))
    self._check('aab->a', (3, 3, 4))
    self._check('aabcc->a', (3, 3, 5, 4, 4))
    self._check('aabcc->ac', (3, 3, 5, 4, 4))
    self._check('aabcd->ad', (3, 3, 5, 4, 4))

  def test_unary_ellipsis(self):
    self._check('...->', ())
    self._check('...ijk->...ki', (3, 4, 5))
    self._check('...ijk->...ki', (1, 3, 4, 5))
    self._check('...ijk->...ki', (2, 2, 3, 4, 5))
    self._check('...ij->...ji', (5, 2, 3))  # batch matrix transpose
    self._check('...ij->...', (5, 2, 3))  # batch sum

    self._check('...->...', ())
    self._check('->...', ())

    # Repeated indices.
    self._check('i...ii->...i', (3, 2, 3, 3))
    self._check('i...i->i...', (2, 2))
    self._check('i...i->', (2, 2))
    self._check('i...i->...', (2, 5, 1, 2))
    self._check('i...i->i...', (2, 1, 2))
    self._check('i...i->i...', (2, 3, 4, 5, 2))

  def test_binary_simple(self):
    # Binary cases in XLA mode must have either (a) each index appearing exactly
    # once in both the inputs (batch or contraction index), or (b) appearing
    # exactly once in an input and in the output (free index).
    self._check(',->', (), ())
    self._check('a,a->', (3,), (3,))
    self._check('a,a->a', (3,), (3,))
    self._check('ab,b->a', (3, 4), (4,))
    self._check('ab,ab->', (3, 4), (3, 4))
    self._check('ab,bc->ac', (3, 4), (4, 5))
    self._check('nij,jk->nik', (5, 2, 3), (3, 4))
    self._check('abc,bad->abcd', (1, 2, 3), (2, 1, 4))
    # Based on https://github.com/google/jax/issues/37#issuecomment-448572187
    self._check('sa,shb->shab', (2, 1), (2, 3, 4))
    # Infer the output subscripts.
    self._check('ab,b', (3, 4), (4,))
    self._check('cab,b', (1, 3, 4), (4,))

  def test_reduced_indices(self):
    self._check('ba,b->', (3, 2), (3,))
    self._check('ab,ab->', (3, 4), (3, 4))

  def test_repeated_indices(self):
    # Repeated indices.
    self._check('ijj,k->ik', (2, 3, 3), (4,))
    self._check('aba,a->b', (3, 4, 3), (3,))
    # From https://github.com/dask/dask/pull/3412#discussion_r182413444
    self._check('aab,bc->ac', (2, 2, 3), (3, 4))
    self._check('aab,bcc->ac', (2, 2, 3), (3, 4, 4))

  def test_binary_ellipsis(self):
    # Batch matmul with ellipsis but without broadcasting.
    self._check('...mk,...kn->...mn', (5, 1, 2, 3), (5, 1, 3, 4))
    # Empty batch dimensions.
    self._check('...mk,...kn->...mn', (2, 3), (3, 4))
    # Tensor contraction with transpose.
    self._check('...ija,aijb...->ba...ij', (1, 2, 2, 3, 1), (1, 2, 3, 4, 1, 2))
    # Output subscripts may omit ellipsis when batch shape is empty.
    self._check('...mk,...kn->mn', (2, 3), (3, 4))
    self._check('...mk,kn->mn', (2, 3), (3, 4))
    self._check('mk,...kn->mn', (2, 3), (3, 4))
    self._check('...,...->...', (2, 3), (2, 3))  # hadamard product
    self._check('...i,...j->...ij', (5, 2), (5, 3))  # outer product

  def test_broadcasting(self):
    # Batch matmul with broadcasting.
    self._check('...ij,...jk->...ik', (1, 2, 3), (3, 5))
    self._check('...ij,...jk->...ik', (2, 3), (1, 3, 5))
    self._check('...ij,...jk->...ik', (5, 2, 3), (3, 5))
    self._check('...ij,...jk->...ik', (2, 3), (5, 3, 5))
    self._check('...ij,...jk->...ik', (3, 1, 2, 3), (1, 1, 7, 3, 5))
    self._check('i...j,j...k->...ik', (2, 1, 3, 1, 3), (3, 1, 7, 5))

    # Broadcasting with repeated indices.
    self._check('ij,jk...k->i...', (3, 2), (2, 4, 1, 4))
    self._check('ij,jk...k->...i', (3, 2), (2, 4, 5, 4))
    self._check('ijj,jk...k->i...', (3, 2, 2), (2, 4, 1, 4))
    self._check('i...jj,jk...k->i...', (3, 3, 1, 2, 2), (2, 4, 1, 5, 4))
    # Following 2 from https://stackoverflow.com/a/19203475/1611416
    self._check('...abc,...abcd->...d', (1, 1, 2, 3, 4), (5, 2, 3, 4, 6))
    self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,))

  def test_dtypes(self):
    dtypes = [np.float64, np.float32, np.complex64, np.complex128]
    for dtype in dtypes:
      self._check('ij,jk->ik', (2, 2), (2, 2), dtype=dtype)
      self._check('ji,jk->ik', (2, 2), (2, 2), dtype=dtype)
      self._check('ji,kj->ik', (2, 2), (2, 2), dtype=dtype)
      self._check('ij,jk->ki', (2, 2), (2, 2), dtype=dtype)
      self._check('ji,kj->ki', (2, 2), (2, 2), dtype=dtype)

  def test_multiple_inputs(self):
    self._check('ijk,ijl,ikl->i', (1, 2, 3), (1, 2, 4), (1, 3, 4))
    self._check('i,ijk,j->k', (1,), (1, 2, 4), (2,))
    self._check('ij,ij,jk,kl->il', (1, 2), (1, 2), (2, 3), (3, 4))
    # Tests from dask.
    self._check('a,b,c', (5,), (7,), (9,))
    self._check('ab,ab,c->c', (5, 6), (5, 6), (2,))

  @test_util.disable_xla('b/131919749')
  def test_placeholder(self):

    def check(equation, *input_and_placeholder_shapes):
      r = np.random.RandomState(0)
      inputs = []
      input_placeholders = []
      for actual_shape, placeholder_shape in input_and_placeholder_shapes:
        input_np = np.array(r.randn(*actual_shape))
        inputs.append(input_np)
        input_placeholders.append(
            array_ops.placeholder_with_default(input_np, placeholder_shape))

      a = np.einsum(equation, *inputs)
      b = self.evaluate(special_math_ops.einsum(equation, *input_placeholders))
      self.assertAllClose(a, b, atol=1e-4, rtol=1e-4)

    check('bijl,bjkm->bik', ((9, 2, 3, 5), (None, None, None, 5)),
          ((9, 3, 4, 7), (None, None, 4, None)))
    check('...ij,...->...i', ((4, 3, 1, 2), (None, 3, None, 2)),
          ((4, 3), (None, 3)))

    # Ellipsis with unknown rank.
    check('bijl,bjkm->bik', ((9, 2, 3, 5), None), ((9, 3, 4, 7), None))
    check('...ij,...jk->...ik', ((3, 1, 2, 3), None), ((1, 7, 3, 4), None))

  def test_numpy_input(self):
    # In addition to Tensors, we also support raw numpy arrays as inputs.
    r = np.random.RandomState(0)
    s = 'ijk,ijl,ikl->i'
    x = r.randn(1, 2, 3)
    y = r.randn(1, 2, 4)
    z = r.randn(1, 3, 4)

    a = np.einsum(s, x, y, z)
    b = self.evaluate(special_math_ops.einsum(s, x, y, z))
    self.assertAllClose(a, b, atol=1e-4, rtol=1e-4)

  def test_long_cases(self):
    cases = [
        'efc,dbc,acf,fd->abe',
        'ea,fb,gc,hd,abcd->efgh',
        'abhe,hidj,jgba,hiab,gab->ed',
        # Cases with whitespace.
        'efc, dbc, acf, fd -> abe',
        'abhe, hidj, jgba, hiab, gab',
        # Repeated equations for cache hit on the opt_einsum call.
        'ea,fb,abcd,gc,hd->efgh',
        'ea,fb,abcd,gc,hd->efgh',
    ]
    dimension_map = dict((c, ord(c) - ord('a') + 1) for c in 'abcdefghij')
    for equation in cases:
      inputs = equation.split('->')[0].replace(' ', '')
      input_shapes = []
      for input_str in inputs.split(','):
        input_shapes.append(tuple([dimension_map[c] for c in input_str]))
      self._check(equation, *input_shapes)

  def test_opt_einsum_cached(self):
    # Checks call_count to opt_einsum which are only reflected in eager mode.
    if not context.executing_eagerly():
      return

    input_1 = ('ijk,ijl,ikl->i', (1, 2, 3), (1, 2, 4), (1, 3, 4))
    input_2 = ('ij,ij,jk,kl->il', (1, 2), (1, 2), (2, 3), (3, 4))

    with test.mock.patch.object(
        opt_einsum, 'contract_path',
        wraps=opt_einsum.contract_path) as mock_contract_path:

      # explicitly clear the lru_cache contents for the method
      #   special_math_ops.get_opt_einsum_contract_path
      # We need to do this because other tests in this file invoke that method
      # with the same input args (as input_1 and input_2 above), and if
      # those tests run before this test, then the call_count for the method
      # mock_contract_path will not increment.
      special_math_ops._get_opt_einsum_contract_path.cache_clear()

      self.assertEqual(mock_contract_path.call_count, 0)
      self._check(*input_1)
      self.assertEqual(mock_contract_path.call_count, 1)
      # The same input results in no extra call if we're caching the
      # opt_einsum.contract_path call. We only cache in Python3.
      self._check(*input_1)
      self.assertEqual(mock_contract_path.call_count, 1)
      # New input results in another call to opt_einsum.
      self._check(*input_2)
      self.assertEqual(mock_contract_path.call_count, 2)
      # No more extra calls as the inputs should be cached.
      self._check(*input_1)
      self._check(*input_2)
      self._check(*input_1)
      self.assertEqual(mock_contract_path.call_count, 2)

  @test_util.disable_xla('b/131919749')
  def test_long_cases_with_repeated_labels(self):
    cases = [
        # Tests from dask.
        'fdf,cdd,ccd,afe->ae',
        'fff,fae,bef,def->abd',
    ]
    dimension_map = dict((c, ord(c) - ord('a') + 1) for c in 'abcdefghij')
    for equation in cases:
      inputs = equation.split('->')[0].replace(' ', '')
      input_shapes = []
      for input_str in inputs.split(','):
        input_shapes.append(tuple([dimension_map[c] for c in input_str]))
      self._check(equation, *input_shapes)

  @test_util.disable_xla('b/131919749')
  @test_util.run_in_graph_and_eager_modes
  def test_invalid_equation(self):
    r = np.random.RandomState(0)
    cases = [
        # invalid equation format.
        ('a0->a', r.randn(5, 3)),
        ('a->a,a', r.randn(5)),
        ('a->a->a', r.randn(5)),
        ('ijk ijk', r.randn(1, 2, 3), r.randn(1, 2, 3)),
        ('ij.jk->ik', r.randn(2, 3), r.randn(3, 4)),
        # output label not present in input.
        ('a->b', r.randn(5)),
        ('ij,jk->im', r.randn(2, 3), r.randn(3, 4)),
        # wrong shape.
        ('ij,jk->ik', r.randn(1, 2, 3), r.randn(3, 4)),
        # inconsistent dimensions.
        ('ij,jk->ik', r.randn(2, 3), r.randn(4, 4)),
        # output has repeated subscripts.
        ('ij,jk->iik', r.randn(2, 3), r.randn(3, 4)),
        # too many ellipses
        ('...ij...,jk...->ik...', r.randn(2, 3), r.randn(3, 4)),
        ('...ij,jk...->...ik...', r.randn(2, 3), r.randn(3, 4)),
        # invalid broadcast dimensions.
        ('...ij,...jk->...ik', r.randn(5, 2, 3), r.randn(7, 3, 4)),
        # output should have ellipsis when broadcasting shape is non-empty.
        ('...ij,...jk->ik', r.randn(2, 2, 3), r.randn(3, 4)),
    ]
    for args in cases:
      with self.assertRaises((ValueError, errors.InvalidArgumentError)):
        _ = special_math_ops.einsum(*args)

      placeholders = [
          array_ops.placeholder_with_default(x, shape=None) for x in args[1:]
      ]
      with self.assertRaises((ValueError, errors.InvalidArgumentError)):
        _ = self.evaluate(special_math_ops.einsum(args[0], *placeholders))

  @test_util.disable_xla('b/131919749')
  def test_empty(self):

    def check(equation, input_shapes, output_shape):
      # All these cases result in an output filled with zeros, so we don't call
      # np.einsum. Also np.einsum doesn't support generalized diagonals which
      # are needed for EinsumOp gradients.
      r = np.random.RandomState(0)
      inputs = [np.array(r.randn(*shape)) for shape in input_shapes]
      input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
      output = self.evaluate(special_math_ops.einsum(equation, *input_tensors))
      self.assertAllClose(output, np.zeros(output_shape), atol=1e-4, rtol=1e-4)

    # Contractions along zero-sized dimensions.
    check('ab,bc->ac', [(0, 10), (10, 10)], (0, 10))
    # From transformer xl.
    check('ibnd,ijbn->jnd', [(1, 0, 5, 10), (1, 1, 0, 5)], (1, 5, 10))

    # Generalized traces with zero-sized dimensions.
    check('aab,bc->ac', [(0, 0, 10), (10, 10)], (0, 10))
    check('aaab,bc->c', [(0, 0, 0, 3), (3, 4)], (4,))


@test_util.run_all_in_graph_and_eager_modes
class EinsumGradTest(test.TestCase):

  def _check_gradient(self, s, *input_shapes):
    with self.cached_session():
      r = np.random.RandomState(0)
      inputs = [np.array(r.randn(*shape)) for shape in input_shapes]
      input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
      analytical, numerical = gradient_checker_v2.compute_gradient(
          lambda *xs: special_math_ops.einsum(s, *xs), input_tensors)
      self.assertLess(
          gradient_checker_v2.max_error(analytical, numerical), 1e-4)

  @test_util.disable_xla('b/131919749')
  def test_unary(self):
    self._check_gradient('->', ())
    self._check_gradient('aaa->a', (3, 3, 3))
    self._check_gradient('aabcd->ad', (3, 3, 5, 4, 4))
    self._check_gradient('abcd->da', (3, 5, 4, 2))

  @test_util.disable_xla('b/131919749')
  def test_unary_ellipsis(self):
    self._check_gradient('...->...', ())
    self._check_gradient('...->', ())
    self._check_gradient('->...', ())

    # Tests from dask
    self._check_gradient('a...a->a...', (2, 2))
    self._check_gradient('a...a->', (2, 2))
    self._check_gradient('a...a->...', (2, 5, 1, 2))
    self._check_gradient('a...a->a...', (2, 1, 2))
    self._check_gradient('a...a->a...', (2, 3, 4, 5, 2))

    self._check_gradient('...ijk->...ki', (3, 4, 5))
    self._check_gradient('...ijk->...ki', (1, 3, 4, 5))
    self._check_gradient('...ijk->...ki', (2, 2, 3, 4, 5))
    self._check_gradient('ab...cd->da...', (3, 5, 2, 3, 4, 2))

  def test_binary_simple(self):
    # Binary cases in XLA mode must have either (a) each index appearing
    # exactly once in both the inputs (batch or contraction index), or
    # (b) appearing exactly once in an input and in the output (free index).
    self._check_gradient(',->', (), ())
    self._check_gradient('a,a->', (3,), (3,))
    self._check_gradient('a,a->a', (3,), (3,))
    self._check_gradient('ab,b->a', (3, 4), (4,))
    self._check_gradient('ab,ab->', (3, 4), (3, 4))
    self._check_gradient('ab,bc->ac', (3, 4), (4, 5))
    self._check_gradient('nij,jk->nik', (5, 2, 3), (3, 4))
    self._check_gradient('abc,bad->abcd', (1, 2, 3), (2, 1, 4))
    # Based on https://github.com/google/jax/issues/37#issuecomment-448572187
    self._check_gradient('sa,shb->shab', (2, 1), (2, 3, 4))

  def test_empty(self):
    # From Transformer XL.
    self._check_gradient('ibnd,ijbn->jnd', (1, 0, 5, 10), (1, 1, 0, 5))

  @test_util.disable_xla('b/131919749')
  def test_reduced_indices(self):
    self._check_gradient('ba,b->', (3, 2), (3,))
    self._check_gradient('ab,ab->', (3, 4), (3, 4))
    self._check_gradient('abce,badf->abcd', (1, 2, 3, 4), (2, 1, 4, 3))

  @test_util.disable_xla('b/131919749')
  def test_repeated_indices(self):
    # Repeated indices.
    self._check_gradient('aba,a->b', (3, 4, 3), (3,))
    self._check_gradient('ijj,k->ik', (2, 3, 3), (4,))
    self._check_gradient('ill,k->ik', (2, 3, 3), (4,))
    # From https://github.com/dask/dask/pull/3412#discussion_r182413444
    self._check_gradient('aab,bc->ac', (1, 1, 3), (3, 4))
    self._check_gradient('aab,bcc->ac', (2, 2, 3), (3, 4, 4))

  @test_util.disable_xla('b/131919749')
  def test_empty_with_repeated_indices(self):
    self._check_gradient('aab,bc->ac', (0, 0, 10), (10, 10))
    self._check_gradient('aab,bc->ac', (1, 1, 0), (0, 10))
    self._check_gradient('aaab,bc->c', (0, 0, 0, 3), (3, 4))

  @test_util.disable_xla('b/131919749')
  def test_broadcasting(self):
    self._check_gradient('...ij,...jk->...ik', (3, 2), (2, 4))
    self._check_gradient('ij...,jk...->ik...', (3, 2, 1), (2, 4))
    self._check_gradient('...ij,...jk->...ik', (3, 1, 3, 2), (1, 5, 2, 4))
    self._check_gradient('ij,jk...k->i...', (3, 2), (2, 4, 1, 4))
    self._check_gradient('aab,b...c->a...c', (1, 1, 3), (3, 1, 1, 4))
    # Tests from dask.
    self._check_gradient('...i,...j,...k->...ijk', (1, 4, 1, 2), (5, 1, 1, 3),
                         (1, 1, 1, 1, 9))
    self._check_gradient('...i,...j,...k->...ijk', (1,), (1,), (1,))

  def test_long_cases(self):
    cases = [
        'abhe,hidj,jgba,hiab,gab->ed',
        # Tests from dask.
        'ea,fb,abcd,gc,hd->efgh',
    ]
    dimension_map = dict(
        (c, ((ord(c) - ord('a')) % 3) + 1) for c in 'abcdefghij')
    for equation in cases:
      inputs = equation.split('->')[0].replace(' ', '')
      input_shapes = []
      for input_str in inputs.split(','):
        input_shapes.append(tuple([dimension_map[c] for c in input_str]))
      self._check_gradient(equation, *input_shapes)

  @test_util.disable_xla('b/131919749')
  def test_long_cases_with_repeated_labels(self):
    cases = [
        # Tests from dask.
        'fdf,cdd,ccd,afe->ae',
        'fff,fae,bef,def->abd',
    ]
    dimension_map = dict(
        (c, ((ord(c) - ord('a')) % 3) + 1) for c in 'abcdefghij')
    for equation in cases:
      inputs = equation.split('->')[0].replace(' ', '')
      input_shapes = []
      for input_str in inputs.split(','):
        input_shapes.append(tuple([dimension_map[c] for c in input_str]))
      self._check_gradient(equation, *input_shapes)


class EinsumBenchmark(test.Benchmark):
  cases = [
      # Unary cases.
      ['ijk->i', 100],
      ['ijk->kji', 100],
      # Regular matmul or batch matmul.
      ['ij,jk->ik', 500],
      ['ji,kj->ik', 500],
      ['bij,bjk->bik', 100],
      ['bji,bjk->bki', 100],
      ['ikl,kji->kl', 100],
      ['klj,lki->ij', 100],
      ['ijk,ilj->kli', 100],
      ['ijk,jklm->il', 50],
      # Larger binary contractions.
      ['efabc,eabcd->efd', 20],
      ['fabec,abcde->fde', 20],
      ['efabc,edabc->efd', 20],
      ['eadbf,dfebc->ecfad', 20],
      ['abcdef,bcdfg->abcdeg', 20],
      # Chain matmul.
      ['ij,jk,kl->il', 1000],
      # Long cases. Path optimization should kick in.
      ['ea,fb,abcd,gc,hd->efgh', 10],
      ['bca,cdb,dbf,afc->', 10],
      ['efc,dbc,acf,fd->abe', 10],
      ['abhe,hidj,jgba,hiab,gab->ed', 10],
  ]

  def benchmark_einsum(self):
    for equation, dim in self.cases:
      with ops.Graph().as_default(), \
          session.Session(config=benchmark.benchmark_config()) as sess, \
          ops.device('/cpu:0'):
        r = np.random.RandomState(0)
        input_subscripts = equation.split('->')[0].split(',')
        input_vars = []
        for subscript in input_subscripts:
          input_shape = (dim,) * len(subscript)
          input_vars.append(
              variables.Variable(np.array(r.randn(*input_shape), np.float32)))
        self.evaluate(variables.global_variables_initializer())

        if len(input_vars) <= 2:
          self.run_op_benchmark(
              sess,
              special_math_ops.einsum(equation, *input_vars),
              min_iters=50,
              name='einsum_cpu_({})_{}'.format(equation, dim))
        else:
          for optimize in ['greedy', 'auto']:
            self.run_op_benchmark(
                sess,
                special_math_ops.einsum(
                    equation, *input_vars, optimize=optimize),
                min_iters=50,
                name='einsum_cpu_({})_{}_{}'.format(equation, optimize, dim))


if __name__ == '__main__':
  test.main()