if torch.isnan(self.K_step).any():
                print("Encounter multistep K with nan, replace it by identity matrix")
                self.K_step = (
                    torch.eye(self.K_step.shape[1])
                    .to(self.K_step.device)