mat_set_diag_batch = np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]],
                                     [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]])