Kev1CO/cocofest

View on GitHub
cocofest/identification/ding2003_force_parameter_identification.py

Summary

Maintainability
C
7 hrs
Test Coverage
import time as time_package

from bioptim import Solver, Objective, OdeSolver

from ..models.fes_model import FesModel
from ..models.ding2003 import DingModelFrequency
from ..optimization.fes_identification_ocp import OcpFesId
from .identification_method import (
    full_data_extraction,
    average_data_extraction,
    sparse_data_extraction,
    node_shooting_list_creation,
    force_at_node_in_ocp,
)
from .identification_abstract_class import ParameterIdentification


class DingModelFrequencyForceParameterIdentification(ParameterIdentification):
    """
    This class is responsible for identifying parameters of the Ding model using force data.
    It supports identification on full data and average data (work in progress : sparse data).
    """

    def __init__(
        self,
        model: DingModelFrequency,
        data_path: str | list[str] = None,
        identification_method: str = "full",
        double_step_identification: bool = False,
        key_parameter_to_identify: list = None,
        additional_key_settings: dict = None,
        n_shooting: int = 5,
        custom_objective: list[Objective] = None,
        use_sx: bool = True,
        ode_solver: OdeSolver = OdeSolver.RK4(n_integration_steps=1),
        n_threads: int = 1,
        **kwargs,
    ):
        """
        Parameters
        ----------
        model: DingModelFrequency,
            The model for identification
        data_path: str | list[str],
            The path to the force model data
        identification_method: str,
            The method to use for the force model identification,
             "full" for objective function on all data,
             "average" for objective function on average data,
             "sparse" for objective function at the beginning and end of the data
        double_step_identification: bool,
            If True, the identification will be done in two steps, the first step will be used to set the initial guess
        key_parameter_to_identify: list,
            The list of parameters to identify
        additional_key_settings: dict,
            additional_key_settings will enable to modify identified parameters default parameters such as initial guess,
            min_bound, max_bound, function and scaling
        n_shooting: int,
            The number of shooting points for the ocp
        custom_objective: list[Objective],
            The custom objective to use for the identification
        use_sx: bool
            The nature of the casadi variables. MX are used if False.
        ode_solver: OdeSolver,
            The ode solver to use for the identification
        n_threads: int,
            The number of threads to use for the identification
        """

        self.default_values = self._set_default_values(model=model)

        dict_parameter_to_configure = model.identifiable_parameters
        model_parameters_value = [
            None if key in key_parameter_to_identify else dict_parameter_to_configure[key]
            for key in dict_parameter_to_configure
        ]
        self.model = self._set_model_parameters(model, model_parameters_value)

        self.input_sanity(
            model,
            data_path,
            identification_method,
            double_step_identification,
            key_parameter_to_identify,
            additional_key_settings,
            n_shooting,
        )

        self.key_parameter_to_identify = key_parameter_to_identify
        self.additional_key_settings = self.key_setting_to_dictionary(key_settings=additional_key_settings)

        self.data_path = data_path
        self.force_model_identification_method = identification_method
        self.double_step_identification = double_step_identification

        self.force_ocp = None
        self.force_identification_result = None
        self.n_shooting = n_shooting
        self.custom_objective = custom_objective
        self.use_sx = use_sx
        self.ode_solver = ode_solver
        self.n_threads = n_threads
        self.kwargs = kwargs

    def _set_default_values(self, model):
        """
        This method is used to set the default values for the identified parameters (initial guesses, bounds, scaling and
        function).
        If the user does not provide additional_key_settings for a specific parameter, the default value will be used.

        Parameters
        ----------
        model

        Returns
        -------

        """
        return {
            "a_rest": {
                "initial_guess": 1000,
                "min_bound": 1,
                "max_bound": 10000,
                "function": model.set_a_rest,
                "scaling": 1,
            },
            "km_rest": {
                "initial_guess": 0.5,
                "min_bound": 0.001,
                "max_bound": 1,
                "function": model.set_km_rest,
                "scaling": 1,  # 1000
            },
            "tau1_rest": {
                "initial_guess": 0.5,
                "min_bound": 0.0001,
                "max_bound": 1,
                "function": model.set_tau1_rest,
                "scaling": 1,  # 1000
            },
            "tau2": {
                "initial_guess": 0.5,
                "min_bound": 0.0001,
                "max_bound": 1,
                "function": model.set_tau2,
                "scaling": 1,  # 1000
            },
        }

    def _set_default_parameters_list(self):
        """
        This method is used to set the default parameters list for the model.
        """
        self.numeric_parameters = [self.model.a_rest, self.model.km_rest, self.model.tau1_rest, self.model.tau2]
        self.key_parameters = ["a_rest", "km_rest", "tau1_rest", "tau2"]

    def input_sanity(
        self,
        model: FesModel = None,
        data_path: str | list[str] = None,
        identification_method: str = None,
        double_step_identification: bool = None,
        key_parameter_to_identify: list = None,
        additional_key_settings: dict = None,
        n_shooting: int = None,
    ):
        """
        This method is used to check the input sanity entered from the user.

        Parameters
        ----------
        model: FesModel,
            The model to use for the identification process
        data_path: str | list[str],
            The path to the force model data
        identification_method: str,
            The method to use for the force model identification,
             "full" for objective function on all data,
             "average" for objective function on average data,
             "sparse" for objective function at the beginning and end of the data
        double_step_identification: bool,
            If True, the identification will be done in two steps, the first step will be used to set the initial guess
        key_parameter_to_identify: list,
            The list of parameters to identify
        additional_key_settings: dict,
            additional_key_settings will enable to modify identified parameters default parameters such as initial guess,
            min_bound, max_bound, function and scaling
        n_shooting: int,
            The number of shooting points for the ocp
        """

        if model._with_fatigue:
            raise ValueError(
                f"The given model is not valid and should not be including the fatigue equation in the model"
            )
        self.check_experiment_force_format(data_path)

        if identification_method not in ["full", "average", "sparse"]:
            raise ValueError(
                f"The given model identification method is not valid,"
                f"only 'full', 'average' and 'sparse' are available,"
                f" the given value is {identification_method}"
            )

        if not isinstance(double_step_identification, bool):
            raise TypeError(
                f"The given double_step_identification must be bool type,"
                f" the given value is {type(double_step_identification)} type"
            )

        if isinstance(key_parameter_to_identify, list):
            for key in key_parameter_to_identify:
                if key not in self.default_values:
                    raise ValueError(
                        f"The given key_parameter_to_identify is not valid,"
                        f" the given value is {key},"
                        f" the available values are {list(self.default_values.keys())}"
                    )
        else:
            raise TypeError(
                f"The given key_parameter_to_identify must be list type,"
                f" the given value is {type(key_parameter_to_identify)} type"
            )

        if isinstance(additional_key_settings, dict):
            for key in additional_key_settings:
                if key not in self.default_values:
                    raise ValueError(
                        f"The given additional_key_settings is not valid,"
                        f" the given value is {key},"
                        f" the available values are {list(self.default_values.keys())}"
                    )
                for setting_name in additional_key_settings[key]:
                    if setting_name not in self.default_values[key]:
                        raise ValueError(
                            f"The given additional_key_settings is not valid,"
                            f" the given value is {setting_name},"
                            f" the available values are {list(self.default_values[key].keys())}"
                        )
                    if not isinstance(
                        additional_key_settings[key][setting_name], type(self.default_values[key][setting_name])
                    ):
                        raise TypeError(
                            f"The given additional_key_settings value is not valid,"
                            f" the given value is {type(additional_key_settings[key][setting_name])},"
                            f" the available type is {type(self.default_values[key][setting_name])}"
                        )
        else:
            raise TypeError(
                f"The given additional_key_settings must be dict type,"
                f" the given value is {type(additional_key_settings)} type"
            )

        if not isinstance(n_shooting, int):
            raise TypeError(f"The given n_shooting must be int type," f" the given value is {type(n_shooting)} type")

        self._set_default_parameters_list()
        if not all(isinstance(param, None | int | float) for param in self.numeric_parameters):
            raise ValueError(f"The given model parameters are not valid, only None, int and float are accepted")

    def key_setting_to_dictionary(self, key_settings):
        """
        This method is used to set the identified parameter optimization values (initial guesses, bounds,
        scaling and function). The default values can be modified by the user when sending the "additional_key_settings"
        input into class. If the user does not provide a value for a specific parameter, the default value will be used.

        Parameters
        ----------
        key_settings: dict,
            The settings attributed from user for parameter to identify

        Returns
        -------
        This function will return a dictionary of dictionaries which contains the identified keys with its associated settings.

        """
        settings_dict = {}
        for key in self.key_parameter_to_identify:
            settings_dict[key] = {}
            for setting_name in self.default_values[key]:
                settings_dict[key][setting_name] = (
                    key_settings[key][setting_name]
                    if (key in key_settings and setting_name in key_settings[key])
                    else self.default_values[key][setting_name]
                )
        return settings_dict

    @staticmethod
    def check_experiment_force_format(data_path):
        if isinstance(data_path, list):
            for i in range(len(data_path)):
                if not isinstance(data_path[i], str):
                    raise TypeError(
                        f"In the given list, all model_data_path must be str type," f" path index n°{i} is not str type"
                    )
                if not data_path[i].endswith(".pkl"):
                    raise TypeError(
                        f"In the given list, all model_data_path must be pickle type and end with .pkl,"
                        f" path index n°{i} is not ending with .pkl"
                    )
        elif isinstance(data_path, str):
            data_path = [data_path]
            if not data_path[0].endswith(".pkl"):
                raise TypeError(
                    f"In the given list, all model_data_path must be pickle type and end with .pkl,"
                    f" path index is not ending with .pkl"
                )
        else:
            raise TypeError(
                f"In the given path, model_data_path must be str or list[str] type, the input is {type(data_path)} type"
            )

    @staticmethod
    def _set_model_parameters(model, model_parameters_value):
        model.a_rest = model_parameters_value[0]
        model.km_rest = model_parameters_value[1]
        model.tau1_rest = model_parameters_value[2]
        model.tau2 = model_parameters_value[3]
        return model

    def _force_model_identification_for_initial_guess(self):
        self.input_sanity(
            self.model,
            self.data_path,
            self.force_model_identification_method,
            self.double_step_identification,
            self.key_parameter_to_identify,
            self.additional_key_settings,
            self.n_shooting,
        )
        self.check_experiment_force_format(self.data_path)
        # --- Data extraction --- #
        # --- Force model --- #
        stimulated_n_shooting = self.n_shooting
        force_curve_number = None

        time, stim, force, discontinuity = average_data_extraction(self.data_path)
        n_shooting, final_time_phase = node_shooting_list_creation(stim, stimulated_n_shooting)
        force_at_node = force_at_node_in_ocp(time, force, n_shooting, final_time_phase, force_curve_number)

        # --- Building force ocp --- #
        self.force_ocp = OcpFesId.prepare_ocp(
            model=self.model,
            n_shooting=n_shooting,
            final_time_phase=final_time_phase,
            force_tracking=force_at_node,
            key_parameter_to_identify=self.key_parameter_to_identify,
            additional_key_settings=self.additional_key_settings,
            custom_objective=self.custom_objective,
            discontinuity_in_ocp=discontinuity,
            use_sx=self.use_sx,
            ode_solver=self.ode_solver,
            n_threads=self.n_threads,
        )

        self.force_identification_result = self.force_ocp.solve(
            Solver.IPOPT()
        )  # _hessian_approximation="limited-memory"

        initial_guess = {}
        for key in self.key_parameter_to_identify:
            initial_guess[key] = self.force_identification_result.parameters[key][0][0]

        return initial_guess

    def force_model_identification(self):
        if not self.double_step_identification:
            self.input_sanity(
                self.model,
                self.data_path,
                self.force_model_identification_method,
                self.double_step_identification,
                self.key_parameter_to_identify,
                self.additional_key_settings,
                self.n_shooting,
            )
            self.check_experiment_force_format(self.data_path)

        # --- Data extraction --- #
        # --- Force model --- #
        stimulated_n_shooting = self.n_shooting
        force_curve_number = None
        time, stim, force, discontinuity = None, None, None, None

        if self.force_model_identification_method == "full":
            time, stim, force, discontinuity = full_data_extraction(self.data_path)

        elif self.force_model_identification_method == "average":
            time, stim, force, discontinuity = average_data_extraction(self.data_path)

        elif self.force_model_identification_method == "sparse":
            force_curve_number = self.kwargs["force_curve_number"] if "force_curve_number" in self.kwargs else 5
            time, stim, force, discontinuity = sparse_data_extraction(self.data_path, force_curve_number)

        n_shooting, final_time_phase = node_shooting_list_creation(stim, stimulated_n_shooting)
        force_at_node = force_at_node_in_ocp(time, force, n_shooting, final_time_phase, force_curve_number)

        if self.double_step_identification:
            initial_guess = self._force_model_identification_for_initial_guess()

            for key in self.key_parameter_to_identify:
                self.additional_key_settings[key]["initial_guess"] = initial_guess[key]

        # --- Building force ocp --- #
        start_time = time_package.time()
        self.force_ocp = OcpFesId.prepare_ocp(
            model=self.model,
            n_shooting=n_shooting,
            final_time_phase=final_time_phase,
            force_tracking=force_at_node,
            key_parameter_to_identify=self.key_parameter_to_identify,
            additional_key_settings=self.additional_key_settings,
            custom_objective=self.custom_objective,
            discontinuity_in_ocp=discontinuity,
            use_sx=self.use_sx,
            ode_solver=self.ode_solver,
            n_threads=self.n_threads,
        )

        print(f"OCP creation time : {time_package.time() - start_time} seconds")

        self.force_identification_result = self.force_ocp.solve(Solver.IPOPT(_max_iter=1000))

        identified_parameters = {}
        for key in self.key_parameter_to_identify:
            identified_parameters[key] = self.force_identification_result.parameters[key][0]

        return identified_parameters