LiberTEM/LiberTEM

View on GitHub
src/libertem/web/state.py

Summary

Maintainability
B
4 hrs
Test Coverage
import os
import copy
import typing
import itertools
import logging
import asyncio
import time
import contextlib

import psutil

import libertem
from libertem.api import Context
from libertem.analysis.base import AnalysisResultSet
from libertem.common.executor import JobExecutor
from libertem.common.async_utils import sync_to_async
from libertem.executor.base import AsyncAdapter
from libertem.executor.dask import DaskJobExecutor
from libertem.io.dataset.base import DataSetException, DataSet
from libertem.io.writers.results.base import ResultFormatRegistry
from libertem.io.writers.results import formats  # NOQA
from libertem.udf.base import UDFResults
from libertem.utils import devices
from .models import (
    AnalysisDetails, AnalysisInfo, AnalysisResultInfo, CompoundAnalysisInfo, AnalysisParameters,
    DatasetInfo, JobInfo, SerializedJobInfo,
)
from .helpers import create_executor
from .event_bus import EventBus

log = logging.getLogger(__name__)


class ExecutorState:
    """
    Executor management for the web API. This class is used by `SharedState` to
    manage executor instances and related state.  Executors passed into
    `set_executors` will be taken over, and managed directly by this class.

    Be sure to explicitly call the `shutdown` method to clean up, otherwise
    there may be dangling resources preventing a clean shutdown.
    """
    def __init__(
        self,
        event_bus: EventBus,
        loop: typing.Optional[asyncio.AbstractEventLoop] = None,
        snooze_timeout: typing.Optional[float] = None,
    ):
        self.executor = None
        self.cluster_params = {}
        self.cluster_details: typing.Optional[list] = None
        self.context: typing.Optional[Context] = None
        self._event_bus = event_bus
        self._snooze_timeout = snooze_timeout
        self._pool = AsyncAdapter.make_pool()
        self._is_snoozing = False
        self._last_activity = time.monotonic()
        self._snooze_check_interval = min(
            30.0,
            snooze_timeout and (snooze_timeout * 0.1) or 30.0,
        )
        self._snooze_task = asyncio.ensure_future(self._snooze_check_task())
        if loop is None:
            loop = asyncio.get_event_loop()
        self._loop = loop
        self._keep_alive = 0
        self.local_directory = "dask-worker-space"
        self.preload: tuple[str, ...] = ()

    def set_preload(self, preload: tuple[str, ...]) -> None:
        self.preload = preload

    def get_preload(self) -> tuple[str, ...]:
        return self.preload

    def set_local_directory(self, local_directory: str) -> None:
        if local_directory is not None:
            self.local_directory = local_directory

    def get_local_directory(self):
        return self.local_directory

    def _update_last_activity(self):
        self._last_activity = time.monotonic()
        log.debug("_update_last_activity")

    @contextlib.contextmanager
    def keep_alive(self):
        self._update_last_activity()
        self._keep_alive += 1
        try:
            yield
        finally:
            self._keep_alive -= 1
            self._update_last_activity()

    async def snooze(self):
        if self.executor is None:
            log.debug("not snoozing: no executor")
            return
        if self._keep_alive > 0:
            log.debug("not snoozing: _keep_alive=%d", self._keep_alive)
            return
        log.info("Snoozing...")
        from .messages import Message
        msg = Message().snooze("snoozing")
        self._event_bus.send(msg)
        self.executor.ensure_sync().close()
        self.context = None
        self.executor = None
        self._is_snoozing = True

    async def unsnooze(self):
        if not self._is_snoozing:
            return
        log.info("Unsnoozing...")
        from .messages import Message
        msg = Message().unsnooze("unsnoozing")
        self._event_bus.send(msg)
        executor = await self.make_executor(self.cluster_params, self._pool)
        self._set_executor(executor, self.cluster_params)
        self._is_snoozing = False
        msg = Message().unsnooze_done("unsnooze done")
        self._event_bus.send(msg)

    def is_snoozing(self):
        return self._is_snoozing

    async def _snooze_check_task(self):
        """
        Periodically check if we need to snooze the executor
        """
        try:
            while True:
                await asyncio.sleep(self._snooze_check_interval)
                if not self._is_snoozing and self._snooze_timeout is not None:
                    since_last_activity = time.monotonic() - self._last_activity
                    if since_last_activity > self._snooze_timeout:
                        await self.snooze()
        except asyncio.CancelledError:
            log.debug("shutting down snooze check task")
            return

    async def make_executor(self, params, pool) -> AsyncAdapter:
        connection = params['connection']
        if connection["type"].lower() == "tcp":
            sync_executor = await sync_to_async(
                DaskJobExecutor.connect,
                pool=self._pool,
                scheduler_uri=connection['address'],
            )
        elif connection["type"].lower() == "local":
            sync_executor = await sync_to_async(
                create_executor,
                pool=self._pool,
                connection=connection,
                local_directory=self.get_local_directory(),
                preload=self.get_preload(),
            )
        else:
            raise ValueError("unknown connection type")
        executor = AsyncAdapter(wrapped=sync_executor, pool=pool)
        return executor

    async def get_executor(self):
        self._update_last_activity()
        await self.unsnooze()
        if self.executor is None:
            # TODO: exception type, conversion into 400 response
            raise RuntimeError("wrong state: executor is None")
        return self.executor

    def have_executor(self):
        return self.executor is not None or self._is_snoozing

    async def get_resource_details(self):
        # memoize the cluster details, if ever we support
        # dynamic resources this will need to change
        if self.cluster_details is None:
            executor = await self.get_executor()
            self.cluster_details = await executor.get_resource_details()
        return self.cluster_details

    async def get_context(self) -> Context:
        await self.unsnooze()
        self._update_last_activity()
        if self.context is None:
            raise RuntimeError("cannot get context, please call `set_executor` before")
        return self.context

    def shutdown(self):
        self._loop.call_soon_threadsafe(self._snooze_task.cancel)
        if self.context is not None:
            self.context.close()

    async def set_executor(self, executor: JobExecutor, params):
        """
        Set the new executor used to run jobs, and the parameters used
        to create/connect.

        After the executor is set, we take "ownership" of it, and ensure
        that it is properly cleaned up in the `shutdown` method.
        """
        self._update_last_activity()
        if self.executor is not None:
            await self.executor.close()
            self.executor = None
        self._set_executor(executor, params)

    def _set_executor(self, executor: JobExecutor, params):
        if self.executor is not None:
            self.executor.ensure_sync().close()
        self.executor = executor
        self.cluster_params = params
        self.context = Context(executor=executor.ensure_sync())
        self._is_snoozing = False  # if we were snoozing before, we no longer are
        try:
            # Exposing the scheduler address allows use of
            # libertem-server to spin up a LT-compatible
            # persistent dask cluster
            log.info(f'Scheduler at {self.context.executor.client.scheduler.address}')
        except AttributeError:
            pass

    def get_cluster_params(self):
        self._update_last_activity()
        return self.cluster_params


