"""Utility functions to build neutral tangents models for PALNT

Depending on the dataset there might be some issues with these models,
some tricks are listed in https://github.com/google/neural-tangents/issues/76

1. Use Erf as activation
2. Initialize the weights with larger standard deviation
3. Standardize the data

The first two points are done by default in the `build_dense_network` function

Note that following the law of total variance the prior, intialized via
W_std and b_std give an upper bound on the std of the posterior

from dataclasses import dataclass
from typing import Callable, Sequence, Union

class NTModel:
    """Defining a dataclass for neural tangents models"""

    # Initialization functions construct parameters for neural networks
    # given a random key and an input shape.
    init_fn: Callable
    # Apply functions do computations with finite-width neural networks.
    apply_fn: Callable
    kernel_fn: Callable
    predict_fn: Union[Callable, None] = None
    scaler: Union[Callable, None] = None  # Used to store Standard Scaler objects
    params: Union[list, None] = None  # Used to store parameters for the ensemble models

class JaxOptimizer:
    """Defining a dataclass for a JAX optimizer"""

    opt_init: Callable
    opt_update: Callable
    get_params: Callable

def build_dense_network(
    hidden_layers: Sequence[int],
    activations: Union[Sequence, str] = "erf",
    w_std: float = 2.5,
) -> NTModel:
    """Utility function to build a simple feedforward network with the
    neural tangents library.

        hidden_layers (Sequence[int]): Iterable with the number of neurons.
            For example, [512, 512]
        activations (Union[Sequence, str], optional):
            Iterable with neural_tangents.stax axtivations or "relu" or "erf".
            Defaults to "erf".
        w_std (float): Standard deviation of the weight distribution.
        b_std (float): Standard deviation of the bias distribution.

        NTModel: jiited init, apply and
            kernel functions, predict_function (None)
    from jax.config import config  # pylint:disable=import-outside-toplevel

    config.update("jax_enable_x64", True)
    from jax import jit  # pylint:disable=import-outside-toplevel
    from neural_tangents import stax  # pylint:disable=import-outside-toplevel

    assert len(hidden_layers) >= 1, "You must provide at least one hidden layer"
    if activations is None:
        activations = [stax.Relu() for _ in hidden_layers]
    elif isinstance(activations, str):
        if activations.lower() == "relu":
            activations = [stax.Relu() for _ in hidden_layers]
        elif activations.lower() == "erf":
            activations = [stax.Erf() for _ in hidden_layers]
        for activation in activations:
            assert callable(activation), "You need to provide `neural_tangents.stax` activations"

    assert len(activations) == len(
    ), "The number of hidden layers should match the number of nonlinearities"
    stack = []

    for hidden_layer, activation in zip(hidden_layers, activations):
        stack.append(stax.Dense(hidden_layer, W_std=w_std, b_std=b_std))

    stack.append(stax.Dense(1, W_std=w_std, b_std=b_std))

    init_fn, apply_fn, kernel_fn = stax.serial(*stack)

    return NTModel(init_fn, jit(apply_fn), jit(kernel_fn, static_argnums=(2,)), None)

def get_optimizer(
    learning_rate: float = 1e-4, optimizer="sdg", optimizer_kwargs: dict = None
) -> JaxOptimizer:
    """Return a `JaxOptimizer` dataclass for a JAX optimizer

        learning_rate (float, optional): Step size. Defaults to 1e-4.
        optimizer (str, optional): Optimizer type (Allowed types: "adam",
            "adamax", "adagrad", "rmsprop", "sdg"). Defaults to "sdg".
        optimizer_kwargs (dict, optional): Additional keyword arguments
            that are passed to the optimizer. Defaults to None.

    from jax.config import config  # pylint:disable=import-outside-toplevel

    config.update("jax_enable_x64", True)
    from jax import jit  # pylint:disable=import-outside-toplevel
    from jax.experimental import optimizers  # pylint:disable=import-outside-toplevel

    if optimizer_kwargs is None:
        optimizer_kwargs = {}
    optimizer = optimizer.lower()
    if optimizer == "adam":
        opt_init, opt_update, get_params = optimizers.adam(learning_rate, **optimizer_kwargs)
    elif optimizer == "adagrad":
        opt_init, opt_update, get_params = optimizers.adagrad(learning_rate, **optimizer_kwargs)
    elif optimizer == "adamax":
        opt_init, opt_update, get_params = optimizers.adamax(learning_rate, **optimizer_kwargs)
    elif optimizer == "rmsprop":
        opt_init, opt_update, get_params = optimizers.rmsprop(learning_rate, **optimizer_kwargs)
        opt_init, opt_update, get_params = optimizers.sgd(learning_rate, **optimizer_kwargs)

    opt_update = jit(opt_update)

    return JaxOptimizer(opt_init, opt_update, get_params)