LiberTEM/LiberTEM

View on GitHub
src/libertem/common/numba/cache.py

Summary

Maintainability
A
55 mins
Test Coverage
import time
import hashlib
import logging

logger = logging.getLogger(__name__)

import numba  # NOQA: E402
try:
    from numba.core.serialize import dumps
    from numba.core.caching import FunctionCache
    from numba.core.registry import dispatcher_registry, CPUDispatcher
    old_dispatch = True
except ImportError:
    try:
        # numba 0.54 changes the internal structure, so we have to adapt our imports:
        from numba.core.serialize import dumps
        from numba.core.caching import FunctionCache
        from numba.core.registry import CPUDispatcher
        from numba.core.target_extension import (
            dispatcher_registry, jit_registry, target_registry, Generic
        )
        old_dispatch = False
    except ImportError as e:
        CPUDispatcher = None
        dispatcher_registry = None
        logger.warning(
            "could not register custom numba dispatcher, disabling custom cache (%s)" % str(e)
        )
        logger.warning(
            "numba version %s" % str(numba.__version__)
        )


_cached_njit_reg = []


def cached_njit(*args, **kwargs):
    """
    Replacement for numba.njit with custom caching. Only supports usage
    with parameters, i.e.

    @cached_njit()
    def fn():
        ...
    """
    def wrapper(fn):
        # only register with the custom target if we manage to import the right
        # structures from numba:
        if dispatcher_registry is None:
            kwargs.update({'cache': True})
        else:
            kwargs.update({'_target': 'custom_cpu', 'cache': True})
            _cached_njit_reg.append(fn)
        return numba.njit(fn, *args, **kwargs)
    return wrapper


def hasher(x):
    return hashlib.sha256(x).hexdigest()


if dispatcher_registry is not None:
    class MyFunctionCache(FunctionCache):
        def _get_dependencies(self, cvar):
            deps = [cvar]
            if hasattr(cvar, 'py_func'):
                # TODO: does the cache key need to depend on any other
                # attributes of the Dispatcher?
                closure = cvar.py_func.__closure__
                deps = [cvar.py_func.__code__.co_code]
            elif hasattr(cvar, '__closure__'):
                closure = cvar.__closure__
                # if cvar is a function and closes over a Dispatcher, the
                # cache will be busted because of the uuid that is regenerated
                deps = [cvar.__code__.co_code]
            else:
                closure = None
            if closure is not None:
                for x in closure:
                    deps.extend(self._get_dependencies(x.cell_contents))
            return deps

        def _index_key(self, sig, codegen):
            """
            Compute index key for the given signature and codegen.
            It includes a description of the OS, target architecture and hashes of
            the bytecode for the function and, if the function has a __closure__,
            a hash of the cell_contents.
            """
            codebytes = self._py_func.__code__.co_code
            cvars = self._get_dependencies(self._py_func)
            if len(cvars) > 0:
                cvarbytes = dumps(cvars)
            else:
                cvarbytes = b''

            return (sig, "libertem-numba-cache", codegen.magic_tuple(), (hasher(codebytes),
                                                 hasher(cvarbytes),))

        def load_overload(self, *args, **kwargs):
            t0 = time.time()
            data = super().load_overload(*args, **kwargs)
            t1 = time.time()
            if data is None:
                logger.info(f"numba cache miss {self._name} {self._py_func}")
            else:
                logger.info(f"cache hit for {self._name}, load took {(t1 - t0):.3f}s")
            return data

    # if we can hack in our custom caching, do it:
    class MyCPUDispatcher(CPUDispatcher):
        def enable_caching(self):
            self._cache = MyFunctionCache(self.py_func)

    if old_dispatch:
        dispatcher_registry['custom_cpu'] = MyCPUDispatcher
    else:
        class MyCPU(Generic):
            pass

        dispatcher_registry[MyCPU] = MyCPUDispatcher
        jit_registry[MyCPU] = cached_njit  # FIXME: is this needed?
        target_registry["custom_cpu"] = MyCPU