class AnalysisState:
    def __init__(self, executor_state: ExecutorState, job_state: 'JobState'):
        self.analyses: dict[str, AnalysisInfo] = {}
        self.results: dict[str, AnalysisResultInfo] = {}
        self.job_state = job_state

    def create(
        self, uuid: str, dataset_uuid: str, analysis_type: str, parameters: AnalysisParameters
    ) -> None:
        assert uuid not in self.analyses
        self.analyses[uuid] = {
            "dataset": dataset_uuid,
            "analysis": uuid,
            "jobs": [],
            "details": {
                "analysisType": analysis_type,
                "parameters": parameters,
            },
        }

    def add_job(self, analysis_id: str, job_id: str) -> None:
        jobs = self.analyses[analysis_id]["jobs"]
        jobs.append(job_id)

    def update(self, uuid: str, analysis_type: str, parameters: AnalysisParameters) -> None:
        self.analyses[uuid]["details"]["parameters"] = parameters
        self.analyses[uuid]["details"]["analysisType"] = analysis_type

    def get(
        self, uuid: str, default: typing.Optional[AnalysisInfo] = None
    ) -> typing.Optional[AnalysisInfo]:
        return self.analyses.get(uuid, default)

    def filter(self, predicate: typing.Callable[[AnalysisInfo], bool]) -> list[AnalysisInfo]:
        return [
            analysis
            for analysis in self.analyses.values()
            if predicate(analysis)
        ]

    async def remove(self, uuid: str) -> bool:
        if uuid not in self.analyses:
            return False
        if uuid in self.results:
            self.remove_results(uuid)
        await self.remove_jobs(uuid)
        del self.analyses[uuid]
        return True

    async def remove_jobs(self, uuid: str) -> None:
        jobs = copy.copy(self.job_state.get_for_analysis_id(uuid))
        for job_id in jobs:
            await self.job_state.remove(job_id)

    def remove_results(self, uuid: str) -> None:
        del self.results[uuid]

    def set_results(
        self,
        analysis_id: str,
        details: AnalysisDetails,
        results: AnalysisResultSet,
        job_id: str,
        udf_results: typing.Optional[UDFResults],
    ) -> None:
        """
        set results: create or update
        """
        self.results[analysis_id] = AnalysisResultInfo(
            copy.deepcopy(details), results, job_id, udf_results
        )

    def have_results(self, analysis_id: str) -> bool:
        return analysis_id in self.results

    def get_results(self, analysis_id: str) -> AnalysisResultInfo:
        return self.results[analysis_id]

    def get_all_results(self) -> typing.ItemsView[str, AnalysisResultInfo]:
        return self.results.items()

    def __getitem__(self, analysis_id: str) -> AnalysisInfo:
        return self.analyses[analysis_id]

    def serialize(self, analysis_id: str) -> AnalysisInfo:
        result = copy.copy(self[analysis_id])
        result["jobs"] = [
            job_id
            for job_id in result["jobs"]
            if not self.job_state.is_cancelled(job_id)
        ]
        return result

    def serialize_all(self) -> list[AnalysisInfo]:
        return [
            self.serialize(analysis_id)
            for analysis_id in self.analyses
        ]


