MushroomRL/mushroom-rl

View on GitHub
mushroom_rl/approximators/ensemble.py

Summary

Maintainability
A
3 hrs
Test Coverage
C
77%
import numpy as np
import torch
from sklearn.exceptions import NotFittedError

from mushroom_rl.core.serialization import Serializable


class Ensemble(Serializable):
    """
    This class is used to create an ensemble of regressors.

    """
    def __init__(self, model, n_models, prediction='mean', **params):
        """
        Constructor.

        Args:
            approximator (class): the model class to approximate the
                Q-function.
            n_models (int): number of regressors in the ensemble;
            prediction (str, ['mean', 'sum', 'min', 'max']): the type of
                prediction to make;
            **params: parameters dictionary to create each regressor.

        """
        self._model = list()
        self._prediction = prediction

        for _ in range(n_models):
            self._model.append(model(**params))

        self._add_save_attr(
            _model=self._get_serialization_method(model),
            _prediction='primitive'
        )

    def fit(self, *z, idx=None, **fit_params):
        """
        Fit the ``idx``-th model of the ensemble if ``idx`` is provided, every
        model otherwise.

        Args:
            *z: a list containing the inputs to use to predict with each
                regressor of the ensemble;
            idx (int, None): index of the model to fit;
            **fit_params: other params.

        """
        if idx is None:
            for i in range(len(self)):
                self[i].fit(*z, **fit_params)
        else:
            self[idx].fit(*z, **fit_params)

    def predict(self, *z, idx=None, prediction=None, compute_variance=False,
                **predict_params):
        """
        Predict.

        Args:
            *z: a list containing the inputs to use to predict with each
                regressor of the ensemble;
            idx (int, None): index of the model to use for prediction;
            prediction (str, None): the type of prediction to make. When
                provided, it overrides the ``prediction`` class attribute;
            compute_variance (bool, False): whether to compute the variance
                of the prediction or not;
            **predict_params: other parameters used by the predict method
                the regressor.

        Returns:
            The predictions of the model.

        """
        if idx is None:
            idx = [x for x in range(len(self))]

        torch_output = False

        if isinstance(idx, int):
            try:
                results = self[idx].predict(*z, **predict_params)
            except NotFittedError:
                raise NotFittedError
        else:
            predictions = list()
            for i in idx:
                try:
                    predictions.append(self[i].predict(*z, **predict_params))
                except NotFittedError:
                    raise NotFittedError

            prediction = self._prediction if prediction is None else prediction
            if isinstance(predictions[0], np.ndarray):
                predictions = np.array(predictions)
            else:
                predictions = torch.stack(predictions, axis=0)
                torch_output = True

            if prediction == 'mean':
                results = predictions.mean(0)
            elif prediction == 'sum':
                results = predictions.sum(0)
            elif prediction == 'min':
                results = predictions.min(0)
            elif prediction == 'max':
                results = predictions.max(0)
            else:
                raise ValueError
            if compute_variance:
                results = [results, predictions.var(0)]

        if torch_output and (prediction == 'min' or prediction == 'max'):
            return results.values
        else:
            return results

    def reset(self):
        """
        Reset the model parameters.

        """
        try:
            for m in self.model:
                m.reset()
        except AttributeError:
            raise NotImplementedError('Attempt to reset weights of a'
                                      ' non-parametric regressor.')

    @property
    def model(self):
        """
        Returns:
            The list of the models in the ensemble.

        """
        return self._model

    def __len__(self):
        return len(self._model)

    def __getitem__(self, idx):
        return self._model[idx]