if hps.output_dist == 'poisson':
      # Enforce correct dtype
      assert np.issubdtype(
          datasets[hps.dataset_names[0]]['train_data'].dtype, int), \
          "Data dtype must be int for poisson output distribution"