class CompoundAnalysisState:
    def __init__(self, analysis_state: AnalysisState):
        self.analysis_state = analysis_state
        self.analyses: dict[str, CompoundAnalysisInfo] = {}

    def create_or_update(
        self, uuid: str, main_type: str, dataset_id: str, analyses: list[str]
    ) -> bool:
        created = uuid not in self.analyses
        self.analyses[uuid] = {
            "dataset": dataset_id,
            "compoundAnalysis": uuid,
            "details": {
                "mainType": main_type,
                "analyses": analyses,
            }
        }
        return created

    def remove(self, uuid: str) -> None:
        del self.analyses[uuid]

    def __getitem__(self, uuid: str) -> CompoundAnalysisInfo:
        return self.analyses[uuid]

    def filter(
        self, predicate: typing.Callable[[CompoundAnalysisInfo], bool]
    ) -> list[CompoundAnalysisInfo]:
        return [
            ca
            for ca in self.analyses.values()
            if predicate(ca)
        ]

    def serialize(self, uuid: str) -> CompoundAnalysisInfo:
        return self[uuid]

    def serialize_all(self) -> list[CompoundAnalysisInfo]:
        return [
            self.serialize(uuid)
            for uuid in self.analyses
        ]


class DatasetState:
    def __init__(self, executor_state: ExecutorState, analysis_state: AnalysisState,
                 compound_analysis_state: CompoundAnalysisState):
        self.datasets: dict[str, DatasetInfo] = {}
        self.dataset_to_id: dict[DataSet, str] = {}
        self.executor_state = executor_state
        self.analysis_state = analysis_state
        self.compound_analysis_state = compound_analysis_state

    def register(
        self, uuid: str, dataset: DataSet, params: dict, converted: dict
    ):
        assert uuid not in self.datasets
        self.datasets[uuid] = {
            "dataset": dataset,
            "params": params,
            "converted": converted,
        }
        self.dataset_to_id[dataset] = uuid
        return self

    async def serialize(self, dataset_id):
        executor = await self.executor_state.get_executor()
        dataset = self.datasets[dataset_id]
        diag = await executor.run_function(lambda: dataset["dataset"].diagnostics)
        return {
            "id": dataset_id,
            "params": {
                **dataset["params"]["params"],
                "shape": tuple(dataset["dataset"].shape),
            },
            "diagnostics": diag,
        }

    async def serialize_all(self):
        return [
            await self.serialize(dataset_id)
            for dataset_id in self.datasets.keys()
        ]

    def id_for_dataset(self, dataset):
        return self.dataset_to_id[dataset]

    def __getitem__(self, uuid):
        return self.datasets[uuid]["dataset"]

    def __contains__(self, uuid):
        return uuid in self.datasets

    async def verify(self):
        executor = await self.executor_state.get_executor()
        for uuid, params in self.datasets.items():
            dataset = params["dataset"]
            try:
                await executor.run_function(dataset.check_valid)
            except DataSetException:
                await self.remove_dataset(uuid)

    async def remove(self, uuid):
        """
        Stop all jobs and remove dataset state for the dataset identified by `uuid`
        """
        ds = self.datasets[uuid]["dataset"]
        analyses = self.analysis_state.filter(lambda a: a["dataset"] == uuid)
        compound_analyses = self.compound_analysis_state.filter(lambda ca: ca["dataset"] == uuid)
        del self.datasets[uuid]
        del self.dataset_to_id[ds]
        for analysis in analyses:
            await self.analysis_state.remove(analysis["analysis"])
        for ca in compound_analyses:
            self.compound_analysis_state.remove(ca["compoundAnalysis"])


