HunterMcGushion/hyperparameter_hunter

View on GitHub
hyperparameter_hunter/callbacks/bases.py

Summary

Maintainability
B
4 hrs
Test Coverage
"""This module defines the base callback classes, from which all other callback classes in
:mod:`hyperparameter_hunter.callbacks` are descendants. Importantly, the specific base callback
classes contained herein are all descendants of
:class:`hyperparameter_hunter.callbacks.bases.BaseCallback`, ensuring all callbacks descend from the
same base class. This module also defines
:func:`hyperparameter_hunter.callbacks.bases.lambda_callback`, which can be used to define custom
callbacks to be executed during Experiments when passed to
:meth:`hyperparameter_hunter.environment.Environment.__init__` via the `experiment_callbacks`
argument

Related
-------
:mod:`hyperparameter_hunter.callbacks`
    The rest of the submodules within this module should define classes which all descend from the
    base callback classes defined in :mod:`hyperparameter_hunter.callbacks.bases`
:mod:`hyperparameter_hunter.experiment_core`
    This is where callback classes are added as bases inherited by
    :class:`hyperparameter_hunter.experiments.BaseExperiment`. This module is the path that links
    :mod:`hyperparameter_hunter.callbacks` to :mod:`hyperparameter_hunter.experiments`
:mod:`hyperparameter_hunter.environment`
    This module provides the means to use custom callbacks made by
    :func:`hyperparameter_hunter.callbacks.bases.lambda_callback` through the `experiment_callbacks`
    argument of :meth:`hyperparameter_hunter.environment.Environment.__init__`"""
##################################################
# Import Own Assets
##################################################
from hyperparameter_hunter.i_o.exceptions import DeprecatedWarning
from hyperparameter_hunter.settings import G

##################################################
# Import Miscellaneous Assets
##################################################
from inspect import signature
import numpy as np
from uuid import uuid4 as uuid
from warnings import warn


DEPRECATED = object()


class BaseCallback(object):
    """The base class from which all callbacks and all intermediate base callbacks are descendants.
    Callback classes' :meth:`__init__` will not be called, so any tasks that must be performed at
    the onset of an experiment should be placed in :meth:`on_exp_start`

    Notes
    -----
    __init__(): Some classes that inherit :class:`BaseCallback` may implement :meth:`__init__`;
    however, callback classes are NEVER INITIALIZED. Callback classes should be regarded as
    extensions to the Experiment class that is inheriting them. Because they are inherited, callback
    classes have access to all attributes of the inheriting Experiment class. If any callback class
    does implement :meth:`__init__`, it is simply a convention to clearly declare the attributes
    required by the other methods of that callback class. Again, :meth:`__init__` of classes that
    inherit :class:`BaseCallback` will not be called

    The methods below each call :meth:`settings.G.debug` to signal that the dynamic callback
    inheritance organized in :class:`experiment_core.ExperimentMeta` has proceeded at least
    partially successfully. If all callback methods end by executing the method of the same name in
    their parent classes, then the below debug messages should be visible in the "Heartbeat.log"
    file. Conversely, if any of the below debug messages are not printed to "Heartbeat.log", it is
    likely that a callback class's implementation of the corresponding method does not end with
    "super().<method_name>()". Such cases should be remedied immediately, as the callback stream
    could be skipping any number of other callbacks"""

    # FLAG: Try to implement something like below to ensure other attributes aren't modified (except predictions by Predictors)
    # FLAG: However, since ExperimentMeta makes BaseCallback a superclass of the Experiment classes, they would all pick up...
    # FLAG: ... this method, which would break everything.
    # FLAG: Solution 1: Check name of class that is calling __setattr__. If Callback, apply constraints, else setattr normally
    # FLAG: Solution 2: Try to override __setattr__ method for all classes that require normal behavior; however, ...
    # FLAG: ... uncertain how or where the overridden method might be inherited
    # def __setattr__(self, key, value):
    #     if key == 'stat_aggregates':
    #         self.__dict__[key] = value

    def on_exp_start(self):
        """Perform tasks when an Experiment is started"""
        G.debug("BaseCallback.on_exp_start()")

    def on_exp_end(self):
        """Perform tasks when an Experiment ends"""
        G.debug("BaseCallback.on_exp_end()")

    def on_rep_start(self):
        """Perform tasks on repetition start in an Experiment's repeated cross-validation scheme"""
        G.debug("BaseCallback.on_rep_start()")

    def on_rep_end(self):
        """Perform tasks on repetition end in an Experiment's repeated cross-validation scheme"""
        G.debug("BaseCallback.on_rep_end()")

    def on_fold_start(self):
        """Perform tasks on fold start in an Experiment's cross-validation scheme"""
        G.debug("BaseCallback.on_fold_start()")

    def on_fold_end(self):
        """Perform tasks on fold end in an Experiment's cross-validation scheme"""
        G.debug("BaseCallback.on_fold_end()")

    def on_run_start(self):
        """Perform tasks on run start in an Experiment's multiple-run-averaging phase"""
        G.debug("BaseCallback.on_run_start()")

    def on_run_end(self):
        """Perform tasks on run end in an Experiment's multiple-run-averaging phase"""
        G.debug("BaseCallback.on_run_end()")


