FilippoAiraldi/casadi-neural-nets

View on GitHub
src/csnn/__init__.py

Summary

Maintainability
A
0 mins
Test Coverage
A
91%
__version__ = "1.0.3rc1"

__all__ = [
    "Module",
    "Sequential",
    "Linear",
    "ReLU",
    "Sigmoid",
    "SoftPlus",
    "get_sym_type",
    "set_sym_type",
]


from typing import Literal, Union

import casadi as cs

from csnn.activation import ReLU, Sigmoid, SoftPlus
from csnn.containers import Sequential
from csnn.linear import Linear
from csnn.module import Module


def get_sym_type() -> Union[type[cs.SX], type[cs.MX]]:
    """Gets the casadi's symbolic type used to build the networks.

    Returns
    -------
    type[cs.SX] or type[cs.MX]]
        The current symbolic type, either `casadi.SX` or `MX`.
    """
    return Module.sym_type


def set_sym_type(type: Literal["SX", "MX"]) -> None:
    """Sets the casadi's symbolic type to be used in building the networks.

    Parameters
    ----------
    type : "SX" or "MX"
        The name of the symbolic type to set.
    """
    Module.sym_type = getattr(cs, type)