FilippoAiraldi/casadi-neural-nets

View on GitHub
examples/pwq.py

Summary

Maintainability
A
0 mins
Test Coverage
"""Reproduces the PieceWise Quadratic (PWQ) neural network from [1] used to approximate
the value function of a constrained linear proble. Thus, the output of the net is meant
to be a single scalar.

References
----------
[1] He, K., Shi, S., Boom, T.V.D. and De Schutter, B., 2022. Approximate Dynamic
    Programming for Constrained Linear Systems: A Piecewise Quadratic Approximation
    Approach. arXiv preprint arXiv:2205.10065.
"""

from typing import Callable, TypeVar

import casadi as cs
import matplotlib.pyplot as plt
import numpy as np

import csnn

SymType = TypeVar("SymType", cs.SX, cs.MX)


class ElementWiseSquare(csnn.Module[SymType]):
    """Squares the input in an element-wise fashion."""

    def forward(self, x: SymType) -> SymType:
        return x * x


class DotProduct(csnn.Linear[SymType]):
    """Compute the dot product of the input with some weights."""

    def __init__(self, in_features: int) -> None:
        super().__init__(in_features, 1, False)


class Pwq(csnn.Module[SymType]):
    """Piecewise quadratic (PWQ) neural network."""

    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        act: Callable[[SymType], SymType] = csnn.ReLU(),
    ) -> None:
        """Creates a PWQ model.

        Parameters
        ----------
        in_features : int
            Number of input features.
        hidden_features : int
            Number of features in the single hidden layer.
        act : Callable[[SymType], SymType], optional
            Instance of an activation function class, or a callable that computes an
            activation function. By default, `ReLU` is used.
        """
        super().__init__()
        self.layers = csnn.Sequential(
            (
                csnn.Linear(in_features, hidden_features),
                act,
                ElementWiseSquare(),
                DotProduct(hidden_features),
            )
        )

    def forward(self, x: SymType) -> SymType:
        return self.layers(x)


# create the model
n_in = 2
n_hidden = 32
mdl = Pwq(n_in, n_hidden)

# turn it into a function
x = cs.MX.sym("x", n_in, 1)
y = mdl(x.T)
p = dict(mdl.parameters(skip_none=True))
F = cs.Function("F", [x, cs.vvcat(p.values())], [y], ["x", "p"], ["y"])

# simulate some weights
np_random = np.random.default_rng(69)
p_num = {k: np_random.normal(size=v.shape) for k, v in p.items()}

# force convexity of nn function
p_num["layers.3.weight"] = np.abs(p_num["layers.3.weight"])

# force value at origin to be close to zero
p_num["layers.0.bias"] = -np.abs(p_num["layers.0.bias"])

# plot function
p_num = cs.vvcat(p_num.values())
o = np.linspace(-3, 3, 100)
X1, X2 = np.meshgrid(o, o)
X = np.stack((X1, X2)).reshape(n_in, -1)
Y = F(X, p_num).full().reshape(o.size, o.size)

print("Value in origin:", F(0, p_num).full().item(), "(should be zero)")
min_idx = np.argmin(Y)
r = np.unravel_index(min_idx, Y.shape)
print(
    "Minimum value:", Y[r], "at", X[:, min_idx], "(should be zero, and close to origin)"
)

fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
ax.plot_surface(X1, X2, Y, cmap="RdBu_r")
ax.set_xlabel(r"$x_1$")
ax.set_ylabel(r"$x_2$")
ax.set_zlabel(r"$y$")
plt.show()