mat_set_diag_batch = np.array([[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0],
                                      [1.0, 0.0, -3.0]],
                                     [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0],
                                      [2.0, 0.0, -6.0]]]).astype(dtype)