tensorflow/tensorflow

View on GitHub
tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py

Summary

Maintainability
F
3 wks
Test Coverage
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""CSR sparse matrix tests."""

from absl.testing import parameterized
import numpy as np
from scipy import sparse

from tensorflow.core.framework import tensor_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging

CPU = "/device:CPU:0"
GPU = "/device:GPU:0"


def dense_to_csr_sparse_matrix(dense):
  dense_t = ops.convert_to_tensor(dense)
  locs = array_ops.stop_gradient(array_ops.where(math_ops.abs(dense_t) > 0))
  return sparse_csr_matrix_ops.dense_to_csr_sparse_matrix(dense_t, locs)


def _swap(a, i, j):
  a[i], a[j] = a[j], a[i]


def twist_matrix(matrix, permutation_indices):
  """Permute the rows and columns of a 2D or (batched) 3D Tensor."""
  # Shuffle the rows and columns with the same permutation.
  if matrix.shape.ndims == 2:
    # Invert the permutation since `tf.gather` and `tf.gather_nd` need the
    # mapping from each index `i` to the index that maps to `i`.
    permutation_indices_inv = array_ops.invert_permutation(permutation_indices)
    matrix = array_ops.gather(matrix, permutation_indices_inv, axis=0)
    matrix = array_ops.gather(matrix, permutation_indices_inv, axis=1)
  elif matrix.shape.ndims == 3:
    permutation_indices_inv = map_fn.map_fn(array_ops.invert_permutation,
                                            permutation_indices)
    # For 3D Tensors, it's easy to shuffle the rows but not the columns. We
    # permute the rows, transpose, permute the rows again, and transpose back.
    batch_size = matrix.shape[0]
    batch_indices = array_ops.broadcast_to(
        math_ops.range(batch_size)[:, None], permutation_indices.shape)
    for _ in range(2):
      matrix = array_ops.gather_nd(
          matrix,
          array_ops_stack.stack([batch_indices, permutation_indices_inv],
                                axis=-1))
      # Transpose the matrix, or equivalently, swap dimensions 1 and 2.
      matrix = array_ops.transpose(matrix, perm=[0, 2, 1])
  else:
    raise ValueError("Input matrix must have rank 2 or 3. Got: {}".format(
        matrix.shape.ndims))

  return matrix


class CSRSparseMatrixOpsTest(test.TestCase, parameterized.TestCase):

  @classmethod
  def setUpClass(cls):  # pylint: disable=g-missing-super-call
    cls._gpu_available = test_util.is_gpu_available()

  # TODO(ebrevdo): This will work once we find a way to get rendezvous
  # working for CSRSparseMatrix and can remove the HostMemory
  # annotations for the other ops.
  @test_util.run_in_graph_and_eager_modes
  def DISABLEDtestFromProto(self):
    if not self._gpu_available:
      return

    a_indices = np.array([[0, 0], [2, 3]])
    a_values = np.asarray([1.0, 5.0], dtype=np.float32)
    a_dense_shape = np.asarray([5, 6], dtype=np.int64)
    a_sparse_mat = sparse.coo_matrix(
        (a_values, (a_indices[:, 0], a_indices[:, 1])), shape=a_dense_shape)
    a_csr_mat = a_sparse_mat.tocsr()
    a_col_inds = a_csr_mat.indices
    a_row_ptrs = a_csr_mat.indptr

    # Format of SparseMatrix:
    #  type_name == "tensorflow::CSRSparseMatrix"
    #  metadata == b (validated)
    #  tensors == [dense_shape, row_ptrs, col_indices, values]
    dense_shape_proto = tensor_util.make_tensor_proto(a_dense_shape)
    row_ptrs_proto = tensor_util.make_tensor_proto(a_row_ptrs)
    col_inds_proto = tensor_util.make_tensor_proto(a_col_inds)
    values_proto = tensor_util.make_tensor_proto(a_values)
    variant_tensor_data = tensor_pb2.VariantTensorDataProto(
        type_name="tensorflow::CSRSparseMatrix",
        metadata=np.asarray(True).tobytes(),
        tensors=[
            dense_shape_proto, row_ptrs_proto, col_inds_proto, values_proto
        ])
    tensor_proto = tensor_pb2.TensorProto(
        dtype=dtypes.variant.as_datatype_enum,
        tensor_shape=tensor_shape.TensorShape([]).as_proto())
    tensor_proto.variant_val.extend([variant_tensor_data])
    a_sm = constant_op.constant(tensor_proto)
    a_rt = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
        a_sm, type=dtypes.float32)
    self.evaluate(a_rt)

  @test_util.run_in_graph_and_eager_modes
  def testSparseTensorConversion(self):
    a_indices = np.array([[0, 0], [2, 3], [2, 4], [3, 0]])
    a_values = [1.0, 5.0, -1.0, -2.0]
    a_dense_shape = [5, 6]
    a_sparse_mat = sparse.coo_matrix(
        (a_values, (a_indices[:, 0], a_indices[:, 1])), shape=a_dense_shape)
    a_csr_mat = a_sparse_mat.tocsr()

    # Convert 2D SparseTensor to CSR Matrix
    a_st = sparse_tensor.SparseTensor(a_indices, a_values, a_dense_shape)
    a_st = math_ops.cast(a_st, dtypes.float32)
    a_sm = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
        a_st.indices, a_st.values, a_st.dense_shape)

    # Get row indices and columns for batch 0.
    a_sm_row_ptrs, a_sm_col_inds, a_sm_values = (
        sparse_csr_matrix_ops.csr_sparse_matrix_components(
            a_sm, 0, type=a_st.dtype))

    a_sm_row_ptrs_values, a_sm_col_inds_values, a_sm_values_values = (
        self.evaluate((a_sm_row_ptrs, a_sm_col_inds, a_sm_values)))

    self.assertAllEqual(a_csr_mat.indices, a_sm_col_inds_values)
    self.assertAllEqual(a_csr_mat.indptr, a_sm_row_ptrs_values)
    self.assertAllClose(a_values, a_sm_values_values)

    # Convert CSR Matrix to 2D SparseTensor
    a_st_rt = sparse_csr_matrix_ops.csr_sparse_matrix_to_sparse_tensor(
        a_sm, type=a_st.dtype)
    a_st_rt_value = self.evaluate(a_st_rt)

    self.assertAllEqual(a_indices, a_st_rt_value.indices)
    self.assertAllClose(a_values, a_st_rt_value.values)
    self.assertAllEqual(a_dense_shape, a_st_rt_value.dense_shape)

  def testSparseTensorConversionInvalidInputShapes(self):
    values = constant_op.constant(0.554979503, shape=[5], dtype=dtypes.float32)
    with self.assertRaisesRegex(
        (ValueError, errors.InvalidArgumentError), "must be rank 1"
    ):
      indices = constant_op.constant(0, shape=[5, 2], dtype=dtypes.int64)
      dense_shape = constant_op.constant(53, shape=[], dtype=dtypes.int64)
      csr = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
          indices=indices, values=values, dense_shape=dense_shape
      )
      self.evaluate(csr)

    with self.assertRaisesRegex(
        (ValueError, errors.InvalidArgumentError), "must be rank 2"
    ):
      indices = constant_op.constant(0, shape=[5], dtype=dtypes.int64)
      dense_shape = constant_op.constant(53, shape=[1], dtype=dtypes.int64)
      csr = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
          indices=indices, values=values, dense_shape=dense_shape
      )
      self.evaluate(csr)

    int32max = 2**31 - 1
    with self.assertRaisesRegex(
        (ValueError, errors.InvalidArgumentError),
        "batch_size must be < Int32Max",
    ):
      indices = constant_op.constant(0, shape=[5, 3], dtype=dtypes.int64)
      dense_shape = constant_op.constant(
          [int32max, 1, 1], shape=[3], dtype=dtypes.int64
      )
      csr = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
          indices=indices, values=values, dense_shape=dense_shape
      )
      self.evaluate(csr)

    with self.assertRaisesRegex(
        (ValueError, errors.InvalidArgumentError),
        "csr row index size.*must be <= Int32Max",
    ):
      indices = constant_op.constant(0, shape=[5, 3], dtype=dtypes.int64)
      dense_shape = constant_op.constant(
          [(int32max // 2), 10, 1], shape=[3], dtype=dtypes.int64
      )
      csr = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
          indices=indices, values=values, dense_shape=dense_shape
      )
      self.evaluate(csr)

    with self.assertRaisesRegex(
        (ValueError, errors.InvalidArgumentError),
        "Index rank .* and shape rank .* do not match",
    ):
      self.evaluate(
          sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
              indices=[[0, 0, 0], [0, 0, 1]],
              values=[10.0, 20.0],
              dense_shape=[33, 73],
          )
      )

  # TODO(b/139491352): Add handle_data propagation to array_ops.identity.
  @test_util.run_deprecated_v1
  def testCSRSparseMatrixResourceVariable(self):
    if not self._gpu_available:
      return

    sparsify = lambda m: m * (m > 0)
    dense_shape = [53, 65, 127]
    a_mats = sparsify(np.random.randn(*dense_shape)).astype(np.float32)

    a_sm = dense_to_csr_sparse_matrix(a_mats)
    with ops.device("/gpu:0"):
      v = variable_scope.get_variable("sm", initializer=a_sm, use_resource=True)
      v_id = array_ops.identity(v)
      self.assertEqual(
          sparse_csr_matrix_ops.dense_shape_and_type(v_id).shape, a_mats.shape)
      a_rt = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
          v, type=dtypes.float32)
    v_reassign = state_ops.assign(v, v_id).op
    with self.assertRaisesOpError("uninitialized"):
      self.evaluate(a_rt)
    self.evaluate(v.initializer)
    a_rt_value = self.evaluate(a_rt)
    self.assertAllClose(a_mats, a_rt_value)
    self.evaluate(v_reassign)
    a_rt_reassigned_value = self.evaluate(a_rt)
    self.assertAllClose(a_mats, a_rt_reassigned_value)

  @test_util.run_in_graph_and_eager_modes
  def testBatchSparseTensorConversion(self):
    a_indices = np.array([[0, 0, 0], [0, 2, 3], [2, 0, 1]])
    a_values = [1.0, 5.0, 6.0]
    a_dense_shape = [3, 5, 6]
    a_sparse_mats = [
        sparse.coo_matrix(([1.0, 5.0], ([0, 2], [0, 3])),
                          shape=a_dense_shape[1:]),
        sparse.coo_matrix(([], ([], [])), shape=a_dense_shape[1:]),
        sparse.coo_matrix(([6.0], ([0], [1])), shape=a_dense_shape[1:])
    ]
    a_csr_mats = [m.tocsr() for m in a_sparse_mats]

    # Convert 3D SparseTensor to CSR Matrix
    a_st = sparse_tensor.SparseTensor(a_indices, a_values, a_dense_shape)
    a_st = math_ops.cast(a_st, dtypes.float32)
    a_sm = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
        a_st.indices, a_st.values, a_st.dense_shape)

    # Get row indices and columns for batches.
    a_sm_components = [
        sparse_csr_matrix_ops.csr_sparse_matrix_components(
            a_sm, i, type=a_st.dtype) for i in range(3)
    ]

    a_sm_values = self.evaluate(a_sm_components)

    for i, (a_sm_val, a_csr_mat) in enumerate(zip(a_sm_values, a_csr_mats)):
      tf_logging.info("Comparing batch %d" % i)
      self.assertAllEqual(a_csr_mat.indptr, a_sm_val.row_ptrs)
      self.assertAllEqual(a_csr_mat.indices, a_sm_val.col_inds)
      self.assertAllClose(a_csr_mat.data, a_sm_val.values)

    # Convert CSR batched Matrix to 3D SparseTensor
    a_st_rt = sparse_csr_matrix_ops.csr_sparse_matrix_to_sparse_tensor(
        a_sm, type=a_st.dtype)
    a_st_rt_value = self.evaluate(a_st_rt)

    self.assertAllEqual(a_indices, a_st_rt_value.indices)
    self.assertAllClose(a_values, a_st_rt_value.values)
    self.assertAllEqual(a_dense_shape, a_st_rt_value.dense_shape)

  @test_util.run_in_graph_and_eager_modes
  def testLargeBatchSparseTensorConversion(self):
    # Test two sets of conversions to check behavior of the ops in a
    # concurrent environment (parallel executions of the ST -> SM ops).

    sparsify = lambda m: m * (m > 0)
    dense_shape = [53, 65, 127]

    mats = [
        sparsify(np.random.randn(*dense_shape)).astype(np.float32)
        for _ in range(2)
    ]
    csr_mats = [list(map(sparse.csr_matrix, mat)) for mat in mats]
    mats_t = [ops.convert_to_tensor(mat) for mat in mats]
    mats_locs = [array_ops.where(mat_t > 0) for mat_t in mats_t]
    sparse_tensors = list()
    for mat_t, mat_loc in zip(mats_t, mats_locs):
      sparse_tensors.append(
          sparse_tensor.SparseTensor(mat_loc,
                                     array_ops.gather_nd(mat_t,
                                                         mat_loc), dense_shape))
    sparse_matrices = [
        sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
            st.indices, st.values, st.dense_shape) for st in sparse_tensors
    ]
    sm_nnz = [
        sparse_csr_matrix_ops.sparse_matrix_nnz(sm) for sm in sparse_matrices
    ]

    # Get row indices and columns for batches.
    sm_components = list()
    for sm in sparse_matrices:
      sm_components.append([
          sparse_csr_matrix_ops.csr_sparse_matrix_components(
              sm, i, type=dtypes.float32) for i in range(dense_shape[0])
      ])

    sm_nnz_values, sm_values = self.evaluate((sm_nnz, sm_components))

    for i, (sm_values_i, csr_mats_i) in enumerate(zip(sm_values, csr_mats)):
      for b, (sm_val, csr_mat) in enumerate(zip(sm_values_i, csr_mats_i)):
        tf_logging.info("Comparing matrix %d batch %d" % (i, b))
        self.assertEqual(csr_mat.nnz, sm_nnz_values[i][b])
        self.assertAllEqual(csr_mat.indptr, sm_val.row_ptrs)
        self.assertAllEqual(csr_mat.indices, sm_val.col_inds)
        self.assertAllClose(csr_mat.data, sm_val.values)

    # Convert CSR batched Matrix to 3D SparseTensor
    st_rt = [
        sparse_csr_matrix_ops.csr_sparse_matrix_to_sparse_tensor(
            sm, type=dtypes.float32) for sm in sparse_matrices
    ]

    st_values, st_rt_values = self.evaluate((sparse_tensors, st_rt))

    for (st_value, st_rt_value) in zip(st_values, st_rt_values):
      self.assertAllEqual(st_value.indices, st_rt_value.indices)
      self.assertAllClose(st_value.values, st_rt_value.values)
      self.assertAllEqual(dense_shape, st_rt_value.dense_shape)

  @test_util.run_in_graph_and_eager_modes
  def testDenseConversion(self):
    a_indices = np.array([[0, 0], [2, 3], [2, 4], [3, 0]])
    a_values = np.array([1.0, 5.0, -1.0, -2.0]).astype(np.float32)
    a_dense_shape = [5, 6]
    a_sparse_mat = sparse.coo_matrix(
        (a_values, (a_indices[:, 0], a_indices[:, 1])), shape=a_dense_shape)
    a_csr_mat = a_sparse_mat.tocsr()
    a_dense = a_sparse_mat.todense()

    # Convert 2D SparseTensor to CSR Matrix
    a_sm = dense_to_csr_sparse_matrix(a_dense)

    # Get row indices and columns for batch 0.
    a_sm_row_ptrs, a_sm_col_inds, a_sm_values = (
        sparse_csr_matrix_ops.csr_sparse_matrix_components(
            a_sm, 0, type=dtypes.float32))

    a_sm_row_ptrs_values, a_sm_col_inds_values, a_sm_values_values = (
        self.evaluate((a_sm_row_ptrs, a_sm_col_inds, a_sm_values)))

    self.assertAllEqual(a_csr_mat.indices, a_sm_col_inds_values)
    self.assertAllEqual(a_csr_mat.indptr, a_sm_row_ptrs_values)
    self.assertAllClose(a_values, a_sm_values_values)

    # Convert CSR Matrix to 2D dense matrix
    a_rt = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
        a_sm, dtypes.float32)
    a_rt_value = self.evaluate(a_rt)

    self.assertAllEqual(a_dense, a_rt_value)

  @test_util.run_in_graph_and_eager_modes
  def testBatchDenseConversion(self):
    a_dense_shape = [4, 5, 6]
    a_sparse_mats = [
        sparse.coo_matrix(([1.0, 5.0], ([0, 2], [0, 3])),
                          shape=a_dense_shape[1:]),
        sparse.coo_matrix(([], ([], [])), shape=a_dense_shape[1:]),
        sparse.coo_matrix(([6.0], ([0], [1])), shape=a_dense_shape[1:]),
        sparse.coo_matrix(([], ([], [])), shape=a_dense_shape[1:]),
    ]
    a_csr_mats = [m.tocsr() for m in a_sparse_mats]
    a_dense = np.asarray([m.todense() for m in a_sparse_mats], dtype=np.float32)

    # Convert 3D SparseTensor to CSR Matrix
    a_sm = dense_to_csr_sparse_matrix(a_dense)

    # Get row indices and columns for batches.
    a_sm_components = [
        sparse_csr_matrix_ops.csr_sparse_matrix_components(
            a_sm, i, type=dtypes.float32) for i in range(3)
    ]

    a_sm_values = self.evaluate(a_sm_components)

    for i, (a_sm_val, a_csr_mat) in enumerate(zip(a_sm_values, a_csr_mats)):
      tf_logging.info("Comparing batch %d" % i)
      self.assertAllEqual(a_csr_mat.indptr, a_sm_val.row_ptrs)
      self.assertAllEqual(a_csr_mat.indices, a_sm_val.col_inds)
      self.assertAllClose(a_csr_mat.data, a_sm_val.values)

    # Convert CSR batched Matrix to 3D SparseTensor
    a_rt = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
        a_sm, type=dtypes.float32)
    a_rt_value = self.evaluate(a_rt)

    self.assertAllEqual(a_dense, a_rt_value)

  @test_util.run_in_graph_and_eager_modes
  def testLargeBatchDenseConversion(self):
    # Test two sets of conversions to check behavior of the ops in a
    # concurrent environment (parallel executions of the ST -> SM
    # ops).

    sparsify = lambda m: m * (m > 0)
    dense_shape = [53, 65, 127]

    mats = [
        sparsify(np.random.randn(*dense_shape)).astype(np.float32)
        for _ in range(2)
    ]
    csr_mats = [[sparse.csr_matrix(m) for m in mat] for mat in mats]
    mats_t = [ops.convert_to_tensor(mat) for mat in mats]
    mats_locs = [array_ops.where(mat_t > 0) for mat_t in mats_t]
    sparse_matrices = [
        sparse_csr_matrix_ops.dense_to_csr_sparse_matrix(mat, mat_loc)
        for (mat, mat_loc) in zip(mats_t, mats_locs)
    ]
    sm_nnz = [
        sparse_csr_matrix_ops.sparse_matrix_nnz(sm) for sm in sparse_matrices
    ]

    # Get row indices and columns for batches.
    sm_components = []
    for sm in sparse_matrices:
      sm_components.append([
          sparse_csr_matrix_ops.csr_sparse_matrix_components(
              sm, i, type=dtypes.float32) for i in range(dense_shape[0])
      ])

    sm_nnz_values, sm_values = self.evaluate((sm_nnz, sm_components))

    for i, (sm_values_i, csr_mats_i) in enumerate(zip(sm_values, csr_mats)):
      for b, (sm_val, csr_mat) in enumerate(zip(sm_values_i, csr_mats_i)):
        tf_logging.info("Comparing matrix %d batch %d" % (i, b))
        self.assertEqual(csr_mat.nnz, sm_nnz_values[i][b])
        self.assertAllEqual(csr_mat.indptr, sm_val.row_ptrs)
        self.assertAllEqual(csr_mat.indices, sm_val.col_inds)
        self.assertAllClose(csr_mat.data, sm_val.values)

    # Convert CSR batched Matrix to 3D dense tensor
    sm_rt = [
        sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
            sm, type=dtypes.float32) for sm in sparse_matrices
    ]

    sm_rt_values = self.evaluate(sm_rt)

    for (mat, sm_rt_value) in zip(mats, sm_rt_values):
      self.assertAllEqual(mat, sm_rt_value)

  @test_util.run_in_graph_and_eager_modes
  def testSparseMatrixAdd(self):
    if not self._gpu_available:
      return

    a_indices = np.array([[0, 0], [2, 3]])
    a_values = np.array([1.0, 5.0]).astype(np.float32)
    a_dense_shape = [5, 6]
    a_sparse_mat = sparse.coo_matrix(
        (a_values, (a_indices[:, 0], a_indices[:, 1])), shape=a_dense_shape)
    a_dense = a_sparse_mat.todense()

    b_indices = np.array([[1, 0], [1, 4], [2, 3], [4, 1]])
    b_values = np.array([1.0, 0.5, -5.0, 2.0]).astype(np.float32)
    b_dense_shape = [5, 6]
    b_sparse_mat = sparse.coo_matrix(
        (b_values, (b_indices[:, 0], b_indices[:, 1])), shape=b_dense_shape)
    b_dense = b_sparse_mat.todense()

    for (alpha, beta) in [(1.0, 1.0), (1.0, -1.0), (0.25, 0.5)]:
      a_sum_b_sparse_mat = alpha * a_sparse_mat + beta * b_sparse_mat

      # Convert 2D SparseTensor to CSR Matrix
      a_sm = dense_to_csr_sparse_matrix(a_dense)
      b_sm = dense_to_csr_sparse_matrix(b_dense)
      alpha = np.float32(alpha)
      beta = np.float32(beta)
      c_sm = sparse_csr_matrix_ops.sparse_matrix_add(
          a_sm, b_sm, alpha=alpha, beta=beta)
      c_dense = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
          c_sm, dtypes.float32)
      c_dense_value = self.evaluate(c_dense)

      self.assertAllClose(a_sum_b_sparse_mat.todense(), c_dense_value)

  @test_util.run_in_graph_and_eager_modes
  def testLargeBatchSparseMatrixAdd(self):
    if not self._gpu_available:
      return

    sparsify = lambda m: m * (m > 0)
    dense_shape = [53, 65, 127]
    a_mats = sparsify(np.random.randn(*dense_shape)).astype(np.float32)
    b_mats = sparsify(np.random.randn(*dense_shape)).astype(np.float32)
    for (alpha, beta) in [(1.0, 1.0), (1.0, -1.0), (0.25, 0.5)]:
      tf_logging.info("testLargeBatchSparseMatrixAdd, comparing "
                      "alpha, beta (%d, %d)" % (alpha, beta))
      a_sm = dense_to_csr_sparse_matrix(a_mats)
      b_sm = dense_to_csr_sparse_matrix(b_mats)
      alpha = np.float32(alpha)
      beta = np.float32(beta)
      c_sm = sparse_csr_matrix_ops.sparse_matrix_add(
          a_sm, b_sm, alpha=alpha, beta=beta)
      c_dense = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
          c_sm, dtypes.float32)
      c_dense_value = self.evaluate(c_dense)

      self.assertAllClose(c_dense_value, alpha * a_mats + beta * b_mats)

  @test_util.run_in_graph_and_eager_modes
  def testSparseMatrixMatMul(self):
    for shapes in [[(5, 6), (6, 1)], [(5, 6), (6, 2)]]:
      a_indices = np.array([[0, 0], [2, 3]])
      a_values = np.array([1.0, 5.0]).astype(np.float32)
      a_dense_shape = shapes[0]
      a_sparse_mat = sparse.coo_matrix(
          (a_values, (a_indices[:, 0], a_indices[:, 1])), shape=a_dense_shape)
      a_dense = a_sparse_mat.todense()

      # Will multiply sparse a (shape=shapes[0]) by dense b (shape=shapes[1]).
      b = np.random.randn(*shapes[1]).astype(np.float32)

      a_sm = dense_to_csr_sparse_matrix(a_dense)
      c = sparse_csr_matrix_ops.sparse_matrix_mat_mul(a=a_sm, b=b)
      c_value = self.evaluate(c)

      expected_c_value = a_sparse_mat.dot(b)
      self.assertAllClose(expected_c_value, c_value)

  @test_util.run_in_graph_and_eager_modes
  def testSparseMatrixMatMulConjugateOutput(self):
    for shapes in [[(5, 6), (6, 1)], [(5, 6), (6, 2)]]:
      a_indices = np.array([[0, 0], [2, 3]])
      a_values = np.array([1.0 + 1.j, 5.0 - 2.j]).astype(np.complex64)
      a_dense_shape = shapes[0]
      a_sparse_mat = sparse.coo_matrix(
          (a_values, (a_indices[:, 0], a_indices[:, 1])), shape=a_dense_shape)
      a_dense = a_sparse_mat.todense()

      # Will multiply sparse a (shape=shapes[0]) by dense b (shape=shapes[1]).
      b = np.random.randn(*shapes[1]).astype(np.complex64)

      a_sm = dense_to_csr_sparse_matrix(a_dense)
      c = sparse_csr_matrix_ops.sparse_matrix_mat_mul(
          a=a_sm, b=b, conjugate_output=True)
      c_value = self.evaluate(c)

      expected_c_value = self.evaluate(
          math_ops.conj(test_util.matmul_without_tf32(a_dense, b)))
      self.assertAllClose(expected_c_value, c_value)

  @parameterized.product(
      dtype=[np.float32, np.complex64],
      transpose=[(False, False), (False, True), (True, False), (True, True)],
      adjoint=[(False, False), (False, True), (True, False), (True, True)],
      shapes=[([53, 127, 65], [53, 65, 1]), ([53, 127, 1], [53, 1, 65]),
              ([53, 127, 65], [53, 65, 127])])
  @test_util.run_in_graph_and_eager_modes
  def testLargeBatchSparseMatrixMatMul(self, dtype, transpose, adjoint, shapes):
    sparsify = lambda m: m * (m > 0)
    transpose_a, transpose_b = transpose
    adjoint_a, adjoint_b = adjoint
    if (transpose_a and adjoint_a) or (transpose_b and adjoint_b):
      return
    # Make copies so we don't update the lists inside the decorator arguments.
    a_dense_shape = shapes[0][:]
    b_dense_shape = shapes[1][:]
    if transpose_a or adjoint_a:
      _swap(a_dense_shape, -2, -1)
    if transpose_b or adjoint_b:
      _swap(b_dense_shape, -2, -1)
    a_mats = sparsify((np.random.randn(*a_dense_shape) +
                       1.j * np.random.randn(*a_dense_shape))).astype(dtype)
    b_mats = (np.random.randn(*b_dense_shape) +
              1.j * np.random.randn(*b_dense_shape)).astype(dtype)
    tf_logging.info(
        "testLargeBatchSparseMatrixMatMul transpose_a %s transpose_b "
        "%s adjoint_a %s adjoint_b %s" %
        (transpose_a, transpose_b, adjoint_a, adjoint_b))
    a_sm = dense_to_csr_sparse_matrix(a_mats)
    c_t = sparse_csr_matrix_ops.sparse_matrix_mat_mul(
        a_sm,
        b_mats,
        transpose_output=False,
        conjugate_output=False,
        transpose_a=transpose_a,
        transpose_b=transpose_b,
        adjoint_a=adjoint_a,
        adjoint_b=adjoint_b)
    c_dense_t = test_util.matmul_without_tf32(
        a_mats,
        b_mats,
        transpose_a=transpose_a,
        transpose_b=transpose_b,
        adjoint_a=adjoint_a,
        adjoint_b=adjoint_b)
    self.assertAllEqual(c_dense_t.shape, c_t.shape)
    c_t_value, c_dense_t_value = self.evaluate((c_t, c_dense_t))

    self.assertAllClose(c_t_value, c_dense_t_value, rtol=1e-6, atol=2e-5)

  @parameterized.product(
      dtype=[np.float32, np.complex64],
      transpose=[(False, False), (False, True), (True, False), (True, True)],
      adjoint=[(False, False), (False, True), (True, False), (True, True)],
      shapes=[[[53, 127, 65], [53, 65, 1]], [[53, 127, 1], [53, 1, 65]],
              [[53, 127, 65], [53, 65, 127]]])
  @test_util.run_in_graph_and_eager_modes
  def testLargeBatchSparseMatrixMatMulTransposed(self, dtype, transpose,
                                                 adjoint, shapes):
    sparsify = lambda m: m * (m > 0)
    (transpose_a, transpose_b) = transpose
    (adjoint_a, adjoint_b) = adjoint
    if (transpose_a and adjoint_a) or (transpose_b and adjoint_b):
      return
    # Make copies so we don't update the lists inside the decorator arguments.
    a_dense_shape = shapes[0][:]
    b_dense_shape = shapes[1][:]
    if transpose_a or adjoint_a:
      _swap(a_dense_shape, -2, -1)
    if transpose_b or adjoint_b:
      _swap(b_dense_shape, -2, -1)
    a_mats = sparsify((np.random.randn(*a_dense_shape) +
                       1.j * np.random.randn(*a_dense_shape))).astype(dtype)
    b_mats = (np.random.randn(*b_dense_shape) +
              1.j * np.random.randn(*b_dense_shape)).astype(dtype)
    tf_logging.info(
        "testLargeBatchSparseMatrixMatMul transpose_a %s transpose_b "
        "%s adjoint_a %s adjoint_b %s" %
        (transpose_a, transpose_b, adjoint_a, adjoint_b))
    a_sm = dense_to_csr_sparse_matrix(a_mats)
    c_t = sparse_csr_matrix_ops.sparse_matrix_mat_mul(
        a_sm,
        b_mats,
        transpose_output=True,
        conjugate_output=False,
        transpose_a=transpose_a,
        transpose_b=transpose_b,
        adjoint_a=adjoint_a,
        adjoint_b=adjoint_b)

    # Example: t(adj(a) . b) = t(b) . conj(a)
    c_dense_t = test_util.matmul_without_tf32(
        math_ops.conj(b_mats) if adjoint_b else b_mats,
        math_ops.conj(a_mats) if adjoint_a else a_mats,
        transpose_a=not (transpose_b or adjoint_b),
        transpose_b=not (transpose_a or adjoint_a),
        adjoint_a=False,
        adjoint_b=False)
    self.assertAllEqual(c_t.shape, c_dense_t.shape)
    c_t_value, c_dense_t_value = self.evaluate((c_t, c_dense_t))
    self.assertAllClose(c_t_value, c_dense_t_value, rtol=1e-6, atol=2e-5)

  @test_util.run_in_graph_and_eager_modes
  def testLargeBatchSparseMatrixMatMulConjugate(self):
    sparsify = lambda m: m * (m > 0)
    a_dense_shape = [53, 65, 127]
    b_dense_shape = [53, 127, 67]
    a_mats = sparsify(
        (np.random.randn(*a_dense_shape) +
         1.j * np.random.randn(*a_dense_shape))).astype(np.complex64)
    b_mats = (np.random.randn(*b_dense_shape) +
              1.j * np.random.randn(*b_dense_shape)).astype(np.complex64)
    a_sm = dense_to_csr_sparse_matrix(a_mats)
    c_t = sparse_csr_matrix_ops.sparse_matrix_mat_mul(
        a_sm, b_mats, conjugate_output=True)

    c_dense_t = math_ops.conj(test_util.matmul_without_tf32(a_mats, b_mats))
    self.assertAllEqual(c_t.shape, c_dense_t.shape)
    c_t_value, c_dense_t_value = self.evaluate((c_t, c_dense_t))

    self.assertAllClose(c_t_value, c_dense_t_value, atol=1e-5, rtol=1e-5)

  @test_util.run_in_graph_and_eager_modes
  def testSparseMatrixSparseMatMul(self):
    a_indices = np.array([[0, 0], [2, 3]])
    a_values = np.array([1.0, 5.0]).astype(np.float32)
    a_dense_shape = [5, 6]
    a_sparse_mat = sparse.coo_matrix(
        (a_values, (a_indices[:, 0], a_indices[:, 1])), shape=a_dense_shape)
    a_dense = a_sparse_mat.todense()

    b_indices = np.array([[0, 0], [3, 0], [3, 1]])
    b_values = np.array([2.0, 7.0, 8.0]).astype(np.float32)
    b_dense_shape = [6, 7]
    b_sparse_mat = sparse.coo_matrix(
        (b_values, (b_indices[:, 0], b_indices[:, 1])), shape=b_dense_shape)
    b_dense = b_sparse_mat.todense()

    a_sm = dense_to_csr_sparse_matrix(a_dense)
    b_sm = dense_to_csr_sparse_matrix(b_dense)
    c_sm = sparse_csr_matrix_ops.sparse_matrix_sparse_mat_mul(
        a=a_sm, b=b_sm, type=dtypes.float32)

    c_sm_dense = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
        c_sm, dtypes.float32)
    c_sm_dense_value = self.evaluate(c_sm_dense)

    expected_c_value = a_sparse_mat.dot(b_sparse_mat).todense()
    self.assertAllClose(expected_c_value, c_sm_dense_value)

  @test_util.run_in_graph_and_eager_modes
  def testSparseMatrixSparseMatMul_NumericZerosNotPruned(self):
    # Tests that numeric zeros appearing from the sparse-sparse matrix
    # multiplication are not pruned from the sparse structural
    a_indices = np.array([[0, 0], [0, 2]])
    a_values = np.array([2.0, -1.0]).astype(np.float32)
    a_dense_shape = [2, 3]
    a_sparse_mat = sparse.coo_matrix(
        (a_values, (a_indices[:, 0], a_indices[:, 1])), shape=a_dense_shape)
    a_dense = a_sparse_mat.todense()

    b_indices = np.array([[0, 1], [2, 1]])
    b_values = np.array([3.0, 6.0]).astype(np.float32)
    b_dense_shape = [3, 2]
    b_sparse_mat = sparse.coo_matrix(
        (b_values, (b_indices[:, 0], b_indices[:, 1])), shape=b_dense_shape)
    b_dense = b_sparse_mat.todense()

    # Convert to CSRSparseMatrix while removing numeric zeros from the
    # structural representation.
    a_sm = dense_to_csr_sparse_matrix(a_dense)
    b_sm = dense_to_csr_sparse_matrix(b_dense)

    # Compute the matmul.
    c_sm = sparse_csr_matrix_ops.sparse_matrix_sparse_mat_mul(
        a=a_sm, b=b_sm, type=dtypes.float32)
    c_nnz = sparse_csr_matrix_ops.sparse_matrix_nnz(c_sm)
    c_nnz_value = self.evaluate(c_nnz)

    # Expect that there is a single numeric zero at index (0, 1) if zeros are
    # not pruned, since 2.0 * 3.0 + (-1.0) * 6.0 = 0.0.
    self.assertAllClose(1, c_nnz_value)

  @test_util.run_in_graph_and_eager_modes
  def testLargeBatchSparseMatrixSparseMatMul(self):
    sparsify = lambda m: m * (m > 0)

    for (transpose_a, transpose_b) in ((False, False), (False, True),
                                       (True, False), (True, True)):
      for (adjoint_a, adjoint_b) in ((False, False), (False, True),
                                     (True, False), (True, True)):
        if (transpose_a and adjoint_a) or (transpose_b and adjoint_b):
          continue

        for a_batch_size in (1, 53):
          for b_batch_size in (1, 53):
            a_dense_shape = (
                [a_batch_size, 127, 65]
                if transpose_a or adjoint_a
                else [a_batch_size, 65, 127]
            )
            b_dense_shape = (
                [b_batch_size, 67, 127]
                if transpose_b or adjoint_b
                else [b_batch_size, 127, 67]
            )

            a_mats = sparsify(np.random.randn(*a_dense_shape)).astype(
                np.float32
            )
            b_mats = sparsify(
                np.random.randn(*b_dense_shape).astype(np.float32)
            )

            a_sm = dense_to_csr_sparse_matrix(a_mats)
            b_sm = dense_to_csr_sparse_matrix(b_mats)
            c_sm = sparse_csr_matrix_ops.sparse_matrix_sparse_mat_mul(
                a_sm,
                b_sm,
                type=dtypes.float32,
                transpose_a=transpose_a,
                adjoint_a=adjoint_a,
                transpose_b=transpose_b,
                adjoint_b=adjoint_b,
            )
            c_sm_dense = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
                c_sm, dtypes.float32
            )
            c_dense_t = test_util.matmul_without_tf32(
                a_mats,
                b_mats,
                transpose_a=transpose_a,
                adjoint_a=adjoint_a,
                transpose_b=transpose_b,
                adjoint_b=adjoint_b,
            )
            c_dense_t_value, c_sm_dense_value = self.evaluate(
                (c_dense_t, c_sm_dense)
            )

            self.assertAllClose(c_sm_dense_value, c_dense_t_value)

  @test_util.run_in_graph_and_eager_modes
  def testLargeBatchRegisteredAddN(self):
    if not self._gpu_available:
      return

    sparsify = lambda m: m * (m > 0)
    dense_shape = [53, 65, 127]
    matrices = [
        sparsify(np.random.randn(*dense_shape)).astype(np.float32)
        for _ in range(16)
    ]
    sparse_matrices = [dense_to_csr_sparse_matrix(mat) for mat in matrices]
    sparse_matrices_sum = math_ops.add_n(sparse_matrices)
    sparse_matrices_sum_dense = \
        sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
            sparse_matrices_sum, dtypes.float32)
    sparse_matrices_sum_dense_value = self.evaluate(sparse_matrices_sum_dense)

    # Ensure that the dense (numpy) sum across all batches matches the result
    # of add_n converted back to dense.
    expected_sum = np.sum(matrices, axis=0)
    self.assertAllClose(expected_sum, sparse_matrices_sum_dense_value)

  @test_util.run_in_graph_and_eager_modes
  def testCSRZeros(self):
    if not self._gpu_available:
      return
    a_dense_shape = [65, 127]
    b_dense_shape = [53, 127, 67]
    data_types = [
        dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128
    ]
    for dtype in data_types:
      # Check both rank-2 and rank-3 tensors.
      a_sm = sparse_csr_matrix_ops.sparse_matrix_zeros(
          a_dense_shape, type=dtype)
      b_sm = sparse_csr_matrix_ops.sparse_matrix_zeros(
          b_dense_shape, type=dtype)
      a_rt = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(a_sm, type=dtype)
      b_rt = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(b_sm, type=dtype)
      a_rt_value, b_rt_value = self.evaluate((a_rt, b_rt))

      self.assertAllEqual(a_rt_value, np.zeros(a_dense_shape))
      self.assertAllEqual(b_rt_value, np.zeros(b_dense_shape))

  @test_util.run_in_graph_and_eager_modes
  def testLargeBatchZerosLike(self):
    if not self._gpu_available:
      return

    batch_size = 53
    rows = 128
    cols = 67
    dense_shape = [batch_size, rows, cols]
    data_types = [
        dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128
    ]
    for dtype in data_types:
      sparse_matrices = sparse_csr_matrix_ops.sparse_matrix_zeros(
          dense_shape, type=dtype)
      zeros_like_sparse_matrices = array_ops.zeros_like(sparse_matrices)
      zeros_like_components = [
          sparse_csr_matrix_ops.csr_sparse_matrix_components(
              zeros_like_sparse_matrices, i, type=dtype)
          for i in range(batch_size)
      ]
      zeros_like_components_values = self.evaluate(zeros_like_components)
      for component in zeros_like_components_values:
        self.assertAllEqual(component.row_ptrs, np.zeros(rows + 1, np.int32))
        self.assertAllEqual(component.col_inds, np.empty([0], np.int32))
        self.assertAllEqual(component.values, np.empty([0],
                                                       dtype.as_numpy_dtype))

  @test_util.run_in_graph_and_eager_modes
  def testTranspose(self):
    sparsify = lambda m: m * (m > 0)
    dense_shape = [127, 65]
    data_types = [
        dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128
    ]
    for dtype in data_types:
      mats = sparsify(
          (np.random.randn(*dense_shape) +
           1.j * np.random.randn(*dense_shape))).astype(dtype.as_numpy_dtype)
      for conjugate in False, True:
        expected = np.transpose(mats)
        if conjugate:
          expected = np.conj(expected)
        matrices = math_ops.cast(mats, dtype)
        sparse_matrices = dense_to_csr_sparse_matrix(matrices)
        transpose_sparse_matrices = \
            sparse_csr_matrix_ops.sparse_matrix_transpose(
                sparse_matrices, conjugate=conjugate, type=dtype)
        dense_transposed = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
            transpose_sparse_matrices, dtype)
        dense_transposed_values = self.evaluate(dense_transposed)
        self.assertAllClose(expected, dense_transposed_values)

  @test_util.run_in_graph_and_eager_modes
  def testLargeBatchTranspose(self):
    sparsify = lambda m: m * (m > 0)
    dense_shape = [53, 65, 127]
    data_types = [
        dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128
    ]
    for dtype in data_types:
      mats = sparsify(
          (np.random.randn(*dense_shape) +
           1.j * np.random.randn(*dense_shape))).astype(dtype.as_numpy_dtype)
      expected = np.transpose(mats, (0, 2, 1))
      for conjugate in False, True:
        if conjugate:
          expected = np.conj(expected)
        matrices = math_ops.cast(mats, dtype)
        sparse_matrices = dense_to_csr_sparse_matrix(matrices)
        transpose_sparse_matrices = \
            sparse_csr_matrix_ops.sparse_matrix_transpose(
                sparse_matrices, conjugate=conjugate, type=dtype)
        dense_transposed = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
            transpose_sparse_matrices, dtype)
        dense_transposed_values = self.evaluate(dense_transposed)
        self.assertAllClose(expected, dense_transposed_values)

  @test_util.run_in_graph_and_eager_modes
  def testSoftmax(self):
    if not self._gpu_available:
      return

    sparsify = lambda m: m * (m > 0)
    dense_shape = [127, 65]
    logits = sparsify(np.random.randn(*dense_shape))
    logits_with_ninf = np.copy(logits)
    logits_with_ninf[logits == 0] = -np.inf
    data_types = [dtypes.float32, dtypes.float64]
    for dtype in data_types:
      logits_t = math_ops.cast(logits, dtype)
      logits_t_with_ninf = math_ops.cast(logits_with_ninf, dtype)
      expected = nn_ops.softmax(logits_t_with_ninf)
      sparse_logits_t = dense_to_csr_sparse_matrix(logits_t)
      softmax_sparse_logits_t = sparse_csr_matrix_ops.sparse_matrix_softmax(
          sparse_logits_t, type=dtype)
      dense_softmax = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
          softmax_sparse_logits_t, dtype)
      dense_softmax_values, expected_values = self.evaluate(
          (dense_softmax, expected))
      self.assertAllClose(expected_values, dense_softmax_values)

  @test_util.run_in_graph_and_eager_modes
  def testLargeBatchSoftmax(self):
    if not self._gpu_available:
      return

    sparsify = lambda m: m * (m > 0)
    dense_shape = [53, 65, 127]
    logits = sparsify(np.random.randn(*dense_shape))
    logits_with_ninf = np.copy(logits)
    logits_with_ninf[logits == 0] = -np.inf
    data_types = [dtypes.float32, dtypes.float64]
    for dtype in data_types:
      logits_t = math_ops.cast(logits, dtype)
      logits_t_with_ninf = math_ops.cast(logits_with_ninf, dtype)
      expected = nn_ops.softmax(logits_t_with_ninf)
      sparse_logits_t = dense_to_csr_sparse_matrix(logits_t)
      softmax_sparse_logits_t = sparse_csr_matrix_ops.sparse_matrix_softmax(
          sparse_logits_t, type=dtype)
      dense_softmax = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
          softmax_sparse_logits_t, dtype)
      dense_softmax_values, expected_values = self.evaluate(
          (dense_softmax, expected))
      self.assertAllClose(expected_values, dense_softmax_values)

  @test_util.run_in_graph_and_eager_modes
  def testLargeBatchSoftmaxEmpty(self):
    if not self._gpu_available:
      return

    dense_shape = [53, 65, 127]
    sparse_logits_t = sparse_csr_matrix_ops.sparse_matrix_zeros(
        dense_shape, type=dtypes.float32)
    softmax_sparse_logits_t = sparse_csr_matrix_ops.sparse_matrix_softmax(
        sparse_logits_t, type=dtypes.float32)
    dense_softmax = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
        softmax_sparse_logits_t, dtypes.float32)
    dense_softmax_values = self.evaluate(dense_softmax)
    self.assertAllEqual(
        np.zeros_like(dense_softmax_values), dense_softmax_values)

  @test_util.run_in_graph_and_eager_modes
  def testSoftmaxGrad(self):
    if not self._gpu_available:
      return

    sparsify = lambda m: m * (m > 0)
    dense_shape = [127, 65]
    softmax = sparsify(np.random.randn(*dense_shape))
    grad_softmax = sparsify(np.random.randn(*dense_shape))
    expected = (
        (grad_softmax - np.sum(grad_softmax * softmax, -1, keepdims=True)) *
        softmax)
    data_types = [dtypes.float32, dtypes.float64]
    for dtype in data_types:
      softmax_t = math_ops.cast(softmax, dtype)
      grad_softmax_t = math_ops.cast(grad_softmax, dtype)
      softmax_sparse = dense_to_csr_sparse_matrix(softmax_t)
      grad_softmax_sparse = dense_to_csr_sparse_matrix(grad_softmax_t)
      gradients_sparse = sparse_csr_matrix_ops.sparse_matrix_softmax_grad(
          softmax_sparse, grad_softmax_sparse, dtype)
      dense_gradients = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
          gradients_sparse, dtype)
      dense_gradients_values = self.evaluate((dense_gradients))
      self.assertAllClose(expected, dense_gradients_values)

  @test_util.run_in_graph_and_eager_modes
  def testLargeBatchSoftmaxGrad(self):
    if not self._gpu_available:
      return

    sparsify = lambda m: m * (m > 0)
    dense_shape = [53, 65, 127]
    softmax = sparsify(np.random.randn(*dense_shape))
    grad_softmax = sparsify(np.random.randn(*dense_shape))
    expected = (
        (grad_softmax - np.sum(grad_softmax * softmax, -1, keepdims=True)) *
        softmax)
    data_types = [dtypes.float32, dtypes.float64]
    for dtype in data_types:
      softmax_t = math_ops.cast(softmax, dtype)
      grad_softmax_t = math_ops.cast(grad_softmax, dtype)
      softmax_sparse = dense_to_csr_sparse_matrix(softmax_t)
      grad_softmax_sparse = dense_to_csr_sparse_matrix(grad_softmax_t)
      gradients_sparse = sparse_csr_matrix_ops.sparse_matrix_softmax_grad(
          softmax_sparse, grad_softmax_sparse, dtype)
      dense_gradients = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
          gradients_sparse, dtype)
      dense_gradients_values = self.evaluate((dense_gradients))
      self.assertAllClose(expected, dense_gradients_values)

  @test_util.run_in_graph_and_eager_modes
  def testLargeBatchSoftmaxGradEmpty(self):
    if not self._gpu_available:
      return

    sparsify = lambda m: m * (m > 0)
    dense_shape = [53, 65, 127]
    not_empty = sparsify(np.random.randn(*dense_shape)).astype(np.float32)
    sparse_empty = sparse_csr_matrix_ops.sparse_matrix_zeros(
        dense_shape, type=dtypes.float32)
    sparse_not_empty = dense_to_csr_sparse_matrix(not_empty)
    gradients_empty_softmax = sparse_csr_matrix_ops.sparse_matrix_softmax_grad(
        sparse_empty, sparse_not_empty, dtypes.float32)
    gradients_empty_grad_softmax = (
        sparse_csr_matrix_ops.sparse_matrix_softmax_grad(
            sparse_not_empty, sparse_empty, dtypes.float32))
    gradients_empty_both = sparse_csr_matrix_ops.sparse_matrix_softmax_grad(
        sparse_empty, sparse_empty, dtypes.float32)
    ges = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
        gradients_empty_softmax, dtypes.float32)
    gegs = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
        gradients_empty_grad_softmax, dtypes.float32)
    geb = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
        gradients_empty_both, dtypes.float32)
    ges_v, gegs_v, geb_v = self.evaluate((ges, gegs, geb))
    for v in (ges_v, gegs_v, geb_v):
      self.assertAllEqual(np.zeros(dense_shape), v)

  @test_util.run_in_graph_and_eager_modes
  def testLargeBatchConj(self):
    if not self._gpu_available:
      return

    sparsify = lambda m: m * (np.real(m) > 0)
    dense_shape = [53, 65, 127]
    matrices = (
        sparsify(np.random.randn(*dense_shape)) +
        1j * np.random.randn(*dense_shape))
    data_types = [
        dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128
    ]
    for dtype in data_types:
      matrices_t = matrices.astype(dtype.as_numpy_dtype)
      expected = np.conj(matrices_t)
      sparse_matrices = dense_to_csr_sparse_matrix(matrices_t)
      conj_sparse_matrices = math_ops.conj(sparse_matrices)
      dense_conj_matrices = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
          conj_sparse_matrices, dtype)
      conj_values = self.evaluate(dense_conj_matrices)
      self.assertAllClose(expected, conj_values)

  @test_util.run_in_graph_and_eager_modes
  def testLargeBatchSparseMatrixMulScalar(self):
    if not self._gpu_available:
      return

    sparsify = lambda m: m * (m > 0)
    a_dense_shape = [53, 65, 127]
    a_mats = sparsify(np.random.randn(*a_dense_shape)).astype(np.float32)
    b = np.float32(3.5)
    expected = a_mats * b
    a_sm = dense_to_csr_sparse_matrix(a_mats)
    c_t = sparse_csr_matrix_ops.sparse_matrix_mul(a_sm, b)
    c_dense_t = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
        c_t, dtypes.float32)
    c_dense_t_value = self.evaluate(c_dense_t)

    self.assertAllClose(expected, c_dense_t_value)

  @test_util.run_in_graph_and_eager_modes
  def testLargeBatchSparseMatrixMulVec(self):
    if not self._gpu_available:
      return

    sparsify = lambda m: m * (m > 0)
    a_dense_shape = [53, 65, 127]
    a_mats = sparsify(np.random.randn(*a_dense_shape)).astype(np.float32)
    b = np.random.randn(53, 1, 1).astype(np.float32)
    expected = a_mats * b
    a_sm = dense_to_csr_sparse_matrix(a_mats)
    c_t = sparse_csr_matrix_ops.sparse_matrix_mul(a_sm, b)
    c_dense_t = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
        c_t, dtypes.float32)
    c_dense_t_value = self.evaluate(c_dense_t)

    self.assertAllClose(expected, c_dense_t_value)

  @test_util.run_in_graph_and_eager_modes
  def testSparseCholesky(self):
    dense_matrix = np.array([
        [2, 0, 0, 0, 0, 0],
        [0, 3, 0, 0, 0, 0],
        [1, 1, 7, 0, 0, 0],
        [0, 0, 0, 4, 0, 0],
        [0, 0, 1, 0, 5, 0],
        [0, 0, 2, 0, 1, 6],
    ]).astype(np.complex128)

    data_types = [
        dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128
    ]
    for dtype in data_types:
      with test_util.force_cpu():
        if dtype.is_complex:
          dense_matrix += 0.5j * np.tril(dense_matrix, -1)

        sparse_matrix = dense_to_csr_sparse_matrix(
            math_ops.cast(dense_matrix, dtype))
        # Obtain the Sparse Cholesky factor using AMD Ordering for reducing
        # fill-in.
        ordering_amd = sparse_csr_matrix_ops.sparse_matrix_ordering_amd(
            sparse_matrix)
        cholesky_sparse_matrices = (
            sparse_csr_matrix_ops.sparse_matrix_sparse_cholesky(
                sparse_matrix, ordering_amd, type=dtype))
        dense_cholesky = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
            cholesky_sparse_matrices, dtype)
        # Compute L * Lh where L is the Sparse Cholesky factor.
        verification = test_util.matmul_without_tf32(
            dense_cholesky, array_ops.transpose(dense_cholesky, conjugate=True))
        verification = twist_matrix(verification, ordering_amd)
        # Assert that input matrix A satisfies A = L * Lh.
        verification_values = self.evaluate(verification)
        full_dense_matrix = (
            dense_matrix +
            np.conjugate(np.transpose(np.tril(dense_matrix, -1))))
        self.assertAllClose(full_dense_matrix, verification_values)

  @test_util.run_in_graph_and_eager_modes
  def testBatchSparseCholesky(self):
    dense_mat = np.array([
        # A diagonal matrix.
        [
            [1, 0, 0, 0],  #
            [0, 2, 0, 0],  #
            [0, 0, 3, 0],  #
            [0, 0, 0, 4],
        ],  #
        # A tridiagonal hermitian matrix.
        [
            [5 + 0j, 1 + 0j, 0 + 0j, 0 + 0j],  #
            [1 + 0j, 4 + 0j, 1 + 2j, 0 + 0j],  #
            [0 + 0j, 1 - 2j, 9 + 0j, 3 - 3j],  #
            [0 + 0j, 0 + 0j, 3 + 3j, 7 + 0j],
        ],  #
        # A diagonal matrix with a corner element; for which
        # OrderingAMD returns a non-identity permutation.
        [
            [1, 0, 0, 1.],  #
            [0, 2, 0, 0.],  #
            [0, 0, 3, 0.],  #
            [1, 0, 0, 4.],
        ]  #
    ]).astype(np.complex128)

    data_types = [
        dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128
    ]
    for dtype in data_types:
      sparse_matrix = dense_to_csr_sparse_matrix(
          math_ops.cast(dense_mat, dtype))
      ordering_amd = sparse_csr_matrix_ops.sparse_matrix_ordering_amd(
          sparse_matrix)

      cholesky_sparse_matrix = (
          sparse_csr_matrix_ops.sparse_matrix_sparse_cholesky(
              sparse_matrix, ordering_amd, type=dtype))
      dense_cholesky = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
          cholesky_sparse_matrix, dtype)

      # Compute L * Lh.
      verification = test_util.matmul_without_tf32(
          dense_cholesky,
          array_ops.transpose(dense_cholesky, perm=[0, 2, 1], conjugate=True))
      verification = twist_matrix(verification, ordering_amd)

      verification_values = self.evaluate(verification)
      self.assertAllClose(
          dense_mat.astype(dtype.as_numpy_dtype), verification_values)

  @test_util.run_in_graph_and_eager_modes
  def testLargeBatchSparseCholesky(self):
    sparsity = 0.1
    sparsify = lambda m: m * (m > 1 - sparsity)

    batch_size = 53
    num_rows = 147
    dense_shape = [batch_size, num_rows, num_rows]

    dense_matrix = sparsify(np.random.uniform(size=dense_shape)).astype(
        np.float32)

    # Create a "random" SPD matrix, by choosing each entry of A between
    # 0 and 1 at the specified density, and computing 0.5(A + At) + n*I.
    # This ensures diagonal dominance which implies positive-definiteness.
    dense_matrix = (
        0.5 *
        (dense_matrix + array_ops.transpose(dense_matrix, perm=[0, 2, 1])) +
        num_rows * linalg_ops.eye(dense_shape[-1], batch_shape=[batch_size]))
    # Compute the fill-in reducing permutation and use it to perform
    # the Sparse Cholesky factorization.
    sparse_matrix = dense_to_csr_sparse_matrix(dense_matrix)
    ordering_amd = sparse_csr_matrix_ops.sparse_matrix_ordering_amd(
        sparse_matrix)

    cholesky_sparse_matrix = \
        sparse_csr_matrix_ops.sparse_matrix_sparse_cholesky(
            sparse_matrix, ordering_amd, type=dtypes.float32)
    dense_cholesky = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
        cholesky_sparse_matrix, dtypes.float32)

    # Compute L * Lh.
    verification = test_util.matmul_without_tf32(
        dense_cholesky, array_ops.transpose(dense_cholesky, perm=[0, 2, 1]))
    verification = twist_matrix(verification, ordering_amd)
    verification_values = self.evaluate(verification)
    self.assertAllClose(dense_matrix, verification_values, atol=1e-5, rtol=1e-5)

  @test_util.run_in_graph_and_eager_modes
  def testSparseCholesky_InvalidMatrix(self):
    # Verify that non-SPD matrices result in an Invalid Argument error.
    invalid_matrices = [
        # zero matrix.
        np.array([
            [0., 0., 0., 0.],  #
            [0., 0., 0., 0.],  #
            [0., 0., 0., 0.],  #
            [0., 0., 0., 0.]  #
        ]),
        # zero diagonal entry.
        np.array([
            [9., 0., 5., 0.],  #
            [0., 0., 0., 1.],  #
            [5., 0., 8., 0.],  #
            [0., 1., 0., 7.]  #
        ]),
        # not positive definite.
        np.array([
            [2., -2., 0., 0.],  #
            [-2., 2., 0., 0.],  #
            [0., 0., 3., -3.],  #
            [0., 0., -3., 3.]  #
        ]),
    ]

    with test_util.force_cpu():
      for invalid_matrix in invalid_matrices:
        with self.assertRaises(errors.InvalidArgumentError):
          sparse_matrix = dense_to_csr_sparse_matrix(
              invalid_matrix.astype(np.float32))
          # Compute the fill-in reducing permutation and use it to perform
          # the Sparse Cholesky factorization.
          ordering_amd = sparse_csr_matrix_ops.sparse_matrix_ordering_amd(
              sparse_matrix)
          cholesky_sparse_matrices = (
              sparse_csr_matrix_ops.sparse_matrix_sparse_cholesky(
                  sparse_matrix, ordering_amd, type=dtypes.float32))
          # Convert the Cholesky factor to a dense matrix to be evaluated.
          dense_cholesky = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
              cholesky_sparse_matrices, type=dtypes.float32)
          self.evaluate(dense_cholesky)

  @test_util.run_in_graph_and_eager_modes
  def testOrderingAMD(self):
    num_rows = 6
    # An SPD matrix where AMD ordering can reduce fill-in for Cholesky factor.
    dense_matrix = np.array([
        [7, 0, 0, 0, 0, 0],
        [1, 4, 0, 0, 0, 0],
        [1, 1, 3, 0, 0, 0],
        [0, 0, 0, 4, 0, 0],
        [2, 0, 0, 0, 5, 0],
        [1, 2, 2, 0, 0, 6],
    ]).astype(np.float32)

    with test_util.force_cpu():
      sparse_matrix = dense_to_csr_sparse_matrix(dense_matrix)

      # Obtain the Sparse Cholesky factor with the identity permutation as the
      # fill-in reducing ordering.
      cholesky_without_ordering = (
          sparse_csr_matrix_ops.sparse_matrix_sparse_cholesky(
              sparse_matrix, math_ops.range(num_rows), type=dtypes.float32))
      cholesky_without_ordering_nnz = sparse_csr_matrix_ops.sparse_matrix_nnz(
          cholesky_without_ordering)

      # Obtain the Sparse Cholesky factor using AMD Ordering for reducing
      # fill-in.
      ordering_amd = sparse_csr_matrix_ops.sparse_matrix_ordering_amd(
          sparse_matrix)
      cholesky_with_amd = sparse_csr_matrix_ops.sparse_matrix_sparse_cholesky(
          sparse_matrix, ordering_amd, type=dtypes.float32)
      cholesky_with_amd_nnz = sparse_csr_matrix_ops.sparse_matrix_nnz(
          cholesky_with_amd)

      (ordering_amd_value, cholesky_with_amd_nnz_value,
       cholesky_without_ordering_nnz_value) = self.evaluate(
           [ordering_amd, cholesky_with_amd_nnz, cholesky_without_ordering_nnz])

      # AMD ordering should return a valid permutation.
      self.assertAllClose(np.arange(num_rows), np.sort(ordering_amd_value))
      # Check that cholesky with AMD ordering has a strictly lower nonzero count
      # for this matrix.
      self.assertLess(cholesky_with_amd_nnz_value,
                      cholesky_without_ordering_nnz_value)

  @test_util.run_in_graph_and_eager_modes
  def testNoMatrixNoCrash(self):
    # Round-about way of creating an empty variant tensor that works in both
    # graph and eager modes.
    no_matrix = array_ops.reshape(dense_to_csr_sparse_matrix([[0.0]]), [1])[0:0]
    with self.assertRaisesRegex(
        (ValueError, errors.InvalidArgumentError),
        "(Invalid input matrix)|(Shape must be rank 0)"):
      sparse_csr_matrix_ops.sparse_matrix_nnz(no_matrix)


class CSRSparseMatrixOpsBenchmark(test.Benchmark):

  def benchmark_sparse_matrix_mat_mul_gpu(self):
    if not test_util.is_gpu_available():
      return

    sparsify = lambda m: array_ops.where(m > 2, m, array_ops.zeros_like(m))

    # XW, X dense and W sparse
    # X is shaped [{1, 8, 16}, 2000]
    # W is shaped [2000, 4000]

    for batch_size in [1, 8, 16]:
      x_dense_shape = [batch_size, 2000]
      w_dense_shape = [2000, 4000]

      with ops.Graph().as_default(), ops.device("/gpu:0"):
        x_mats = random_ops.random_normal(x_dense_shape, dtype=dtypes.float32)
        w_mats = sparsify(
            random_ops.random_normal(w_dense_shape, dtype=dtypes.float32))
        nnz = array_ops.shape(array_ops.where(w_mats))[0]
        ratio = math_ops.cast(nnz, dtypes.float32) / np.prod(w_dense_shape)
        w_sm = dense_to_csr_sparse_matrix(w_mats)
        with ops.name_scope("w_sm_var"):
          w_sm_var = variable_scope.get_variable(
              "sm", initializer=w_sm, use_resource=True)
          w_sm_var_v = w_sm_var.read_value()
        with ops.name_scope("w_var"):
          w_var = variable_scope.get_variable(
              "sm_dense", initializer=w_mats, use_resource=True)
          w_var_v = w_var.read_value()
        with ops.name_scope("b"):
          x = variable_scope.get_variable(
              "b", initializer=x_mats, use_resource=True)
          x_v = x.read_value()
        # X*W = (W'*X')'
        xw_sparse = sparse_csr_matrix_ops.sparse_matrix_mat_mul(
            w_sm_var_v,
            x_v,
            transpose_a=True,
            transpose_b=True,
            transpose_output=True)
        xw_dense = math_ops.matmul(x_v, w_var_v)

        with session.Session() as sess:
          self.evaluate(
              [w_var.initializer, w_sm_var.initializer, x.initializer])
          nnz_value, ratio_value = self.evaluate((nnz, ratio))
          name_template = (
              "sparse_matrix_mat_mul_gpu_%s_W_2000x4000_batch_size_%d")
          self.run_op_benchmark(
              sess,
              xw_sparse.op,
              name=name_template % ("sparse", batch_size),
              extras={
                  "percentage_nonzero": ratio_value,
                  "num_nonzero": nnz_value
              },
              min_iters=50)
          self.run_op_benchmark(
              sess,
              xw_dense.op,
              name=name_template % ("dense", batch_size),
              extras={
                  "percentage_nonzero": ratio_value,
                  "num_nonzero": nnz_value
              },
              min_iters=50)

  def benchmark_sparse_matrix_mat_vec_mul(self):
    # num_rows, device, transpose.
    cases = [
        [2000, CPU, False],
        [8000, CPU, False],
        [12000, CPU, False],
        [2000, CPU, True],
        [8000, CPU, True],
        [12000, CPU, True],
    ]
    seed = 42

    for num_rows, device, transpose in cases:
      if device == GPU and not test_util.is_gpu_available():
        continue
      for num_threads in [1, 2, 4, 6, 8, 10]:
        device_str = "cpu" if device == CPU else "gpu"
        w_dense_shape = [num_rows, num_rows]
        x_dense_shape = [num_rows, 1]

        with ops.Graph().as_default(), ops.device(device):
          random_seed.set_random_seed(seed)
          x = random_ops.random_normal(x_dense_shape, dtype=dtypes.float32)
          w_np = sparse.rand(
              w_dense_shape[0],
              w_dense_shape[1],
              density=0.01,
              dtype=np.float32,
              random_state=np.random.RandomState(seed))
          w_st = sparse_tensor.SparseTensor(
              zip(w_np.row, w_np.col), w_np.data, w_np.shape)
          w_st = sparse_ops.sparse_reorder(w_st)

          nnz = array_ops.shape(w_st.values)[0]
          ratio = math_ops.cast(nnz, dtypes.float32) / np.prod(w_np.shape)

          w_sm = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
              w_st.indices, w_st.values, w_st.dense_shape)
          xw_sparse_matrix = sparse_csr_matrix_ops.sparse_matrix_mat_mul(
              w_sm,
              x,
              transpose_a=transpose,
              transpose_b=False,
              transpose_output=False)
          xw_sparse_tensor = sparse_ops.sparse_tensor_dense_matmul(
              w_st, x, adjoint_a=transpose, adjoint_b=False)

          with session.Session(
              config=config_pb2.ConfigProto(
                  intra_op_parallelism_threads=num_threads)) as sess:
            nnz_value, ratio_value = sess.run((nnz, ratio))
            name_template = ("mat_vec_mul_%s_%s_W_%d_transpose_%s_threads_%d")
            self.run_op_benchmark(
                sess,
                xw_sparse_matrix.op,
                name=name_template %
                (device_str, "sparse_matrix", num_rows, transpose, num_threads),
                extras={
                    "percentage_nonzero": ratio_value,
                    "num_nonzero": nnz_value,
                },
                min_iters=10)
            self.run_op_benchmark(
                sess,
                xw_sparse_tensor.op,
                name=name_template %
                (device_str, "sparse_tensor", num_rows, transpose, num_threads),
                extras={
                    "percentage_nonzero": ratio_value,
                    "num_nonzero": nnz_value,
                },
                min_iters=10)

  def benchmark_sparse_matrix_sparse_matmul(self):
    density = 0.05
    # pylint: disable=g-long-lambda
    sparsify = lambda m: array_ops.where(m > 1. - density, m,
                                         array_ops.zeros_like(m))
    # pylint: enable=g-long-lambda

    for batch_size in [1, 16]:
      for num_threads in [1, 4, 12]:
        dense_shape = [batch_size, 250, 250]

        for device in [CPU, GPU]:
          if device == GPU and not test_util.is_gpu_available():
            continue

          with ops.Graph().as_default(), ops.device(device):
            x_mats = sparsify(
                random_ops.random_uniform(dense_shape, dtype=dtypes.float32))
            y_mats = sparsify(
                random_ops.random_uniform(dense_shape, dtype=dtypes.float32))

            nnz = array_ops.shape(array_ops.where(x_mats))[0] + array_ops.shape(
                array_ops.where(y_mats))[0]
            ratio = math_ops.cast(nnz,
                                  dtypes.float32) / (2 * np.prod(dense_shape))

            x_sm = dense_to_csr_sparse_matrix(x_mats)
            y_sm = dense_to_csr_sparse_matrix(y_mats)

            xy_sparse = sparse_csr_matrix_ops.sparse_matrix_sparse_mat_mul(
                x_sm, y_sm, type=dtypes.float32)

            with session.Session(
                config=config_pb2.ConfigProto(
                    intra_op_parallelism_threads=num_threads)) as sess:
              nnz_value, ratio_value = self.evaluate((nnz, ratio))
              name_template = (
                  "sparse_matrix_sparse_matmul_%s_N_%d_batch_size_%d_threads_%d"
              )
              device_str = "cpu" if device == CPU else "gpu"
              self.run_op_benchmark(
                  sess,
                  xy_sparse.op,
                  name=name_template %
                  (device_str, dense_shape[-1], batch_size, num_threads),
                  extras={
                      "percentage_nonzero": ratio_value,
                      "num_nonzero": nnz_value
                  },
                  min_iters=50)

  def benchmark_sparse_dense_conversion(self):
    sparsity = 0.05

    for batch_size in [1, 16]:
      for num_threads in [1, 4, 12]:
        dense_shape = [batch_size, 750, 750]

        for device in [CPU, GPU]:
          if device == GPU and not test_util.is_gpu_available():
            continue

          with ops.Graph().as_default(), ops.device(device):
            mats = random_ops.random_uniform(dense_shape, dtype=dtypes.float32)
            mats_locs = array_ops.where(mats > 1.0 - sparsity)

            sparse_matrices = sparse_csr_matrix_ops.dense_to_csr_sparse_matrix(
                mats, mats_locs)
            dense_matrices = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
                sparse_matrices, type=dtypes.float32)
            nnz = math_ops.reduce_sum(
                sparse_csr_matrix_ops.sparse_matrix_nnz(sparse_matrices))
            ratio = math_ops.cast(nnz, dtypes.float32) / np.prod(dense_shape)

            with session.Session(
                config=config_pb2.ConfigProto(
                    intra_op_parallelism_threads=num_threads)) as sess:
              nnz_value, ratio_value = self.evaluate((nnz, ratio))
              device_str = "cpu" if device == CPU else "gpu"
              name_template = (
                  "dense_to_sparse_matrix_%s_N_%d_batch_size_%d_num_threads_%d")
              self.run_op_benchmark(
                  sess,
                  sparse_matrices.op,
                  name=name_template %
                  (device_str, dense_shape[-1], batch_size, num_threads),
                  extras={
                      "percentage_nonzero": ratio_value,
                      "num_nonzero": nnz_value,
                  },
                  min_iters=50)
              name_template = (
                  "sparse_matrix_to_dense_%s_N_%d_batch_size_%d_num_threads_%d")
              self.run_op_benchmark(
                  sess,
                  dense_matrices.op,
                  name=name_template %
                  (device_str, dense_shape[-1], batch_size, num_threads),
                  extras={
                      "percentage_nonzero": ratio_value,
                      "num_nonzero": nnz_value,
                  },
                  min_iters=50)

  def benchmark_sparse_cholesky(self):
    # TODO(anudhyan): Use conversions from SparseTensor instead of to get this
    # benchmark working for larger matrices. For this to work without GPU, we
    # need to write CPU kernels for SparseTensor conversions.
    num_rows = 500
    density = 0.01
    # pylint: disable=g-long-lambda
    sparsify = lambda m: array_ops.where(m > 1. - density, m,
                                         array_ops.zeros_like(m))
    # pylint: enable=g-long-lambda

    for batch_size in [1, 16]:
      for num_threads in [1, 4, 12]:
        dense_shape = [batch_size, num_rows, num_rows]

        with ops.Graph().as_default(), ops.device(CPU):
          # Create a "random" SPD matrix, by choosing each entry of A between
          # 0 and 1 at the specified density, and computing 0.5(A + At) + n*I.
          # This ensures diagonal dominance which implies positive-definiteness.
          dense_matrix = sparsify(
              random_ops.random_uniform(dense_shape, dtype=dtypes.float32))
          spd_dense_matrix = (
              0.5 *
              (dense_matrix + array_ops.transpose(dense_matrix, perm=[0, 2, 1]))
              + num_rows *
              linalg_ops.eye(dense_shape[-1], batch_shape=[batch_size]))

          # Convert to SparseMatrix and invoke Sparse Cholesky factorization
          # with AMD Ordering.
          sparse_matrix = dense_to_csr_sparse_matrix(spd_dense_matrix)
          ordering_amd = sparse_csr_matrix_ops.sparse_matrix_ordering_amd(
              sparse_matrix)
          cholesky_sparse_matrix = (
              sparse_csr_matrix_ops.sparse_matrix_sparse_cholesky(
                  sparse_matrix, ordering_amd, type=dtypes.float32))

          nnz = math_ops.reduce_sum(
              sparse_csr_matrix_ops.sparse_matrix_nnz(sparse_matrix))
          ratio = math_ops.cast(nnz, dtypes.float32) / np.prod(dense_shape)
          ordering_amd_name_template = (
              "sparse_matrix_ordering_amd_cpu_N_%d_batch_size_%d_threads_%d")
          sparse_cholesky_name_template = (
              "sparse_matrix_sparse_cholesky_cpu_N_%d_batch_size_%d_threads_%d")
          with session.Session(
              config=config_pb2.ConfigProto(
                  intra_op_parallelism_threads=num_threads)) as sess:
            nnz_value, ratio_value = self.evaluate((nnz, ratio))
            self.run_op_benchmark(
                sess,
                ordering_amd.op,
                name=ordering_amd_name_template %
                (dense_shape[-1], batch_size, num_threads),
                extras={
                    "percentage_nonzero": ratio_value,
                    "num_nonzero": nnz_value
                },
                min_iters=25)
            self.run_op_benchmark(
                sess,
                cholesky_sparse_matrix.op,
                name=sparse_cholesky_name_template %
                (dense_shape[-1], batch_size, num_threads),
                extras={
                    "percentage_nonzero": ratio_value,
                    "num_nonzero": nnz_value
                },
                min_iters=25)


if __name__ == "__main__":
  test.main()