if torch.isnan(self.K).any():
            print("Encounter K with nan, replace K by identity matrix")
            self.K = (
                torch.eye(self.K.shape[1])
                .to(self.K.device)