endremborza/atqo

View on GitHub
atqo/distributed_apis.py

Summary

Maintainability
A
0 mins
Test Coverage
import multiprocessing as mp
from asyncio import Future, wrap_future
from concurrent.futures import ProcessPoolExecutor
from multiprocessing.managers import SyncManager
from threading import Thread
from typing import TYPE_CHECKING

from structlog import get_logger
from tblib import Traceback

from .bases import ActorBase, DistAPIBase
from .exceptions import DistantException
from .lock_stores import MpLockStore, ThreadLockStore

if TYPE_CHECKING:
    from .core import SchedulerTask  # pragma: no cover

logger = get_logger()


class SyncAPI(DistAPIBase):
    pass


class MultiProcAPI(DistAPIBase):
    def __init__(self) -> None:
        mp.set_start_method("spawn", force=True)
        self.man = mp.Manager()
        self._lock_queue = self.man.Queue(2)
        self._lock_queue.put(self.man.Lock())
        self._filler = Thread(target=self._fill_lock_q, name="filler", daemon=True)
        self._filler.start()
        self._store = MpLockStore(self.man.Lock(), self.man.dict(), self._lock_queue)

    def get_running_actor(self, actor_creator) -> "ActorBase":
        return MPActorWrap(actor_creator, self.man, self._store)

    @staticmethod
    def get_future(actor: ActorBase, next_task: "SchedulerTask") -> Future:
        return wrap_future(actor.consume(next_task.argument))

    def join(self):
        self.man.shutdown()
        self._filler.join(timeout=0)

    def _fill_lock_q(self):
        try:
            while True:
                _lock = self.man.Lock()
                self._lock_queue.put(_lock)
        except (EOFError, ConnectionResetError, BrokenPipeError, FileNotFoundError):
            pass


class MPActorWrap(ActorBase):
    def __init__(self, actor_creator, man: SyncManager, store):
        self._inner_actor = actor_creator
        self._in_q = man.Queue(maxsize=1)
        self._out_q = man.Queue(maxsize=1)
        self.pool = ProcessPoolExecutor(1)
        _setup_q = mp.Queue()
        _args = (actor_creator, self._in_q, self._out_q, _setup_q, store)
        self.proc = mp.Process(target=_work_mp_actor, args=_args)
        # TODO: warn that this might hang if something is not pickleable or sg
        self.proc.start()
        res = _setup_q.get()
        if isinstance(res, Exception):
            raise res

    def consume(self, task_arg):
        return self.pool.submit(_add_task_mp, task_arg, self._in_q, self._out_q)

    def stop(self):
        self.proc.kill()
        self.proc.join()
        self.pool.shutdown()


def _work_mp_actor(actor_cls, in_q, out_q, setup_q, store):  # pragma: no cover
    global _GLOBAL_LOCK_STORE
    _GLOBAL_LOCK_STORE = store
    try:
        actor = actor_cls()
    except Exception as e:
        setup_q.put(e)
        return
    setup_q.put(0)
    while True:
        arg = in_q.get()
        try:
            res = actor.consume(arg)
        except Exception as e:
            res = DistantException(e, Traceback(e.__traceback__))
        out_q.put(res)


def _add_task_mp(task_arg, in_q: mp.Queue, out_q: mp.Queue):
    in_q.put(task_arg)
    return out_q.get()


DEFAULT_DIST_API_KEY = "sync"
DEFAULT_MULTI_API = "mp"
DIST_API_MAP = {DEFAULT_DIST_API_KEY: SyncAPI, DEFAULT_MULTI_API: MultiProcAPI}

try:
    _GLOBAL_LOCK_STORE
except NameError:
    _GLOBAL_LOCK_STORE = ThreadLockStore()


def get_lock(lock_id):
    return _GLOBAL_LOCK_STORE.get(lock_id)


def acquire_lock(lock_id):
    lock = get_lock(lock_id)
    lock.acquire()
    return lock


def get_dist_api(key) -> "DistAPIBase":
    try:
        return DIST_API_MAP[key]
    except KeyError:
        default = DIST_API_MAP[DEFAULT_DIST_API_KEY]
        err = f"unknown distributed system: {key}, defaulting {default}"
        logger.warning(err)
        return default