##################################################
# LambdaCallback
##################################################
def lambda_callback(
    on_exp_start=None,
    on_exp_end=None,
    on_rep_start=None,
    on_rep_end=None,
    on_fold_start=None,
    on_fold_end=None,
    on_run_start=None,
    on_run_end=None,
    agg_name=None,
    do_reshape_aggs=True,
    method_agg_keys=False,
    on_experiment_start=DEPRECATED,
    on_experiment_end=DEPRECATED,
    on_repetition_start=DEPRECATED,
    on_repetition_end=DEPRECATED,
):
    """Utility for creating custom callbacks to be declared by :class:`Environment` and used by
    Experiments. The callable "on_<...>_<start/end>" parameters provided will receive as input
    whichever attributes of the Experiment are included in the signature of the given callable. If
    `\*\*kwargs` is given in the callable's signature, a dict of all of the Experiment's attributes
    will be provided. This can be helpful for trying to figure out how to build a custom callback,
    but should not be used unless absolutely necessary. If the Experiment does not have an attribute
    specified in the callable's signature, the following placeholder will be given: "INVALID KWARG"

    Parameters
    ----------
    on_exp_start: Callable, or None, default=None
        Callable that receives Experiment's values for parameters in the signature at Experiment start
    on_exp_end: Callable, or None, default=None
        Callable that receives Experiment's values for parameters in the signature at Experiment end
    on_rep_start: Callable, or None, default=None
        Callable that receives Experiment's values for parameters in the signature at repetition start
    on_rep_end: Callable, or None, default=None
        Callable that receives Experiment's values for parameters in the signature at repetition end
    on_fold_start: Callable, or None, default=None
        Callable that receives Experiment's values for parameters in the signature at fold start
    on_fold_end: Callable, or None, default=None
        Callable that receives Experiment's values for parameters in the signature at fold end
    on_run_start: Callable, or None, default=None
        Callable that receives Experiment's values for parameters in the signature at run start
    on_run_end: Callable, or None, default=None
        Callable that receives Experiment's values for parameters in the signature at run end
    agg_name: Str, default=uuid.uuid4
        This parameter is only used if the callables are behaving like AggregatorCallbacks by
        returning values (see the "Notes" section below for details on this). If the callables do
        return values, they will be stored under a key named ("_" + `agg_name`) in a dict in
        :attr:`hyperparameter_hunter.experiments.BaseExperiment.stat_aggregates`. The purpose of
        this parameter is to make it easier to understand an Experiment's description file, as
        `agg_name` will default to a UUID if it is not given
    do_reshape_aggs: Boolean, default=True
        Whether to reshape the aggregated values to reflect the nested repetitions/folds/runs
        structure used for other aggregated values. If False, lists of aggregated values are left in
        their original shapes. This parameter is only used if the callables are behaving like
        AggregatorCallbacks (see the "Notes" section below and `agg_name` for details on this)
    method_agg_keys: Boolean, default=False
        If True, the aggregate keys for the items added to the dict at `agg_name` are equivalent to
        the names of the "on_<...>_<start/end>" pseudo-methods whose values are being aggregated. In
        other words, the pool of all possible aggregate keys goes from ["runs", "folds", "reps",
        "final"] to the names of the eight "on_<...>_<start/end>" kwargs of :func:`lambda_callback`.
        See the "Notes" section below for further details and a rough outline
    on_experiment_start: ...
        .. deprecated:: 3.0.0
            Renamed to `on_exp_start`. Will be removed in 3.2.0
    on_experiment_end: ...
        .. deprecated:: 3.0.0
            Renamed to `on_exp_end`. Will be removed in 3.2.0
    on_repetition_start: ...
        .. deprecated:: 3.0.0
            Renamed to `on_rep_start`. Will be removed in 3.2.0
    on_repetition_end: ...
        .. deprecated:: 3.0.0
            Renamed to `on_rep_end`. Will be removed in 3.2.0

    Returns
    -------
    LambdaCallback: :class:`LambdaCallback`
        Uninitialized class, whose methods are the callables of the corresponding "on..." kwarg

    Notes
    -----
    For all of the "on_<...>_<start/end>" callables provided as input to `lambda_callback`, consider
    the following guidelines (for example function "f", which can represent any of the callables):

    - All input parameters in the signature of "f" are attributes of the Experiment being executed

        - If "\*\*kwargs" is a parameter, a dict of all the Experiment's attributes will be provided

    - "f" will be treated as a method of a parent class of the Experiment

        - Take care when modifying attributes, as changes are reflected in the Experiment itself

    - If "f" returns something, it will automatically behave like an AggregatorCallback (see :mod:`hyperparameter_hunter.callbacks.aggregators`). Specifically, the following will occur:

        - A new key (named by `agg_name` if given, else a UUID) with a dict value is added to :attr:`hyperparameter_hunter.experiments.BaseExperiment.stat_aggregates`

            - This new dict can have up to four keys: "runs" (list), "folds" (list), "reps" (list), and "final" (object)

        - If "f" is an "on_run..." function, the returned value is appended to the "runs" list in the new dict
        - Similarly, if "f" is an "on_fold..." or "on_rep..." function, the returned value is appended to the "folds", or "reps" list, respectively
        - If "f" is an "on_exp..." function, the "final" key in the new dict is set to the returned value
        - If values were aggregated in the aforementioned manner, the lists of collected values will be reshaped according to runs/folds/reps on Experiment end
        - The aggregated values will be saved in the Experiment's description file

            - This is because :attr:`hyperparameter_hunter.experiments.BaseExperiment.stat_aggregates` is saved in its entirety

    What follows is a rough outline of the structure produced when using an aggregator-like callback
    that automatically populates :attr:`experiments.BaseExperiment.stat_aggregates` with results of
    the functions used as arguments to :func:`lambda_callback`::

        BaseExperiment.stat_aggregates = dict(
            ...,
            <`agg_name`>=dict(
                <agg_key "runs">  = [...],
                <agg_key "folds"> = [...],
                <agg_key "reps">  = [...],
                <agg_key "final"> = object(),
                ...
            ),
            ...
        )

    In the above outline, the actual `agg_key`s included in the dict at `agg_name` depend on which
    "on_<...>_<start/end>" callables are behaving like aggregators. For example, if neither
    `on_run_start` nor `on_run_end` explicitly returns something, then the "runs" `agg_key` is not
    included in the `agg_name` dict. Similarly, if, for example, neither `on_exp_start` nor
    `on_exp_end` is provided, then the "final" `agg_key` is not included. If `method_agg_keys=True`,
    then the agg keys used in the dict are modified to be named after the method called. For
    example, if `method_agg_keys=True` and `on_fold_start` and `on_fold_end` are both callables
    returning values to be aggregated, then the `agg_key`s used for each will be "on_fold_start"
    and "on_fold_end", respectively. In this example, if `method_agg_keys=False` (default) and
    `do_reshape_aggs=False`, then the single "folds" `agg_key` would contain the combined contents
    returned by both methods in the order in which they were returned

    For examples using `lambda_callback` to create custom callbacks, see
    :mod:`hyperparameter_hunter.callbacks.recipes`

    Examples
    --------
    >>> from hyperparameter_hunter.environment import Environment
    >>> def printer_helper(_rep, _fold, _run, last_evaluation_results):
    ...     print(f"{_rep}.{_fold}.{_run}   {last_evaluation_results}")
    >>> my_lambda_callback = lambda_callback(
    ...     on_exp_end=printer_helper,
    ...     on_rep_end=printer_helper,
    ...     on_fold_end=printer_helper,
    ...     on_run_end=printer_helper,
    ... )
    ... # env = Environment(
    ... #     train_dataset="i am a dataset",
    ... #     results_path="path/to/HyperparameterHunterAssets",
    ... #     metrics=["roc_auc_score"],
    ... #     experiment_callbacks=[my_lambda_callback]
    ... # )
    ... # ... Now execute an Experiment, or an Optimization Protocol...

    See :mod:`hyperparameter_hunter.examples.lambda_callback_example` for more information"""

    if on_experiment_start != DEPRECATED:
        on_exp_start = on_experiment_start
        warn(DeprecatedWarning("on_experiment_start", "3.0.0", "3.2.0", "Use `on_exp_start`"))
    if on_experiment_end != DEPRECATED:
        on_exp_end = on_experiment_end
        warn(DeprecatedWarning("on_experiment_end", "3.0.0", "3.2.0", "Use `on_exp_end`"))
    if on_repetition_start != DEPRECATED:
        on_rep_start = on_repetition_start
        warn(DeprecatedWarning("on_repetition_start", "3.0.0", "3.2.0", "Use `on_rep_start`"))
    if on_repetition_end != DEPRECATED:
        on_rep_end = on_repetition_end
        warn(DeprecatedWarning("on_repetition_end", "3.0.0", "3.2.0", "Use `on_rep_end`"))

    methods = [
        ("on_exp_start", on_exp_start, "final"),
        ("on_rep_start", on_rep_start, "reps"),
        ("on_fold_start", on_fold_start, "folds"),
        ("on_run_start", on_run_start, "runs"),
        ("on_run_end", on_run_end, "runs"),
        ("on_fold_end", on_fold_end, "folds"),
        ("on_rep_end", on_rep_end, "reps"),
        ("on_exp_end", on_exp_end, "final"),
    ]

    LambdaCallback = type("LambdaCallback", (BaseCallback,), dict())
    agg_name = "_{}".format(agg_name or str(uuid()))
    does_aggregate = False
    agg_shapes = dict(runs=None, folds=None)

    for meth_name, meth_content, agg_key in methods:
        if method_agg_keys:
            agg_key = meth_name

        def _method_factory(_meth_name=meth_name, _meth_content=meth_content, _agg_key=agg_key):
            """Provide `_meth_name`, `_meth_content`, and `_agg_key` for :func:`_method`"""

            def _method(self):
                """Execute `_meth_content`, then call parent's method of name `_method_name`"""
                nonlocal does_aggregate

                #################### Execute Custom Callback Method ####################
                try:
                    requested_parameters = {
                        _: getattr(self, _, (self.__dict__ if _ == "kwargs" else "INVALID KWARG"))
                        for _ in dict(signature(_meth_content).parameters)
                    }
                    return_value = _meth_content(**requested_parameters)
                except TypeError:
                    return_value = None

                #################### Handle Return Values ####################
                if return_value is not None:
                    does_aggregate = True
                    self.stat_aggregates.setdefault(agg_name, dict())

                    if _agg_key in ("final", "on_exp_start", "on_exp_end"):
                        self.stat_aggregates[agg_name][_agg_key] = return_value
                    else:
                        self.stat_aggregates[agg_name].setdefault(_agg_key, []).append(return_value)
                        # Record shapes of aggregated return values
                        agg_shapes[_agg_key] = np.shape(return_value)

                #################### Reshape Aggregated Values ####################
                if _meth_name == "on_exp_end" and does_aggregate and do_reshape_aggs:
                    self.stat_aggregates[agg_name] = _reshape_aggregates(
                        self.stat_aggregates[agg_name], agg_shapes, self._rep, self._fold, self._run
                    )

                #################### Call Next Callback Method in Chain ####################
                getattr(super(LambdaCallback, self), _meth_name)()

            return _method

        setattr(LambdaCallback, meth_name, _method_factory())

    return LambdaCallback


