endremborza/atqo

View on GitHub
atqo/simplified_functions.py

Summary

Maintainability
A
0 mins
Test Coverage
from functools import partial
from itertools import islice
from multiprocessing import cpu_count

from .bases import ActorBase
from .core import Scheduler, SchedulerTask
from .distributed_apis import DEFAULT_MULTI_API
from .resource_handling import Capability, CapabilitySet

_RES = "CPU"
_CAP = Capability({_RES: 1})
_Task = partial(SchedulerTask, requirements=[_CAP])


class BatchProd:
    def __init__(self, iterable, batch_size, mapper=_Task) -> None:
        self._size = batch_size
        self._it = iter(iterable)
        self._mapper = mapper

    def __call__(self):
        return [*map(self._mapper, islice(self._it, self._size))]


class ActWrap(ActorBase):
    def __init__(self, fun) -> None:
        self._f = fun

    def consume(self, task_arg):
        return self._f(task_arg)


def get_simp_scheduler(n, Actor, dist_sys, verbose) -> Scheduler:
    return Scheduler(
        actor_dict={CapabilitySet([_CAP]): Actor},
        resource_limits={_RES: n},
        distributed_system=dist_sys,
        verbose=verbose,
    )


def parallel_consume(
    Actor: type[ActorBase],
    iterable,
    dist_api=DEFAULT_MULTI_API,
    batch_size=None,
    min_queue_size=None,
    workers=None,
    raise_errors=True,
    verbose=False,
    pbar=False,
    allowed_fail_count=0,
):
    nw = workers or cpu_count()
    batch_size = batch_size or nw * 5
    min_queue_size = min_queue_size or batch_size // 2

    pinger = get_pinger(iterable, pbar)
    scheduler = get_simp_scheduler(nw, Actor, dist_api, verbose)

    out_iter = scheduler.process(
        batch_producer=BatchProd(
            iterable, batch_size, partial(_Task, allowed_fail_count=allowed_fail_count)
        ),
        min_queue_size=min_queue_size,
    )
    try:
        for e in out_iter:
            if raise_errors and isinstance(e, Exception):
                raise e
            pinger()
            yield e
    finally:
        scheduler.join()


def parallel_map(
    fun,
    iterable,
    dist_api=DEFAULT_MULTI_API,
    batch_size=None,
    min_queue_size=None,
    workers=None,
    raise_errors=True,
    verbose=False,
    pbar=False,
    allowed_fail_count=0,
):
    return parallel_consume(
        partial(ActWrap, fun=fun),
        iterable,
        dist_api,
        batch_size,
        min_queue_size,
        workers,
        raise_errors,
        verbose,
        pbar,
        allowed_fail_count,
    )


def get_pinger(iterable, pbar):
    if not pbar:
        return lambda: None

    from tqdm import tqdm

    try:
        total = len(iterable)
    except Exception:
        total = None
    return tqdm(total=total, desc=pbar if isinstance(pbar, str) else "parallel").update