def distribution(self, context=None):
        mu, chol_sigma = self._get_mean_and_chol(context)
        return torch.distributions.MultivariateNormal(loc=mu, scale_tril=chol_sigma, validate_args=False)