if self.use_prediction_distributions:
            # Get indices of the max log-probability.
            preds = data_streams[self.key_predictions].max(1)[1].data.cpu().numpy()
        else: 
            preds = data_streams[self.key_predictions].data.cpu().numpy()