piotrmaslanka/satella

View on GitHub
satella/instrumentation/metrics/structures/threadpool.py

Summary

Maintainability
A
0 mins
Test Coverage
D
63%
import itertools
import queue
import threading
import weakref
from concurrent.futures import _base
from concurrent.futures import thread
from concurrent.futures.thread import ThreadPoolExecutor, _WorkItem

try:
    from concurrent.futures.thread import BrokenThreadPool
except ImportError:
    BrokenThreadPool = RuntimeError
import typing as tp

from satella.time import measure

from satella.instrumentation.metrics.metric_types import EmptyMetric, MetricLevel, CallableMetric


def _worker(executor_reference, work_queue, initializer, initargs):
    if initializer is not None:
        try:
            initializer(*initargs)
        except BaseException:       # pylint: disable=broad-exception-caught
            _base.LOGGER.critical('Exception in initializer:', exc_info=True)
            executor = executor_reference()
            if executor is not None:
                executor._initializer_failed()          # pylint: disable=protected-access
            return
    try:
        while True:
            work_item = work_queue.get(block=True)
            if work_item is not None:
                # Measure the time spent in waiting
                executor = executor_reference()
                work_item.measure.stop()
                executor.waiting_time_metric.handle(executor.metric_level, work_item.measure())
                del executor

                with measure() as measurement:
                    work_item.run()

                executor = executor_reference()
                executor.executing_time_metric.handle(executor.metric_level, measurement())
                # Delete references to object. See issue16284
                del work_item

                # attempt to increment idle count

                if executor is not None:
                    executor._idle_semaphore.release()      # pylint: disable=protected-access
                del executor
                continue

            executor = executor_reference()
            # Exit if:
            #   - The interpreter is shutting down OR
            #   - The executor that owns the worker has been collected OR
            #   - The executor that owns the worker has been shutdown.
            if thread._shutdown or executor is None or executor._shutdown:  # pylint: disable=protected-access
                # Flag the executor as shutting down as early as possible if it
                # is not gc-ed yet.
                if executor is not None:
                    executor._shutdown = True       # pylint: disable=protected-access
                # Notice other workers
                work_queue.put(None)
                return
            del executor
    except BaseException as e:
        _base.LOGGER.critical('Exception in worker', exc_info=e)


class MetrifiedThreadPoolExecutor(ThreadPoolExecutor):
    """
    A thread pool executor that provides execution statistics as metrics.

    This class will also backport some of Python 3.8's characteristics of the thread pool executor
    to earlier Pythons, thread name prefix, initializer, initargs and BrokenThreadPool behaviour.

    :param time_spent_waiting: a metric (can be aggregate) to which times spent waiting in the
        queue will be deposited
    :param time_spent_executing: a metric (can be aggregate) to which times spent executing will
        be deposited
    :param waiting_tasks: a fresh CallableMetric that will be patched to yield the number of
        currently waiting tasks
    :param metric_level: a level with which to log to these two metrics
    """

    _counter = itertools.count().__next__

    def __init__(self, max_workers=None, thread_name_prefix='',
                 initializer=None, initargs=(),
                 time_spent_waiting=None,
                 time_spent_executing=None,
                 waiting_tasks: tp.Optional[CallableMetric] = None,
                 metric_level: MetricLevel = MetricLevel.RUNTIME):
        super().__init__(max_workers)
        self._initializer = initializer
        self._initargs = initargs
        self._idle_semaphore = threading.Semaphore(0)
        self._broken = False
        if not hasattr(self, '_thread_name_prefix'):
            self._thread_name_prefix = thread_name_prefix or ("ThreadPoolExecutor-%d" %
                                                              self._counter())
        self.waiting_time_metric = time_spent_waiting or EmptyMetric('')
        self.executing_time_metric = time_spent_executing or EmptyMetric('')
        self.metric_level = metric_level
        if waiting_tasks is not None:
            waiting_tasks.callable = lambda: self.get_queue_length()

    def get_queue_length(self) -> int:
        """
        Return the amount of tasks currently in the queue
        """
        return self._work_queue.qsize()

    def submit(*args, **kwargs):
        if len(args) >= 2:
            self, fn, *args = args
        elif not args:
            raise TypeError("descriptor 'submit' of 'ThreadPoolExecutor' object "
                            "needs an argument")
        elif 'fn' in kwargs:
            fn = kwargs.pop('fn')
            self, *args = args
            import warnings
            warnings.warn("Passing 'fn' as keyword argument is deprecated",
                          DeprecationWarning, stacklevel=2)
        else:
            raise TypeError('submit expected at least 1 positional argument, '
                            'got %d' % (len(args) - 1))

        with self._shutdown_lock:
            if self._broken:
                raise BrokenThreadPool(self._broken)

            if self._shutdown:
                raise RuntimeError('cannot schedule new futures after shutdown')
            if thread._shutdown:            # pylint: disable=protected-access
                raise RuntimeError('cannot schedule new futures after '
                                   'interpreter shutdown')

            f = _base.Future()
            w = _WorkItem(f, fn, args, kwargs)
            w.measure = measure()
            self._work_queue.put(w)
            self._adjust_thread_count()
            return f

    def _adjust_thread_count(self):
        # if idle threads are available, don't spin new threads
        if self._idle_semaphore.acquire(timeout=0):
            return

        # When the executor gets lost, the weakref callback will wake up
        # the worker threads.
        def weakref_cb(_, q=self._work_queue):
            q.put(None)

        num_threads = len(self._threads)
        if num_threads < self._max_workers:
            thread_name = '%s_%d' % (self._thread_name_prefix or self,
                                     num_threads)
            t = threading.Thread(name=thread_name, target=_worker,
                                 args=(weakref.ref(self, weakref_cb),
                                       self._work_queue,
                                       self._initializer,
                                       self._initargs))
            t.daemon = True
            t.start()
            self._threads.add(t)
            thread._threads_queues[t] = self._work_queue

    def _initializer_failed(self):
        with self._shutdown_lock:
            self._broken = ('A thread initializer failed, the thread pool '
                            'is not usable anymore')
            # Drain work queue and mark pending futures failed
            while True:
                try:
                    work_item = self._work_queue.get_nowait()
                except queue.Empty:
                    break
                if work_item is not None:
                    work_item.future.set_exception(BrokenThreadPool(self._broken))