FilippoAiraldi/casadi-neural-nets

View on GitHub
src/csnn/containers.py

Summary

Maintainability
A
25 mins
Test Coverage
A
96%
from collections import OrderedDict
from collections.abc import Iterable, Iterator
from itertools import islice
from operator import index
from typing import TypeVar, Union

import casadi as cs

from csnn.module import Module

SymType = TypeVar("SymType", cs.SX, cs.MX)


class Sequential(Module[SymType]):
    """A sequential container. Modules will be added to it in the order they are passed
    in the constructor. Alternatively, an `OrderedDict` of modules can be passed in.
    The `forward` method of `Sequential` accepts any input and forwards it to the first
    module it contains. It then "chains" outputs to inputs sequentially for each
    subsequent module, finally returning the output of the last module."""

    def __init__(
        self, modules: Union[dict[str, Module[SymType]], Iterable[Module[SymType]]]
    ) -> None:
        """Instianties the sequential module.

        Parameters
        ----------
        modules : dict[str, Module] or iterable of Module
            A dict of names-modules, or an iterable or modules.
        """
        super().__init__()
        if isinstance(modules, dict):
            for name, module in modules.items():
                self.add_module(name, module)
        else:
            for i, module in enumerate(modules):
                self.add_module(str(i), module)

    def forward(self, input: SymType) -> SymType:
        for _, module in self:
            input = module.forward(input)
        return input

    def __iter__(self) -> Iterator[tuple[str, Module[SymType]]]:
        return iter(self._modules.items())

    def __getitem__(
        self, idx: Union[slice, int]
    ) -> Union["Sequential[SymType]", Module[SymType]]:
        if isinstance(idx, slice):
            return self.__class__(OrderedDict(list(self._modules.items())[idx]))
        size = len(self._modules)
        idx = index(idx)
        if not -size <= idx < size:
            raise IndexError(f"index {idx} is out of range")
        idx %= size
        return next(islice(self._modules.values(), idx, None))