Kev1CO/cocofest

View on GitHub
cocofest/optimization/fes_ocp.py

Summary

Maintainability
D
2 days
Test Coverage
import numpy as np

from bioptim import (
    BoundsList,
    ConstraintList,
    ControlType,
    DynamicsList,
    InitialGuessList,
    InterpolationType,
    Node,
    Objective,
    ObjectiveFcn,
    ObjectiveList,
    OdeSolver,
    OptimalControlProgram,
    ParameterList,
    ParameterObjectiveList,
    PhaseDynamics,
    VariableScaling,
)

from ..custom_objectives import CustomObjective
from ..custom_constraints import CustomConstraint
from ..fourier_approx import FourierSeries

from ..models.fes_model import FesModel
from ..models.ding2007 import DingModelPulseDurationFrequency
from ..models.ding2007_with_fatigue import DingModelPulseDurationFrequencyWithFatigue
from ..models.ding2003 import DingModelFrequency
from ..models.hmed2018 import DingModelIntensityFrequency
from ..models.hmed2018_with_fatigue import DingModelIntensityFrequencyWithFatigue


class OcpFes:
    """
    The main class to define an ocp. This class prepares the full program and gives all
    the needed parameters to solve a functional electrical stimulation ocp

    Methods
    -------
    from_frequency_and_final_time(self, frequency: int | float, final_time: float, round_down: bool)
        Calculates the number of stim (phases) for the ocp from frequency and final time
    from_frequency_and_n_stim(self, frequency: int | float, n_stim: int)
        Calculates the final ocp time from frequency and stimulation number
    """

    @staticmethod
    def prepare_ocp(
        model: FesModel = None,
        n_stim: int = None,
        n_shooting: int = None,
        final_time: int | float = None,
        pulse_event: dict = None,
        pulse_duration: dict = None,
        pulse_intensity: dict = None,
        objective: dict = None,
        use_sx: bool = True,
        ode_solver: OdeSolver = OdeSolver.RK4(n_integration_steps=1),
        n_threads: int = 1,
    ):
        """
        Prepares the Optimal Control Program (OCP) to be solved.

        Parameters
        ----------
        model : FesModel
            The model type used for the OCP.
        n_stim : int
            Number of stimulations that will occur during the OCP, also referred to as phases.
        n_shooting : int
            Number of shooting points for each individual phase.
        final_time : int | float
            The final time of the OCP.
        pulse_event : dict
            Dictionary containing parameters related to the appearance of the pulse.
        pulse_duration : dict
            Dictionary containing parameters related to the duration of the pulse.
            Optional if not using DingModelPulseDurationFrequency or DingModelPulseDurationFrequencyWithFatigue.
        pulse_intensity : dict
            Dictionary containing parameters related to the intensity of the pulse.
            Optional if not using DingModelIntensityFrequency or DingModelIntensityFrequencyWithFatigue.
        objective : dict
            Dictionary containing parameters related to the optimization objective.
        use_sx : bool
            The nature of the CasADi variables. MX are used if False.
        ode_solver : OdeSolver
            The ODE solver to use.
        n_threads : int
            The number of threads to use while solving (multi-threading if > 1).

        Returns
        -------
        OptimalControlProgram
            The prepared Optimal Control Program.

        """

        (pulse_event, pulse_duration, pulse_intensity, objective) = OcpFes._fill_dict(
            pulse_event, pulse_duration, pulse_intensity, objective
        )

        time_min = pulse_event["min"]
        time_max = pulse_event["max"]
        time_bimapping = pulse_event["bimapping"]
        frequency = pulse_event["frequency"]
        round_down = pulse_event["round_down"]
        pulse_mode = pulse_event["pulse_mode"]

        fixed_pulse_duration = pulse_duration["fixed"]
        pulse_duration_min = pulse_duration["min"]
        pulse_duration_max = pulse_duration["max"]
        pulse_duration_bimapping = pulse_duration["bimapping"]

        fixed_pulse_intensity = pulse_intensity["fixed"]
        pulse_intensity_min = pulse_intensity["min"]
        pulse_intensity_max = pulse_intensity["max"]
        pulse_intensity_bimapping = pulse_intensity["bimapping"]

        force_tracking = objective["force_tracking"]
        end_node_tracking = objective["end_node_tracking"]
        custom_objective = objective["custom"]

        OcpFes._sanity_check(
            model=model,
            n_stim=n_stim,
            n_shooting=n_shooting,
            final_time=final_time,
            pulse_mode=pulse_mode,
            frequency=frequency,
            time_min=time_min,
            time_max=time_max,
            time_bimapping=time_bimapping,
            fixed_pulse_duration=fixed_pulse_duration,
            pulse_duration_min=pulse_duration_min,
            pulse_duration_max=pulse_duration_max,
            pulse_duration_bimapping=pulse_duration_bimapping,
            fixed_pulse_intensity=fixed_pulse_intensity,
            pulse_intensity_min=pulse_intensity_min,
            pulse_intensity_max=pulse_intensity_max,
            pulse_intensity_bimapping=pulse_intensity_bimapping,
            force_tracking=force_tracking,
            end_node_tracking=end_node_tracking,
            custom_objective=custom_objective,
            use_sx=use_sx,
            ode_solver=ode_solver,
            n_threads=n_threads,
        )

        OcpFes._sanity_check_frequency(n_stim=n_stim, final_time=final_time, frequency=frequency, round_down=round_down)

        n_stim, final_time = OcpFes._build_phase_parameter(
            n_stim=n_stim, final_time=final_time, frequency=frequency, pulse_mode=pulse_mode, round_down=round_down
        )

        force_fourier_coefficient = (
            None if force_tracking is None else OcpFes._build_fourier_coefficient(force_tracking)
        )

        models = [model] * n_stim
        n_shooting = [n_shooting] * n_stim

        final_time_phase = OcpFes._build_phase_time(
            final_time=final_time,
            n_stim=n_stim,
            pulse_mode=pulse_mode,
            time_min=time_min,
            time_max=time_max,
        )
        parameters, parameters_bounds, parameters_init, parameter_objectives, constraints = OcpFes._build_parameters(
            model=model,
            n_stim=n_stim,
            time_min=time_min,
            time_max=time_max,
            time_bimapping=time_bimapping,
            fixed_pulse_duration=fixed_pulse_duration,
            pulse_duration_min=pulse_duration_min,
            pulse_duration_max=pulse_duration_max,
            pulse_duration_bimapping=pulse_duration_bimapping,
            fixed_pulse_intensity=fixed_pulse_intensity,
            pulse_intensity_min=pulse_intensity_min,
            pulse_intensity_max=pulse_intensity_max,
            pulse_intensity_bimapping=pulse_intensity_bimapping,
            use_sx=use_sx,
        )

        if len(constraints) == 0 and len(parameters) == 0:
            raise ValueError(
                "This is not an optimal control problem,"
                " add parameter to optimize or use the IvpFes method to build your problem"
            )

        dynamics = OcpFes._declare_dynamics(models, n_stim)
        x_bounds, x_init = OcpFes._set_bounds(model, n_stim)
        objective_functions = OcpFes._set_objective(
            n_stim, n_shooting, force_fourier_coefficient, end_node_tracking, custom_objective, time_min, time_max
        )

        return OptimalControlProgram(
            bio_model=models,
            dynamics=dynamics,
            n_shooting=n_shooting,
            phase_time=final_time_phase,
            objective_functions=objective_functions,
            x_init=x_init,
            x_bounds=x_bounds,
            constraints=constraints,
            parameters=parameters,
            parameter_bounds=parameters_bounds,
            parameter_init=parameters_init,
            parameter_objectives=parameter_objectives,
            control_type=ControlType.CONSTANT,
            use_sx=use_sx,
            ode_solver=ode_solver,
            n_threads=n_threads,
        )

    @staticmethod
    def _fill_dict(pulse_event, pulse_duration, pulse_intensity, objective):
        """
        This method fills the provided dictionaries with default values if they are not set.

        Parameters
        ----------
        pulse_event : dict
            Dictionary containing parameters related to the appearance of the pulse.
            Expected keys are 'min', 'max', 'bimapping', 'frequency', 'round_down', and 'pulse_mode'.

        pulse_duration : dict
            Dictionary containing parameters related to the duration of the pulse.
            Expected keys are 'fixed', 'min', 'max', and 'bimapping'.

        pulse_intensity : dict
            Dictionary containing parameters related to the intensity of the pulse.
            Expected keys are 'fixed', 'min', 'max', and 'bimapping'.

        objective : dict
            Dictionary containing parameters related to the objective of the optimization.
            Expected keys are 'force_tracking', 'end_node_tracking', and 'custom'.

        Returns
        -------
        Returns four dictionaries: pulse_event, pulse_duration, pulse_intensity, and objective.
        Each dictionary is filled with default values for any keys that were not initially set.
        """

        default_pulse_event = {
            "min": None,
            "max": None,
            "bimapping": False,
            "frequency": None,
            "round_down": False,
            "pulse_mode": "single",
        }

        default_pulse_duration = {
            "fixed": None,
            "min": None,
            "max": None,
            "bimapping": False,
        }

        default_pulse_intensity = {
            "fixed": None,
            "min": None,
            "max": None,
            "bimapping": False,
        }

        default_objective = {
            "force_tracking": None,
            "end_node_tracking": None,
            "cycling": None,
            "custom": None,
        }
        dict_list = [pulse_event, pulse_duration, pulse_intensity, objective]
        default_dict_list = [
            default_pulse_event,
            default_pulse_duration,
            default_pulse_intensity,
            default_objective,
        ]

        for i in range(len(dict_list)):
            if dict_list[i] is None:
                dict_list[i] = {}

        for i in range(len(dict_list)):
            for key in default_dict_list[i]:
                if key not in dict_list[i]:
                    dict_list[i][key] = default_dict_list[i][key]

        return dict_list[0], dict_list[1], dict_list[2], dict_list[3]

    @staticmethod
    def _sanity_check(
        model=None,
        n_stim=None,
        n_shooting=None,
        final_time=None,
        pulse_mode=None,
        frequency=None,
        time_min=None,
        time_max=None,
        time_bimapping=None,
        fixed_pulse_duration=None,
        pulse_duration_min=None,
        pulse_duration_max=None,
        pulse_duration_bimapping=None,
        fixed_pulse_intensity=None,
        pulse_intensity_min=None,
        pulse_intensity_max=None,
        pulse_intensity_bimapping=None,
        force_tracking=None,
        end_node_tracking=None,
        custom_objective=None,
        use_sx=None,
        ode_solver=None,
        n_threads=None,
    ):
        if not isinstance(model, FesModel):
            raise TypeError(
                f"The current model type used is {type(model)}, it must be a FesModel type."
                f"Current available models are: DingModelFrequency, DingModelFrequencyWithFatigue,"
                f"DingModelPulseDurationFrequency, DingModelPulseDurationFrequencyWithFatigue,"
                f"DingModelIntensityFrequency, DingModelIntensityFrequencyWithFatigue"
            )

        if n_stim:
            if isinstance(n_stim, int):
                if n_stim <= 0:
                    raise ValueError("n_stim must be positive")
            else:
                raise TypeError("n_stim must be int type")

        if n_shooting:
            if isinstance(n_shooting, int):
                if n_shooting <= 0:
                    raise ValueError("n_shooting must be positive")
            else:
                raise TypeError("n_shooting must be int type")

        if final_time:
            if isinstance(final_time, int | float):
                if final_time <= 0:
                    raise ValueError("final_time must be positive")
            else:
                raise TypeError("final_time must be int or float type")

        if pulse_mode:
            if pulse_mode != "single":
                raise NotImplementedError(f"Pulse mode '{pulse_mode}' is not yet implemented")

        if frequency:
            if isinstance(frequency, int | float):
                if frequency <= 0:
                    raise ValueError("frequency must be positive")
            else:
                raise TypeError("frequency must be int or float type")

        if [time_min, time_max].count(None) == 1:
            raise ValueError("time_min and time_max must be both entered or none of them in order to work")

        if time_bimapping:
            if not isinstance(time_bimapping, bool):
                raise TypeError("time_bimapping must be bool type")

        if isinstance(model, DingModelPulseDurationFrequency | DingModelPulseDurationFrequencyWithFatigue):
            if fixed_pulse_duration is None and [pulse_duration_min, pulse_duration_max].count(None) != 0:
                raise ValueError("pulse duration or pulse duration min max bounds need to be set for this model")
            if all([fixed_pulse_duration, pulse_duration_min, pulse_duration_max]):
                raise ValueError("Either pulse duration or pulse duration min max bounds need to be set for this model")

            minimum_pulse_duration = (
                0 if model.pd0 is None else model.pd0
            )  # Set it to 0 if used for the identification process

            if fixed_pulse_duration is not None:
                if isinstance(fixed_pulse_duration, int | float):
                    if fixed_pulse_duration < minimum_pulse_duration:
                        raise ValueError(
                            f"The pulse duration set ({fixed_pulse_duration})"
                            f" is lower than minimum duration required."
                            f" Set a value above {minimum_pulse_duration} seconds "
                        )
                elif isinstance(fixed_pulse_duration, list):
                    if not all(isinstance(x, int | float) for x in fixed_pulse_duration):
                        raise TypeError("pulse_duration must be int or float type")
                    if not all(x >= minimum_pulse_duration for x in fixed_pulse_duration):
                        raise ValueError(
                            f"The pulse duration set ({fixed_pulse_duration})"
                            f" is lower than minimum duration required."
                            f" Set a value above {minimum_pulse_duration} seconds "
                        )
                else:
                    raise TypeError("Wrong pulse_duration type, only int or float accepted")

            elif pulse_duration_min is not None and pulse_duration_max is not None:
                if not isinstance(pulse_duration_min, int | float) or not isinstance(pulse_duration_max, int | float):
                    raise TypeError("pulse_duration_min and pulse_duration_max must be int or float type")
                if pulse_duration_max < pulse_duration_min:
                    raise ValueError("The set minimum pulse duration is higher than maximum pulse duration.")
                if pulse_duration_min < minimum_pulse_duration:
                    raise ValueError(
                        f"The pulse duration set ({pulse_duration_min})"
                        f" is lower than minimum duration required."
                        f" Set a value above {minimum_pulse_duration} seconds "
                    )

            if not isinstance(pulse_duration_bimapping, None | bool):
                raise NotImplementedError("If added, pulse duration parameter mapping must be a bool type")

        if isinstance(model, DingModelIntensityFrequency | DingModelIntensityFrequencyWithFatigue):
            if fixed_pulse_intensity is None and [pulse_intensity_min, pulse_intensity_max].count(None) != 0:
                raise ValueError("Pulse intensity or pulse intensity min max bounds need to be set for this model")
            if all([fixed_pulse_intensity, pulse_intensity_min, pulse_intensity_max]):
                raise ValueError(
                    "Either pulse intensity or pulse intensity min max bounds need to be set for this model"
                )

            check_for_none_type = [model.cr, model.bs, model.Is]
            minimum_pulse_intensity = (
                0 if None in check_for_none_type else model.min_pulse_intensity()
            )  # Set it to 0 if used for the identification process

            if fixed_pulse_intensity is not None:
                if isinstance(fixed_pulse_intensity, int | float):
                    if fixed_pulse_intensity < minimum_pulse_intensity:
                        raise ValueError(
                            f"The pulse intensity set ({fixed_pulse_intensity})"
                            f" is lower than minimum intensity required."
                            f" Set a value above {minimum_pulse_intensity} mA "
                        )
                elif isinstance(fixed_pulse_intensity, list):
                    if not all(isinstance(x, int | float) for x in fixed_pulse_intensity):
                        raise TypeError("pulse_intensity must be int or float type")
                    if not all(x >= minimum_pulse_intensity for x in fixed_pulse_intensity):
                        raise ValueError(
                            f"The pulse intensity set ({fixed_pulse_intensity})"
                            f" is lower than minimum intensity required."
                            f" Set a value above {minimum_pulse_intensity} seconds "
                        )
                else:
                    raise TypeError("pulse_intensity must be int or float type")

            elif pulse_intensity_min is not None and pulse_intensity_max is not None:
                if not isinstance(pulse_intensity_min, int | float) or not isinstance(pulse_intensity_max, int | float):
                    raise TypeError("pulse_intensity_min and pulse_intensity_max must be int or float type")
                if pulse_intensity_max < pulse_intensity_min:
                    raise ValueError("The set minimum pulse intensity is higher than maximum pulse intensity.")
                if pulse_intensity_min < minimum_pulse_intensity:
                    raise ValueError(
                        f"The pulse intensity set ({pulse_intensity_min})"
                        f" is lower than minimum intensity required."
                        f" Set a value above {minimum_pulse_intensity} mA "
                    )

            if not isinstance(pulse_intensity_bimapping, None | bool):
                raise NotImplementedError("If added, pulse intensity parameter mapping must be a bool type")

        if force_tracking is not None:
            if isinstance(force_tracking, list):
                if isinstance(force_tracking[0], np.ndarray) and isinstance(force_tracking[1], np.ndarray):
                    if len(force_tracking[0]) != len(force_tracking[1]) or len(force_tracking) != 2:
                        raise ValueError(
                            "force_tracking time and force argument must be same length and force_tracking "
                            "list size 2"
                        )
                else:
                    raise TypeError("force_tracking argument must be np.ndarray type")
            else:
                raise TypeError("force_tracking must be list type")

        if end_node_tracking:
            if not isinstance(end_node_tracking, int | float):
                raise TypeError("end_node_tracking must be int or float type")

        if custom_objective:
            if not isinstance(custom_objective, ObjectiveList):
                raise TypeError("custom_objective must be a ObjectiveList type")
            if not all(isinstance(x, Objective) for x in custom_objective[0]):
                raise TypeError("All elements in ObjectiveList must be an Objective type")

        if not isinstance(ode_solver, (OdeSolver.RK1, OdeSolver.RK2, OdeSolver.RK4, OdeSolver.COLLOCATION)):
            raise TypeError("ode_solver must be a OdeSolver type")

        if not isinstance(use_sx, bool):
            raise TypeError("use_sx must be a bool type")

        if not isinstance(n_threads, int):
            raise TypeError("n_thread must be a int type")

    @staticmethod
    def _sanity_check_frequency(n_stim, final_time, frequency, round_down):
        if [n_stim, final_time, frequency].count(None) == 2:
            raise ValueError("At least two variable must be set from n_stim, final_time or frequency")

        if n_stim and final_time and frequency:
            if n_stim != final_time * frequency:
                raise ValueError(
                    "Can not satisfy n_stim equal to final_time * frequency with the given parameters."
                    "Consider setting only two of the three parameters"
                )

        if round_down:
            if not isinstance(round_down, bool):
                raise TypeError("round_down must be bool type")

    @staticmethod
    def _build_fourier_coefficient(force_tracking):
        return FourierSeries().compute_real_fourier_coeffs(force_tracking[0], force_tracking[1], 50)

    @staticmethod
    def _build_phase_time(final_time, n_stim, pulse_mode, time_min, time_max):
        final_time_phase = None
        if time_min is None and time_max is None:
            if pulse_mode == "single":
                step = final_time / n_stim
                final_time_phase = (step,)
                for i in range(n_stim - 1):
                    final_time_phase = final_time_phase + (step,)
        else:
            final_time_phase = [final_time / n_stim] * n_stim

        return final_time_phase

    @staticmethod
    def _build_parameters(
        model,
        n_stim,
        time_min,
        time_max,
        time_bimapping,
        fixed_pulse_duration,
        pulse_duration_min,
        pulse_duration_max,
        pulse_duration_bimapping,
        fixed_pulse_intensity,
        pulse_intensity_min,
        pulse_intensity_max,
        pulse_intensity_bimapping,
        use_sx,
    ):
        parameters = ParameterList(use_sx=use_sx)
        parameters_bounds = BoundsList()
        parameters_init = InitialGuessList()
        parameter_objectives = ParameterObjectiveList()
        constraints = ConstraintList()

        if time_min:
            parameters.add(
                name="pulse_apparition_time",
                function=DingModelFrequency.set_pulse_apparition_time,
                size=n_stim,
                scaling=VariableScaling("pulse_apparition_time", [1] * n_stim),
            )

            if time_min and time_max:
                time_min_list = [time_min * n for n in range(n_stim)]
                time_max_list = [time_max * n for n in range(n_stim)]
            else:
                time_min_list = [0] * n_stim
                time_max_list = [100] * n_stim
            parameters_bounds.add(
                "pulse_apparition_time",
                min_bound=np.array(time_min_list),
                max_bound=np.array(time_max_list),
                interpolation=InterpolationType.CONSTANT,
            )

            parameters_init["pulse_apparition_time"] = np.array([0] * n_stim)

            for i in range(n_stim):
                constraints.add(CustomConstraint.pulse_time_apparition_as_phase, node=Node.START, phase=i, target=0)

        if time_bimapping and time_min and time_max:
            for i in range(n_stim):
                constraints.add(CustomConstraint.equal_to_first_pulse_interval_time, node=Node.START, target=0, phase=i)

        if isinstance(model, DingModelPulseDurationFrequency):
            if fixed_pulse_duration:
                parameters.add(
                    name="pulse_duration",
                    function=DingModelPulseDurationFrequency.set_impulse_duration,
                    size=n_stim,
                    scaling=VariableScaling("pulse_duration", [1] * n_stim),
                )
                if isinstance(fixed_pulse_duration, list):
                    parameters_bounds.add(
                        "pulse_duration",
                        min_bound=np.array(fixed_pulse_duration),
                        max_bound=np.array(fixed_pulse_duration),
                        interpolation=InterpolationType.CONSTANT,
                    )
                    parameters_init.add(key="pulse_duration", initial_guess=np.array(fixed_pulse_duration))
                else:
                    parameters_bounds.add(
                        "pulse_duration",
                        min_bound=np.array([fixed_pulse_duration] * n_stim),
                        max_bound=np.array([fixed_pulse_duration] * n_stim),
                        interpolation=InterpolationType.CONSTANT,
                    )
                    parameters_init["pulse_duration"] = np.array([fixed_pulse_duration] * n_stim)

            elif pulse_duration_min is not None and pulse_duration_max is not None:
                parameters_bounds.add(
                    "pulse_duration",
                    min_bound=[pulse_duration_min],
                    max_bound=[pulse_duration_max],
                    interpolation=InterpolationType.CONSTANT,
                )
                parameters_init["pulse_duration"] = np.array([0] * n_stim)
                parameters.add(
                    name="pulse_duration",
                    function=DingModelPulseDurationFrequency.set_impulse_duration,
                    size=n_stim,
                    scaling=VariableScaling("pulse_duration", [1] * n_stim),
                )

            if pulse_duration_bimapping is True:
                for i in range(1, n_stim):
                    constraints.add(CustomConstraint.equal_to_first_pulse_duration, node=Node.START, target=0, phase=i)

        if isinstance(model, DingModelIntensityFrequency):
            if fixed_pulse_intensity:
                parameters.add(
                    name="pulse_intensity",
                    function=DingModelIntensityFrequency.set_impulse_intensity,
                    size=n_stim,
                    scaling=VariableScaling("pulse_intensity", [1] * n_stim),
                )
                if isinstance(fixed_pulse_intensity, list):
                    parameters_bounds.add(
                        "pulse_intensity",
                        min_bound=np.array(fixed_pulse_intensity),
                        max_bound=np.array(fixed_pulse_intensity),
                        interpolation=InterpolationType.CONSTANT,
                    )
                    parameters_init.add(key="pulse_intensity", initial_guess=np.array(fixed_pulse_intensity))
                else:
                    parameters_bounds.add(
                        "pulse_intensity",
                        min_bound=np.array([fixed_pulse_intensity] * n_stim),
                        max_bound=np.array([fixed_pulse_intensity] * n_stim),
                        interpolation=InterpolationType.CONSTANT,
                    )
                    parameters_init["pulse_intensity"] = np.array([fixed_pulse_intensity] * n_stim)

            elif pulse_intensity_min is not None and pulse_intensity_max is not None:
                parameters_bounds.add(
                    "pulse_intensity",
                    min_bound=[pulse_intensity_min],
                    max_bound=[pulse_intensity_max],
                    interpolation=InterpolationType.CONSTANT,
                )
                intensity_avg = (pulse_intensity_min + pulse_intensity_max) / 2
                parameters_init["pulse_intensity"] = np.array([intensity_avg] * n_stim)
                parameters.add(
                    name="pulse_intensity",
                    function=DingModelIntensityFrequency.set_impulse_intensity,
                    size=n_stim,
                    scaling=VariableScaling("pulse_intensity", [1] * n_stim),
                )

            if pulse_intensity_bimapping is True:
                for i in range(1, n_stim):
                    constraints.add(CustomConstraint.equal_to_first_pulse_intensity, node=Node.START, target=0, phase=i)

        return parameters, parameters_bounds, parameters_init, parameter_objectives, constraints

    @staticmethod
    def _declare_dynamics(models, n_stim):
        dynamics = DynamicsList()
        for i in range(n_stim):
            dynamics.add(
                models[i].declare_ding_variables,
                dynamic_function=models[i].dynamics,
                expand_dynamics=True,
                expand_continuity=False,
                phase=i,
                phase_dynamics=PhaseDynamics.SHARED_DURING_THE_PHASE,
            )

        return dynamics

    @staticmethod
    def _set_bounds(model, n_stim):
        # ---- STATE BOUNDS REPRESENTATION ---- #
        #
        #                    |‾‾‾‾‾‾‾‾‾‾x_max_middle‾‾‾‾‾‾‾‾‾‾‾‾x_max_end‾
        #                    |          max_bounds              max_bounds
        #    x_max_start     |
        #   _starting_bounds_|
        #   ‾starting_bounds‾|
        #    x_min_start     |
        #                    |          min_bounds              min_bounds
        #                     ‾‾‾‾‾‾‾‾‾‾x_min_middle‾‾‾‾‾‾‾‾‾‾‾‾x_min_end‾

        # Sets the bound for all the phases
        x_bounds = BoundsList()
        variable_bound_list = model.name_dof
        starting_bounds, min_bounds, max_bounds = (
            model.standard_rest_values(),
            model.standard_rest_values(),
            model.standard_rest_values(),
        )

        for i in range(len(variable_bound_list)):
            if variable_bound_list[i] == "Cn" or variable_bound_list[i] == "F":
                max_bounds[i] = 1000
            elif variable_bound_list[i] == "Tau1" or variable_bound_list[i] == "Km":
                max_bounds[i] = 1
            elif variable_bound_list[i] == "A":
                min_bounds[i] = 0

        starting_bounds_min = np.concatenate((starting_bounds, min_bounds, min_bounds), axis=1)
        starting_bounds_max = np.concatenate((starting_bounds, max_bounds, max_bounds), axis=1)
        middle_bound_min = np.concatenate((min_bounds, min_bounds, min_bounds), axis=1)
        middle_bound_max = np.concatenate((max_bounds, max_bounds, max_bounds), axis=1)

        for i in range(n_stim):
            for j in range(len(variable_bound_list)):
                if i == 0:
                    x_bounds.add(
                        variable_bound_list[j],
                        min_bound=np.array([starting_bounds_min[j]]),
                        max_bound=np.array([starting_bounds_max[j]]),
                        phase=i,
                        interpolation=InterpolationType.CONSTANT_WITH_FIRST_AND_LAST_DIFFERENT,
                    )
                else:
                    x_bounds.add(
                        variable_bound_list[j],
                        min_bound=np.array([middle_bound_min[j]]),
                        max_bound=np.array([middle_bound_max[j]]),
                        phase=i,
                        interpolation=InterpolationType.CONSTANT_WITH_FIRST_AND_LAST_DIFFERENT,
                    )

        x_init = InitialGuessList()
        for i in range(n_stim):
            for j in range(len(variable_bound_list)):
                x_init.add(variable_bound_list[j], model.standard_rest_values()[j], phase=i)

        return x_bounds, x_init

    @staticmethod
    def _set_objective(
        n_stim, n_shooting, force_fourier_coefficient, end_node_tracking, custom_objective, time_min, time_max
    ):
        # Creates the objective for our problem
        objective_functions = ObjectiveList()
        if custom_objective:
            for i in range(len(custom_objective)):
                objective_functions.add(custom_objective[0][i])

        if force_fourier_coefficient is not None:
            for phase in range(n_stim):
                for i in range(n_shooting[phase]):
                    objective_functions.add(
                        CustomObjective.track_state_from_time,
                        custom_type=ObjectiveFcn.Mayer,
                        node=i,
                        fourier_coeff=force_fourier_coefficient,
                        key="F",
                        quadratic=True,
                        weight=1,
                        phase=phase,
                    )

        if end_node_tracking:
            if isinstance(end_node_tracking, int | float):
                objective_functions.add(
                    ObjectiveFcn.Mayer.MINIMIZE_STATE,
                    node=Node.END,
                    key="F",
                    quadratic=True,
                    weight=1,
                    target=end_node_tracking,
                    phase=n_stim - 1,
                )

        if time_min and time_max:
            for i in range(n_stim):
                objective_functions.add(
                    ObjectiveFcn.Mayer.MINIMIZE_TIME,
                    weight=0.001 / n_shooting[i],
                    min_bound=time_min,
                    max_bound=time_max,
                    quadratic=True,
                    phase=i,
                )

        return objective_functions

    @staticmethod
    def _build_phase_parameter(n_stim, final_time, frequency=None, pulse_mode="single", round_down=False):
        pulse_mode_multiplier = 1 if pulse_mode == "single" else 2 if pulse_mode == "doublet" else 3
        if n_stim and frequency:
            final_time = n_stim / frequency / pulse_mode_multiplier

        if final_time and frequency:
            n_stim = final_time * frequency * pulse_mode_multiplier
            if round_down or n_stim.is_integer():
                n_stim = int(n_stim)
            else:
                raise ValueError(
                    "The number of stimulation needs to be integer within the final time t, set round down"
                    "to True or set final_time * frequency to make the result a integer."
                )

        return n_stim, final_time