Kev1CO/cocofest

View on GitHub
cocofest/identification/ding2007_force_parameter_identification.py

Summary

Maintainability
B
5 hrs
Test Coverage
import time as time_package
import numpy as np

from bioptim import Solver, Objective, OdeSolver
from ..models.ding2007 import DingModelPulseDurationFrequency
from ..identification.ding2003_force_parameter_identification import DingModelFrequencyForceParameterIdentification
from ..optimization.fes_identification_ocp import OcpFesId
from .identification_method import (
    full_data_extraction,
    average_data_extraction,
    sparse_data_extraction,
    node_shooting_list_creation,
    force_at_node_in_ocp,
)


class DingModelPulseDurationFrequencyForceParameterIdentification(DingModelFrequencyForceParameterIdentification):
    """
    This class extends the DingModelFrequencyForceParameterIdentification class and is used to define an optimal control problem (OCP).
    It prepares the full program and provides all the necessary parameters to solve a functional electrical stimulation OCP.
    """

    def __init__(
        self,
        model: DingModelPulseDurationFrequency,
        data_path: str | list[str] = None,
        identification_method: str = "full",
        double_step_identification: bool = False,
        key_parameter_to_identify: list = None,
        additional_key_settings: dict = None,
        n_shooting: int = 5,
        custom_objective: list[Objective] = None,
        use_sx: bool = True,
        ode_solver: OdeSolver = OdeSolver.RK4(n_integration_steps=1),
        n_threads: int = 1,
        **kwargs,
    ):
        """
        Initializes the DingModelPulseDurationFrequencyForceParameterIdentification class.

        Parameters
        ----------
        model: DingModelPulseDurationFrequency
            The model to use for the OCP.
        data_path: str | list[str]
            The path to the force model data.
        identification_method: str
            The method to use for the force model identification. Options are "full" for objective function on all data,
            "average" for objective function on average data, and "sparse" for objective function at the beginning and end of the data.
        double_step_identification: bool
            If True, the identification process will be performed in two steps.
        key_parameter_to_identify: list
            List of keys of the parameters to identify.
        additional_key_settings: dict
            Additional settings for the keys.
        n_shooting: int
            The number of shooting points for the OCP.
        custom_objective: list[Objective]
            List of custom objectives.
        use_sx: bool
            The nature of the CasADi variables. MX are used if False.
        ode_solver: OdeSolver
            The ODE solver to use.
        n_threads: int
            The number of threads to use while solving (multi-threading if > 1).
        **kwargs: dict
            Additional keyword arguments.
        """
        super(DingModelPulseDurationFrequencyForceParameterIdentification, self).__init__(
            model=model,
            data_path=data_path,
            identification_method=identification_method,
            double_step_identification=double_step_identification,
            key_parameter_to_identify=key_parameter_to_identify,
            additional_key_settings=additional_key_settings,
            n_shooting=n_shooting,
            custom_objective=custom_objective,
            use_sx=use_sx,
            ode_solver=ode_solver,
            n_threads=n_threads,
            **kwargs,
        )

    def _set_default_values(self, model):
        """
        Sets the default values for the identified parameters (initial guesses, bounds, scaling and function).
        If the user does not provide additional_key_settings for a specific parameter, the default value will be used.

        Parameters
        ----------
        model: FesModel
            The model to use for the OCP.

        Returns
        -------
        dict
            A dictionary of default values for the identified parameters.
        """
        return {
            "tau1_rest": {
                "initial_guess": 0.5,
                "min_bound": 0.0001,
                "max_bound": 1,
                "function": model.set_tau1_rest,
                "scaling": 1,  # 10000
            },
            "tau2": {
                "initial_guess": 0.5,
                "min_bound": 0.0001,
                "max_bound": 1,
                "function": model.set_tau2,
                "scaling": 1,  # 10000
            },
            "km_rest": {
                "initial_guess": 0.5,
                "min_bound": 0.001,
                "max_bound": 1,
                "function": model.set_km_rest,
                "scaling": 1,  # 10000
            },
            "a_scale": {
                "initial_guess": 5000,
                "min_bound": 1,
                "max_bound": 10000,
                "function": model.set_a_scale,
                "scaling": 1,
            },
            "pd0": {
                "initial_guess": 1e-4,
                "min_bound": 1e-4,
                "max_bound": 6e-4,
                "function": model.set_pd0,
                "scaling": 1,  # 1000
            },
            "pdt": {
                "initial_guess": 1e-4,
                "min_bound": 1e-4,
                "max_bound": 6e-4,
                "function": model.set_pdt,
                "scaling": 1,  # 1000
            },
        }

    def _set_default_parameters_list(self):
        """
        Sets the default parameters list for the model.
        """
        self.numeric_parameters = [
            self.model.tau1_rest,
            self.model.tau2,
            self.model.km_rest,
            self.model.a_scale,
            self.model.pd0,
            self.model.pdt,
        ]
        self.key_parameters = ["tau1_rest", "tau2", "km_rest", "a_scale", "pd0", "pdt"]

    @staticmethod
    def _set_model_parameters(model, model_parameters_value):
        """
        Sets the model parameters.

        Parameters
        ----------
        model: FesModel
            The model to use for the OCP.
        model_parameters_value: list
            List of values for the model parameters.

        Returns
        -------
        FesModel
            The model with updated parameters.
        """
        model.a_scale = model_parameters_value[0]
        model.km_rest = model_parameters_value[1]
        model.tau1_rest = model_parameters_value[2]
        model.tau2 = model_parameters_value[3]
        model.pd0 = model_parameters_value[4]
        model.pdt = model_parameters_value[5]
        return model

    @staticmethod
    def pulse_duration_extraction(data_path: str) -> list[float]:
        """
        Extracts the pulse duration from the data.

        Parameters
        ----------
        data_path: str
            The path to the data.

        Returns
        -------
        list[float]
            A list of pulse durations.
        """
        import pickle

        pulse_duration = []
        for i in range(len(data_path)):
            with open(data_path[i], "rb") as f:
                data = pickle.load(f)
            pulse_duration.append(data["pulse_duration"])
        pulse_duration = [item for sublist in pulse_duration for item in sublist]
        return pulse_duration

    def _force_model_identification_for_initial_guess(self):
        """
        Performs the force model identification for the initial guess.

        Returns
        -------
        dict
            A dictionary of initial guesses for the parameters.
        """
        self.input_sanity(
            self.model,
            self.data_path,
            self.force_model_identification_method,
            self.double_step_identification,
            self.key_parameter_to_identify,
            self.additional_key_settings,
            self.n_shooting,
        )
        self.check_experiment_force_format(self.data_path)
        # --- Data extraction --- #
        # --- Force model --- #
        stimulated_n_shooting = self.n_shooting
        force_curve_number = None

        time, stim, force, discontinuity = average_data_extraction(self.data_path)
        pulse_duration = self.pulse_duration_extraction(self.data_path)
        n_shooting, final_time_phase = node_shooting_list_creation(stim, stimulated_n_shooting)
        force_at_node = force_at_node_in_ocp(time, force, n_shooting, final_time_phase, force_curve_number)

        # --- Building force ocp --- #
        self.force_ocp = OcpFesId.prepare_ocp(
            model=self.model,
            n_shooting=n_shooting,
            final_time_phase=final_time_phase,
            force_tracking=force_at_node,
            custom_objective=self.custom_objective,
            discontinuity_in_ocp=discontinuity,
            pulse_duration=pulse_duration,
            km_rest=self.km_rest,
            tau1_rest=self.tau1_rest,
            tau2=self.tau2,
            use_sx=self.use_sx,
            ode_solver=self.ode_solver,
            n_threads=self.n_threads,
        )

        self.force_identification_result = self.force_ocp.solve(
            Solver.IPOPT()
        )  # _hessian_approximation="limited-memory"

        initial_guess = {}
        for key in self.key_parameter_to_identify:
            initial_guess[key] = self.force_identification_result.parameters[key][0][0]

        return initial_guess

    def force_model_identification(self) -> dict[str, np.ndarray]:
        """
        Performs the force model identification.

        Returns
        -------
        dict[str, np.ndarray]
            A dictionary of identified parameters.
        """

        if not self.double_step_identification:
            self.input_sanity(
                self.model,
                self.data_path,
                self.force_model_identification_method,
                self.double_step_identification,
                self.key_parameter_to_identify,
                self.additional_key_settings,
                self.n_shooting,
            )
            self.check_experiment_force_format(self.data_path)

        # --- Data extraction --- #
        # --- Force model --- #
        stimulated_n_shooting = self.n_shooting
        force_curve_number = None
        stim = None
        time = None
        force = None

        if self.force_model_identification_method == "full":
            time, stim, force, discontinuity = full_data_extraction(self.data_path)
            pulse_duration = self.pulse_duration_extraction(self.data_path)

        elif self.force_model_identification_method == "average":
            time, stim, force, discontinuity = average_data_extraction(self.data_path)
            pulse_duration = np.mean(np.array(self.pulse_duration_extraction(self.data_path)))

        elif self.force_model_identification_method == "sparse":
            force_curve_number = self.kwargs["force_curve_number"] if "force_curve_number" in self.kwargs else 5
            time, stim, force, discontinuity = sparse_data_extraction(self.data_path, force_curve_number)
            pulse_duration = self.pulse_duration_extraction(self.data_path)  # TODO : adapt this for sparse data

        n_shooting, final_time_phase = node_shooting_list_creation(stim, stimulated_n_shooting)
        force_at_node = force_at_node_in_ocp(time, force, n_shooting, final_time_phase, force_curve_number)

        if self.double_step_identification:
            initial_guess = self._force_model_identification_for_initial_guess()

            for key in self.key_parameter_to_identify:
                self.additional_key_settings[key]["initial_guess"] = initial_guess[key]

        # --- Building force ocp --- #
        start_time = time_package.time()
        self.force_ocp = OcpFesId.prepare_ocp(
            model=self.model,
            n_shooting=n_shooting,
            final_time_phase=final_time_phase,
            force_tracking=force_at_node,
            key_parameter_to_identify=self.key_parameter_to_identify,
            additional_key_settings=self.additional_key_settings,
            custom_objective=self.custom_objective,
            discontinuity_in_ocp=discontinuity,
            pulse_duration=pulse_duration,
            use_sx=self.use_sx,
            ode_solver=self.ode_solver,
            n_threads=self.n_threads,
        )

        print(f"OCP creation time : {time_package.time() - start_time} seconds")

        self.force_identification_result = self.force_ocp.solve(Solver.IPOPT())

        identified_parameters = {}
        for key in self.key_parameter_to_identify:
            identified_parameters[key] = self.force_identification_result.parameters[key][0]

        return identified_parameters