class JobState:
    def __init__(self, executor_state: ExecutorState):
        self.jobs: dict[str, JobInfo] = {}
        self.executor_state = executor_state
        self.jobs_for_dataset = typing.DefaultDict[str, set[str]](lambda: set())
        self.jobs_for_analyses = typing.DefaultDict[str, set[str]](lambda: set())

    def register(self, job_id: str, analysis_id, dataset_id):
        assert job_id not in self.jobs
        self.jobs[job_id] = {
            "id": job_id,
            "analysis": analysis_id,
            "dataset": dataset_id,
        }
        self.jobs_for_dataset[dataset_id].add(job_id)
        self.jobs_for_analyses[analysis_id].add(job_id)
        return self

    async def remove(self, uuid: str) -> bool:
        try:
            executor = await self.executor_state.get_executor()
            await executor.cancel(uuid)
            del self.jobs[uuid]
            for ds, jobs in itertools.chain(self.jobs_for_dataset.items(),
                                            self.jobs_for_analyses.items()):
                if uuid in jobs:
                    jobs.remove(uuid)
            return True
        except KeyError:
            return False

    def get_for_dataset_id(self, dataset_id: str) -> set[str]:
        return self.jobs_for_dataset[dataset_id]

    def get_for_analysis_id(self, analysis_id: str) -> set[str]:
        return self.jobs_for_analyses[analysis_id]

    def __getitem__(self, uuid: str) -> JobInfo:
        return self.jobs[uuid]

    def __contains__(self, uuid: str) -> bool:
        return uuid in self.jobs

    def is_cancelled(self, uuid: str) -> bool:
        return uuid not in self.jobs

    def serialize(self, job_id: str) -> SerializedJobInfo:
        job = self[job_id]
        return {
            "id": job["id"],
            "analysis": job["analysis"],
        }

    def serialize_all(self) -> list[SerializedJobInfo]:
        return [
            self.serialize(job_id)
            for job_id in self.jobs.keys()
        ]


class SharedState:
    def __init__(self, executor_state: "ExecutorState"):
        self.executor_state = executor_state
        self.job_state = JobState(self.executor_state)
        self.analysis_state = AnalysisState(self.executor_state, job_state=self.job_state)
        self.compound_analysis_state = CompoundAnalysisState(self.analysis_state)
        self.dataset_state = DatasetState(
            self.executor_state,
            analysis_state=self.analysis_state,
            compound_analysis_state=self.compound_analysis_state,
        )

    def get_local_cores(self, default: int = 2) -> int:
        cores: typing.Optional[int] = psutil.cpu_count(logical=False)
        if cores is None:
            cores = default
        return cores

    def get_ds_type_info(self, ds_type_id: str):
        from libertem.io.dataset import get_dataset_cls
        cls = get_dataset_cls(ds_type_id)
        ConverterCls = cls.get_msg_converter()
        converter = ConverterCls()
        schema = converter.SCHEMA
        supported_backends = cls.get_supported_io_backends()
        default_backend = cls.get_default_io_backend().id_
        if not supported_backends:
            default_backend = None
        return {
            "schema": schema,
            "default_io_backend": default_backend,
            "supported_io_backends": supported_backends,
        }

    def get_config(self):
        from libertem.io.dataset import filetypes
        detected_devices = devices.detect()
        ds_types = list(filetypes.keys())
        return {
            "version": libertem.__version__,
            "resultFileFormats": ResultFormatRegistry.get_available_formats(),
            "revision": libertem.revision,
            "localCores": self.get_local_cores(),
            "devices": detected_devices,
            "datasetTypes": {
                ds_type_id.upper(): self.get_ds_type_info(ds_type_id)
                for ds_type_id in ds_types
            },
            "cwd": os.getcwd(),
            # '/' works on Windows, too.
            "separator": '/'
        }

    async def create_and_set_executor(self, spec: dict[str, int]):
        """
        Create a new executor from spec, a dict[str, int]
        compatible with the main arguments of cluster_spec().
        Any values not in spec are filled from a call to detect()

        Any existing executor will first closed by the call
        to self.executor_state._set_executor
        """
        from .helpers import create_executor_external  # circular import
        executor, params = create_executor_external(
            spec,
            self.executor_state.get_local_directory(),
            self.executor_state.get_preload(),
        )
        self.executor_state._set_executor(executor, params)