localstack/localstack

View on GitHub
localstack/services/stepfunctions/backend/test_state/execution.py

Summary

Maintainability
A
1 hr
Test Coverage
from __future__ import annotations

import logging
import threading
from typing import Optional

from localstack.aws.api.stepfunctions import (
    Arn,
    ExecutionStatus,
    InspectionLevel,
    TestExecutionStatus,
    TestStateOutput,
    Timestamp,
)
from localstack.services.stepfunctions.asl.eval.program_state import (
    ProgramEnded,
    ProgramError,
    ProgramState,
)
from localstack.services.stepfunctions.asl.eval.test_state.program_state import (
    ProgramChoiceSelected,
)
from localstack.services.stepfunctions.asl.utils.encoding import to_json_str
from localstack.services.stepfunctions.backend.activity import Activity
from localstack.services.stepfunctions.backend.execution import (
    BaseExecutionWorkerCommunication,
    Execution,
)
from localstack.services.stepfunctions.backend.state_machine import StateMachineInstance
from localstack.services.stepfunctions.backend.test_state.execution_worker import (
    TestStateExecutionWorker,
)

LOG = logging.getLogger(__name__)


class TestStateExecution(Execution):
    exec_worker: Optional[TestStateExecutionWorker]
    next_state: Optional[str]

    class TestCaseExecutionWorkerCommunication(BaseExecutionWorkerCommunication):
        _execution: TestStateExecution

        def terminated(self) -> None:
            exit_program_state: ProgramState = self.execution.exec_worker.env.program_state()
            if isinstance(exit_program_state, ProgramChoiceSelected):
                self.execution.exec_status = ExecutionStatus.SUCCEEDED
                self.execution.output = self.execution.exec_worker.env.inp
                self.execution.next_state = exit_program_state.next_state_name
            else:
                self._reflect_execution_status()

    def __init__(
        self,
        name: str,
        role_arn: Arn,
        exec_arn: Arn,
        account_id: str,
        region_name: str,
        state_machine: StateMachineInstance,
        start_date: Timestamp,
        activity_store: dict[Arn, Activity],
        input_data: Optional[dict] = None,
    ):
        super().__init__(
            name=name,
            role_arn=role_arn,
            exec_arn=exec_arn,
            account_id=account_id,
            region_name=region_name,
            state_machine=state_machine,
            start_date=start_date,
            activity_store=activity_store,
            input_data=input_data,
            trace_header=None,
        )
        self._execution_terminated_event = threading.Event()
        self.next_state = None

    def _get_start_execution_worker_comm(self) -> BaseExecutionWorkerCommunication:
        return self.TestCaseExecutionWorkerCommunication(self)

    def _get_start_execution_worker(self) -> TestStateExecutionWorker:
        return TestStateExecutionWorker(
            definition=self.state_machine.definition,
            input_data=self.input_data,
            exec_comm=self._get_start_execution_worker_comm(),
            context_object_init=self._get_start_context_object_init_data(),
            aws_execution_details=self._get_start_aws_execution_details(),
            activity_store=self._activity_store,
        )

    def publish_execution_status_change_event(self):
        # Do not publish execution status change events during test state execution.
        pass

    def to_test_state_output(self, inspection_level: InspectionLevel) -> TestStateOutput:
        exit_program_state: ProgramState = self.exec_worker.env.program_state()
        if isinstance(exit_program_state, ProgramEnded):
            output_str = to_json_str(self.output)
            test_state_output = TestStateOutput(
                status=TestExecutionStatus.SUCCEEDED, output=output_str
            )
        elif isinstance(exit_program_state, ProgramError):
            test_state_output = TestStateOutput(
                status=TestExecutionStatus.FAILED,
                error=exit_program_state.error["error"],
                cause=exit_program_state.error["cause"],
            )
        elif isinstance(exit_program_state, ProgramChoiceSelected):
            output_str = to_json_str(self.output)
            test_state_output = TestStateOutput(
                status=TestExecutionStatus.SUCCEEDED, nextState=self.next_state, output=output_str
            )
        else:
            # TODO: handle other statuses
            LOG.warning(
                f"Unsupported StateMachine exit type for TestState '{type(exit_program_state)}'"
            )
            output_str = to_json_str(self.output)
            test_state_output = TestStateOutput(
                status=TestExecutionStatus.FAILED, output=output_str
            )

        match inspection_level:
            case InspectionLevel.TRACE:
                test_state_output["inspectionData"] = self.exec_worker.env.inspection_data
            case InspectionLevel.DEBUG:
                test_state_output["inspectionData"] = self.exec_worker.env.inspection_data

        return test_state_output