for diags, pair in tests.items():
          solution, _ = pair
          mat_diag = array_ops.matrix_diag_part(mat[0], k=diags, align=align)
          self.assertEqual(mat_diag.get_shape(), solution[0].shape)
          self.assertAllEqual(mat_diag, solution[0])