Erdnaxela3/bioptim_gui

View on GitHub
api/bioptim_gui_api/acrobatics_ocp/endpoints/acrobatics_phases_modifiers.py

Summary

Maintainability
A
25 mins
Test Coverage
from fastapi import APIRouter, HTTPException

from bioptim_gui_api.acrobatics_ocp.endpoints.acrobatics_requests import (
    CollisionConstraintRequest,
    NbHalfTwistsRequest,
    NbSomersaultsRequest,
    PositionRequest,
    VisualCriteriaRequest,
    WithSpineRequest,
)
from bioptim_gui_api.acrobatics_ocp.endpoints.acrobatics_responses import (
    NbSomersaultsResponse,
    PositionResponse,
    VisualCriteriaResponse,
    CollisionConstraintResponse,
    WithSpineResponse,
)
from bioptim_gui_api.acrobatics_ocp.misc.phase_updating import update_phase_info


class AcrobaticsPhaseModifiers:
    # endpoints that modifies greatly the phases (number, objectives, constraints, variables)
    # that uses the update_phase_info function
    def __init__(self, data):
        self.data = data
        self.router = None

    def register(self, route: APIRouter) -> None:
        self.router = route
        # register additional endpoints
        self.register_put_nb_somersaults()
        self.register_put_nb_half_twists()
        self.register_put_position()
        self.register_put_with_visual_criteria()
        self.register_put_collision_constraint()
        self.register_put_with_spine()

    def register_put_nb_somersaults(self):
        @self.router.put("/nb_somersaults", response_model=NbSomersaultsResponse)
        def update_nb_somersaults(nb_somersaults: NbSomersaultsRequest):
            """
            Append or pop the half_twists list
            Update the number of somersaults of the acrobatics ocp
            Update the phase info of the acrobatics ocp accordingly
            """
            nb_max_somersaults = 5
            new_nb_somersaults = nb_somersaults.nb_somersaults

            # error handling
            old_value = self.data.read_data("nb_somersaults")
            if new_nb_somersaults <= 0 or new_nb_somersaults > nb_max_somersaults:
                raise HTTPException(status_code=400, detail="nb_somersaults must be positive")

            # updating data
            data = self.data.read_data()

            additional_twists = [0] * (new_nb_somersaults - old_value)
            filtered_twists = data["nb_half_twists"][:new_nb_somersaults]
            updated_half_twists = filtered_twists + additional_twists

            self.data.update_data("nb_somersaults", new_nb_somersaults)
            self.data.update_data("nb_half_twists", updated_half_twists)

            # 1 somersault tuck/pike are not allowed, set the position to straight
            if new_nb_somersaults == 1 and data["position"] != "straight":
                self.data.update_data("position", "straight")

            update_phase_info()
            data = self.data.read_data()
            return NbSomersaultsResponse(
                nb_somersaults=data["nb_somersaults"],
                nb_half_twists=data["nb_half_twists"],
                position=data["position"],
                phases_info=data["phases_info"],
                dof_names=data["dof_names"],
            )

    def register_put_nb_half_twists(self):
        @self.router.put("/nb_half_twists/{somersault_index}", response_model=list)
        def put_nb_half_twist(somersault_index: int, half_twists_request: NbHalfTwistsRequest):
            # error handling
            if half_twists_request.nb_half_twists < 0:
                raise HTTPException(status_code=400, detail="nb_half_twists must be positive or zero")

            # updating data
            half_twists = self.data.read_data("nb_half_twists")
            half_twists[somersault_index] = half_twists_request.nb_half_twists
            self.data.update_data("nb_half_twists", half_twists)

            phases = update_phase_info()
            return phases

    def register_put_position(self):
        @self.router.put("/position", response_model=PositionResponse)
        def put_position(position: PositionRequest):
            new_value = position.position.value

            # error handling
            old_value = self.data.read_data("position")
            if old_value == new_value:
                raise HTTPException(
                    status_code=304,
                    detail=f"position is already {position}",
                )

            # updating data
            self.data.update_data("position", new_value)
            nb_somersaults = self.data.read_data("nb_somersaults")

            # 1 somersault tuck/pike are not allowed, set the nb_somersault to 2
            if old_value == "straight" and nb_somersaults == 1:
                half_twists = self.data.read_data("nb_half_twists") + [0]
                self.data.update_data("nb_somersaults", 2)
                self.data.update_data("nb_half_twists", half_twists)

            update_phase_info()
            data = self.data.read_data()
            return PositionResponse(
                nb_somersaults=data["nb_somersaults"],
                nb_half_twists=data["nb_half_twists"],
                position=data["position"],
                phases_info=data["phases_info"],
                dof_names=data["dof_names"],
            )

    def register_put_with_visual_criteria(self):
        @self.router.put("/with_visual_criteria", response_model=VisualCriteriaResponse)
        def put_with_visual_criteria(visual_criteria: VisualCriteriaRequest):
            new_value = visual_criteria.with_visual_criteria

            # error handling
            old_value = self.data.read_data("with_visual_criteria")
            if old_value == new_value:
                raise HTTPException(
                    status_code=304,
                    detail=f"with_visual_criteria is already {old_value}",
                )

            # updating data
            self.data.update_data("with_visual_criteria", new_value)

            update_phase_info()
            data = self.data.read_data()
            return VisualCriteriaResponse(
                with_visual_criteria=data["with_visual_criteria"],
                phases_info=data["phases_info"],
                dof_names=data["dof_names"],
            )

    def register_put_collision_constraint(self):
        @self.router.put("/collision_constraint", response_model=CollisionConstraintResponse)
        def put_collision_constraint(collision_constraint: CollisionConstraintRequest):
            new_value = collision_constraint.collision_constraint

            # error handling
            old_value = self.data.read_data("collision_constraint")
            if old_value == new_value:
                raise HTTPException(
                    status_code=304,
                    detail=f"collision_constraint is already {old_value}",
                )

            # updating data
            self.data.update_data("collision_constraint", new_value)

            update_phase_info()
            data = self.data.read_data()
            return CollisionConstraintResponse(
                collision_constraint=data["collision_constraint"],
                phases_info=data["phases_info"],
            )

    def register_put_with_spine(self):
        @self.router.put("/with_spine", response_model=WithSpineResponse)
        def put_with_spine(with_spine: WithSpineRequest):
            new_value = with_spine.with_spine

            # error handling
            old_value = self.data.read_data("with_spine")
            if old_value == new_value:
                raise HTTPException(
                    status_code=304,
                    detail=f"with_spine is already {old_value}",
                )

            # updating data
            self.data.update_data("with_spine", new_value)
            if new_value:
                self.data.update_data("dynamics", "JOINTS_ACCELERATION_DRIVEN")
            else:
                self.data.update_data("dynamics", "TORQUE_DRIVEN")

            update_phase_info()
            data = self.data.read_data()
            return WithSpineResponse(
                with_spine=data["with_spine"],
                dynamics=data["dynamics"],
                phases_info=data["phases_info"],
                dof_names=data["dof_names"],
            )