JrGoodle/clowder

View on GitHub
clowder/util/tasks/task_pool.py

Summary

Maintainability
A
0 mins
Test Coverage
"""task pool

.. codeauthor:: Joe DeCapo <joe@polka.cat>

"""

from threading import Lock
from typing import Any, List, Optional

import trio

from clowder.util.console import disable_output
from clowder.util.util import values_sorted_by_key


class Task:

    def __init__(self, name: str):
        self.name: str = name
        self._pool: Optional[TaskPool] = None

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self._pool = None
        pass

    @property
    def cancelled(self) -> bool:
        return self._pool.cancelled

    def in_pool(self, pool: 'TaskPool') -> 'Task':
        self._pool = pool
        return self

    def before_task(self) -> None:
        pass

    def after_task(self) -> None:
        pass

    def run(self) -> None:
        raise NotImplementedError


class TaskPool:

    def __init__(self, jobs: Optional[int] = None):
        self._jobs: Optional[int] = jobs
        self._lock: Lock = Lock()
        self._results: Optional[List[Any]] = None
        self.cancelled: bool = False

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        pass

    def before_tasks(self, tasks: List[Task]) -> None:
        pass

    def before_task(self, task: Task) -> None:
        pass

    def after_task(self, task: Task) -> None:
        pass

    def after_tasks(self, tasks: List[Task]) -> None:
        pass

    @disable_output
    def run(self, tasks: List[Task]) -> List[Any]:
        return trio.run(self._run, tasks)

    async def _run(self, tasks: List[Task]) -> List[Any]:
        try:
            async with trio.open_nursery() as nursery:
                limit = None if self._jobs is None else trio.CapacityLimiter(self._jobs)
                self.before_tasks(tasks)
                self._results = {}
                index = 0
                try:
                    for task in tasks:
                        if limit is not None:
                            await limit.acquire_on_behalf_of(task.name)
                        nursery.start_soon(self._run_task, index, task, limit, nursery)
                        index += 1
                except BaseException:
                    nursery.cancel_scope.cancel()
                    self.cancelled = True
                    raise
            return values_sorted_by_key(self._results)
        except BaseException:
            self.cancelled = True
            raise
        finally:
            self.after_tasks(tasks)

    async def _run_task(self, index: int, task: Task, limit: Optional[trio.CapacityLimiter],
                        nursery: trio.Nursery) -> Any:
        with task.in_pool(self):
            try:
                self.before_task(task)
                task.before_task()
                result = await trio.to_thread.run_sync(task.run)
                with self._lock:
                    self._results[index] = result
            except BaseException:
                self.cancelled = True
                nursery.cancel_scope.cancel()
                raise
            finally:
                task.after_task()
                self.after_task(task)
                if limit is not None:
                    limit.release_on_behalf_of(task.name)