def _reshape_aggregates(named_aggregates, agg_shapes, rep, fold, run):
    """Reshape specified values of `named_aggregates` to reflect the nested repetitions/folds/runs
    structure used for other aggregated values

    Parameters
    ----------
    named_aggregates: Dict[str, Union[list, object]]
        Aggregated values collected by a particular aggregator-like `LambdaCallback`. Can have up to
        four keys, with expected value types parenthesized: "runs" (list), "folds" (list),
        "reps" (list), and "final" (object). This function will attempt to reshape the values of the
        "runs" and "folds" keys to reflect the nested repetitions/folds/runs structure used for
        other aggregated values
    agg_shapes: Dict[str, Tuple[int, ...]]
        Shapes of the values aggregated by aggregator-like `LambdaCallback` methods. Can have up to
        two keys: "runs", "folds". If the return value of a method is not an iterable, its shape
        will be an empty tuple
    rep: Integer
        The repetition number currently being executed
    fold: Integer
        The fold number currently being executed
    run: Integer
        The run number currently being executed

    Returns
    -------
    named_aggregates: Dict[str, Union[list, object]]
        Original aggregated values, with the values for keys "runs" and "folds" reshaped

    Examples
    --------
    >>> _input = dict(
    ...     runs=list("abcdefghijklmnopqrstuvwx"),
    ...     folds=["abc", "def", "ghi", "jkl", "mno", "pqr", "stu", "vwx"],
    ...     reps=["abcdefghijkl", "mnopqrstuvwx"],
    ...     final="foo"
    ... )
    >>> _expected = dict(
    ...     runs=[
    ...         [["a", "b", "c"], ["d", "e", "f"], ["g", "h", "i"], ["j", "k", "l"]],
    ...         [["m", "n", "o"], ["p", "q", "r"], ["s", "t", "u"], ["v", "w", "x"]],
    ...     ],
    ...     folds=[["abc", "def", "ghi", "jkl"], ["mno", "pqr", "stu", "vwx"]],
    ...     reps=["abcdefghijkl", "mnopqrstuvwx"],
    ...     final="foo",
    ... )
    >>> assert _reshape_aggregates(_input, dict(runs=tuple(), folds=tuple()), 1, 3, 2) == _expected
    """
    runs_shape = (rep + 1, fold + 1, run + 1)

    for (key, shape) in [("runs", runs_shape), ("folds", runs_shape[:-1])]:
        if key not in named_aggregates:
            continue
        named_aggregates[key] = np.reshape(named_aggregates[key], shape + agg_shapes[key]).tolist()
    return named_aggregates


##################################################
# Intermediate Base Callbacks
##################################################
class BaseWranglerCallback(BaseCallback):
    """Base class from which all callbacks in :mod:`hyperparameter_hunter.callbacks.wranglers` are descendants"""

    pass


class BaseInputWranglerCallback(BaseWranglerCallback):
    pass


class BaseTargetWranglerCallback(BaseWranglerCallback):
    pass


class BasePredictorCallback(BaseWranglerCallback):
    pass


class BaseLoggerCallback(BaseCallback):
    """Base class from which all callbacks in :mod:`hyperparameter_hunter.callbacks.loggers` are descendants"""

    pass


class BaseAggregatorCallback(BaseCallback):
    """Base class from which all callbacks in :mod:`hyperparameter_hunter.callbacks.aggregators` are descendants"""

    pass


class BaseEvaluatorCallback(BaseCallback):
    """Base class from which all callbacks in :mod:`hyperparameter_hunter.callbacks.evaluators` are descendants"""

    pass


if __name__ == "__main__":
    pass