Garve/scikit-bonus

View on GitHub
skbonus/meta/_zero_inflated_regressor.py

Summary

Maintainability
A
0 mins
Test Coverage
from __future__ import annotations

from typing import Any

import numpy as np
from sklearn.base import (
    BaseEstimator,
    RegressorMixin,
    clone,
    is_regressor,
    is_classifier,
)
from sklearn.exceptions import NotFittedError
from sklearn.utils.validation import check_is_fitted, check_X_y, check_array


class ZeroInflatedRegressor(BaseEstimator, RegressorMixin):
    """
    A meta regressor for zero-inflated datasets, i.e. the targets contain a lot of zeroes.

    `ZeroInflatedRegressor` consists of a classifier and a regressor.

        - The classifier's task is to find of if the target is zero or not.
        - The regressor's task is to output a (usually positive) prediction whenever the classifier indicates that the there should be a non-zero prediction.

    The regressor is only trained on examples where the target is non-zero, which makes it easier for it to focus.

    At prediction time, the classifier is first asked if the output should be zero. If yes, output zero.
    Otherwise, ask the regressor for its prediction and output it.

    Parameters
    ----------
    classifier : Any, scikit-learn classifier
        A classifier that answers the question "Should the output be zero?".

    regressor : Any, scikit-learn regressor
        A regressor for predicting the target. Its prediction is only used if `classifier` says that the output is non-zero.

    Examples
    --------
    >>> import numpy as np
    >>> from sklearn.ensemble import ExtraTreesClassifier, ExtraTreesRegressor
    >>> np.random.seed(0)
    >>> X = np.random.randn(10000, 4)
    >>> y = ((X[:, 0]>0) & (X[:, 1]>0)) * np.abs(X[:, 2] * X[:, 3]**2)
    >>> z = ZeroInflatedRegressor(
    ... classifier=ExtraTreesClassifier(random_state=0),
    ... regressor=ExtraTreesRegressor(random_state=0)
    ... )
    >>> z.fit(X, y)
    ZeroInflatedRegressor(classifier=ExtraTreesClassifier(random_state=0),
                          regressor=ExtraTreesRegressor(random_state=0))
    >>> z.predict(X)[:5]
    array([4.91483294, 0.        , 0.        , 0.04941909, 0.        ])
    """

    _required_parameters = ["classifier", "regressor"]

    def __init__(self, classifier: Any, regressor: Any) -> None:
        """Initialize."""
        self.classifier = classifier
        self.regressor = regressor

    def _fit_estimator(self, estimator_type: str, X: np.ndarray, y: np.ndarray) -> None:
        if estimator_type == 'classifier':
            estimator = self.classifier
        elif estimator_type == 'regressor':
            estimator = self.regressor
        else:
            raise ValueError(f"Estimator type should be `classifier` or `regressor`, got {estimator_type}.")

        try:
            check_is_fitted(estimator)
            self.estimators_[estimator_type] = estimator
        except NotFittedError:
            self.estimators_[estimator_type] = clone(estimator)
            self.estimators_[estimator_type].fit(X, y)

    def _check_estimators(self) -> None:
        if not is_classifier(self.classifier):
            raise ValueError(
                f"`classifier` has to be a classifier. Received instance of {type(self.classifier)} instead."
            )
        if not is_regressor(self.regressor):
            raise ValueError(
                f"`regressor` has to be a regressor. Received instance of {type(self.regressor)} instead."
            )

    def fit(self, X: np.ndarray, y: np.ndarray) -> ZeroInflatedRegressor:
        """
        Fit the model.

        Parameters
        ----------
        X : np.ndarray of shape (n_samples, n_features)
            The training data.

        y : np.ndarray, 1-dimensional
            The target values.

        Returns
        -------
        ZeroInflatedRegressor
            Fitted regressor.
        """
        X, y = check_X_y(X, y)
        self._check_n_features(X, reset=True)
        self._check_estimators()

        self.estimators_ = {}
        self._fit_estimator('classifier', X, y != 0)
        non_zero_indices = np.where(self.estimators_['classifier'].predict(X) == 1)[0]
        if non_zero_indices.size > 0:
            self._fit_estimator('regressor', X[non_zero_indices], y[non_zero_indices])
        else:
            raise ValueError(
                "The predicted training labels are all zero, making the regressor obsolete. Change the classifier or use a plain regressor instead."
            )

        return self

    def predict(self, X: np.ndarray) -> np.ndarray:
        """
        Get predictions.

        Parameters
        ----------
        X : np.ndarray, shape (n_samples, n_features)
            Samples to get predictions of.

        Returns
        -------
        y : np.ndarray, shape (n_samples,)
            The predicted values.
        """
        check_is_fitted(self)
        X = check_array(X)
        self._check_n_features(X, reset=False)

        output = np.zeros(len(X))
        non_zero_indices = np.where(self.estimators_['classifier'].predict(X))[0]

        if non_zero_indices.size > 0:
            output[non_zero_indices] = self.estimators_['regressor'].predict(X[non_zero_indices])

        return output