ErikGartner/hyperdock

View on GitHub
hyperdock/supervisor/search/search.py

Summary

Maintainability
A
55 mins
Test Coverage
from sklearn.model_selection import ParameterGrid

from .sampling import sample_values


class Search:
    """
    The class for the parameter search methods

    Implemented:
        - Grid search
        - Sampling from distributions
    """

    @classmethod
    def expand(cls, specs, **kwargs):
        """
        Takes a (list of) parameter specification and returns a list
        of combination to try.
        """
        specs = Search.list_wrap(specs)

        params = []
        for spec in specs:
            params.extend(cls._expand_spec(spec, **kwargs))
        return params

    @staticmethod
    def list_wrap(spec):
        """
        If not a list, wraps the spec in a list.
        """
        if not isinstance(spec, list):
            spec = [spec]
        return spec

    @staticmethod
    def _expand_spec(spec, **kwargs):
        """
        Expands a single parameter specification.
        """
        fixed_params = {}
        variable_params = {}
        for k, v in spec.items():
            if isinstance(v, list):
                variable_params[k] = v
            elif isinstance(v, dict):
                # Try handling as distribution
                res = sample_values(v)
                if res is not None:
                    variable_params[k] = res
                else:
                    fixed_params[k] = v
            else:
                fixed_params[k] = v

        params = list(ParameterGrid(variable_params))
        [p.update(fixed_params) for p in params]
        return params