kjappelbaum/pyepal

View on GitHub
src/pyepal/pal/pal_finite_ensemble.py

Summary

Maintainability
A
1 hr
Test Coverage
# -*- coding: utf-8 -*-
# Copyright 2020 PyePAL authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Run PAL with the same models for finite ensemble models
and infinite width models (`PALNT`)
"""

from typing import Sequence

import numpy as np
from sklearn.preprocessing import StandardScaler

from .pal_base import PALBase
from .validate_inputs import validate_nt_models, validate_optimizers, validate_positive_integer_list
from ..models.nt import JaxOptimizer, NTModel


# Again, the idea of having the core as pure functions outside of the class is that
# we could parallelize it easier in this way
def _ensemble_train_one_finite_width(  # pylint:disable=too-many-arguments, too-many-locals
    i: int,
    models: Sequence[NTModel],
    design_space: np.ndarray,
    objectives: np.ndarray,
    sampled: np.ndarray,
    optimizers: Sequence[JaxOptimizer],
    key: object,
    training_steps: Sequence[int],
    ensemble_size: Sequence[int],
):
    from jax import random  # pylint:disable=import-outside-toplevel
    from jax import grad, jit, vmap  # pylint:disable=import-outside-toplevel

    model = models[i]
    optimizer = optimizers[i]
    loss = jit(lambda params, x, y: 0.5 * np.mean((model.apply_fn(params, x) - y) ** 2))
    grad_loss = jit(lambda state, x, y: grad(loss)(optimizer.get_params(state), x, y))

    x_train = design_space[sampled[:, i]]

    scaler = StandardScaler()
    y_train = scaler.fit_transform(objectives[sampled[:, i], i].reshape(-1, 1))

    def train_network(key):
        _, params = model.init_fn(key, (-1, x_train.shape[1]))
        opt_state = optimizer.opt_init(params)

        for j in range(training_steps[i]):
            opt_state = optimizer.opt_update(j, grad_loss(opt_state, x_train, y_train), opt_state)

        return optimizer.get_params(opt_state)

    ensemble_key = random.split(key, ensemble_size[i])
    params = vmap(train_network)(ensemble_key)

    return params, scaler


def _ensemble_predict_one_finite_width(i: int, models: Sequence[NTModel], design_space):
    from jax.api import vmap  # pylint:disable=import-outside-toplevel

    model = models[i]

    ensemble_func = vmap(model.apply_fn, (0, None))(model.params, design_space)

    mean_func = np.reshape(np.mean(ensemble_func, axis=0), (-1,))
    std_func = np.reshape(np.std(ensemble_func, axis=0), (-1,))

    return mean_func, std_func


__all__ = ["PALJaxEnsemble", "NTModel", "JaxOptimizer"]


class PALJaxEnsemble(PALBase):  # pylint:disable=too-many-instance-attributes
    """Use PAL with and ensemble of finite-width neural networks.
    Note that we current assume that there is one model per output,
    i.e., we did not yet implement multihead support.
    """

    def __init__(self, *args, **kwargs):
        """Construct the PALJaxEnsemble instance

        Args:
            X_design (np.array): Design space (feature matrix)
            models (Sequence[NTModel]): You need to provide a sequence of
                 NTModel (`pyepal.models.nt.NTModel`).
                The elements of this dataclass are the `apply_fn`, `init_fn`,
                `kernel_fn` and `predict_fn` (for latter you can typically
                provide `None`).
                Can be constructed with
                :py:func:`pyepal.pal.models.nt.build_dense_network`.
            optimizer (Union[JaxOptimizer, Sequence[JaxOptimizer]]):
                Sequence of dataclasses with functions for a JAX optimizer,
                can be constructed with :py:func:`pyepal.pal.models.nt.get_optimizer`.
            ndim (int): Number of objectives
            epsilon (Union[list, float], optional): Epsilon hyperparameter.
                Defaults to 0.01.
            delta (float, optional): Delta hyperparameter. Defaults to 0.05.
            beta_scale (float, optional): Scaling parameter for beta.
                If not equal to 1, the theoretical guarantees do not necessarily hold.
                Also note that the parametrization depends on the kernel type.
                Defaults to 1/9.
            goals (List[str], optional): If a list, provide "min" for every objective
                that shall be minimized and "max" for every objective
                that shall be maximized. Defaults to None, which means
                that the code maximizes all objectives.
            coef_var_threshold (float, optional): Use only points with
                a coefficient of variation below this threshold
                in the classification step. Defaults to 3.
            key (int): Seed to generate the key for the JAX
                pseudo-random number generator. Defaults to 10.
            training_steps (Union[int, Sequence[int]]): Number of epochs,
                the networks are trained. Defaults to 500.
            ensemble_size (Union[int, Sequence[int]]): Size of the ensemble, i.e.,
                over how many randomly initialized neural networks we average
                to obtain estimates of mean and standard deviation.
                Automatically vectorized using `vmap`.
                Defaults to 100.
        """
        from jax import random  # pylint:disable=import-outside-toplevel

        self.optimizers = validate_optimizers(kwargs.pop("optimizers"), kwargs.get("ndim"))

        self.training_steps = validate_positive_integer_list(
            kwargs.pop("training_steps", 500), kwargs.get("ndim")
        )
        self.ensemble_size = validate_positive_integer_list(
            kwargs.pop("ensemble_size", 100), kwargs.get("ndim")
        )
        self.key = random.PRNGKey(kwargs.pop("key", 10))
        self.design_space_scaler = StandardScaler()
        super().__init__(*args, **kwargs)
        self.models = validate_nt_models(self.models, self.ndim)

    def _set_data(self):
        self.design_space = self.design_space_scaler.fit_transform(self.design_space)

    def _train(self):
        for i, _ in enumerate(self.models):
            params, scaler = _ensemble_train_one_finite_width(
                i,
                self.models,
                self.design_space,
                self.y,
                self.sampled,
                self.optimizers,
                self.key,
                self.training_steps,
                self.ensemble_size,
            )
            self.models[i].params = params
            self.models[i].scaler = scaler
            self.y[:, i] = scaler.transform(self.y[:, i].reshape(-1, 1)).flatten()

    def _predict(self):
        means, stds = [], []
        for i in range(len(self.models)):
            mean, std = _ensemble_predict_one_finite_width(i, self.models, self.design_space)
            means.append(mean.reshape(-1, 1))
            stds.append(std.reshape(-1, 1))

        self._means = np.hstack(means)
        self.std = np.hstack(stds)