serum/_context.py

Summary

Maintainability
A
3 hrs
Test Coverage
from copy import copy, deepcopy
from unittest.mock import create_autospec, MagicMock
from functools import wraps
from typing import Type, Set, Union, Dict, TypeVar

from ._dependency_configuration import DependencyConfiguration
from ._key import Key
from .exceptions import NoNamedDependency, InjectionError
from .exceptions import (
    AmbiguousDependencies,
    CircularDependency)
import threading
import inspect
from collections import Counter


T = TypeVar('T')


class _LocalStorage(threading.local):
    def __init__(self):
        self.current_env: Context = None


class _ContextState(threading.local):
    def __init__(self):
        self.pending: Set[Type[object]] = set()
        self.old_current: Context = None
        self.mocks: Dict[Union[str, Type[object]], MagicMock] = dict()
        self.singletons: Dict[Type[T], T] = dict()

    def __deepcopy__(self, memodict):
        new = _ContextState()
        new.pending = copy(self.pending)
        new.old_current = copy(self.old_current)
        new.mocks = copy(self.mocks)
        new.singletons = copy(self.mocks)
        return new


class Context:
    """
    Context manager/decorator for providing dependencies:

    with Context(MyDependency):
        NeedsDependency()
    or

    @Context(MyDependency):
    def fun():
        NeedsDependency()
    """
    __local_storage = _LocalStorage()

    @staticmethod
    def current_context() -> 'Context':
        env = Context.__local_storage.current_env
        if env is None:
            env = Context()
            Context._set_current_env(env)
        return env

    @staticmethod
    def mock(dependency: Union[str, Type[object]]):
        current_env = Context.current_context()
        if isinstance(dependency, str):
            value = current_env[dependency]
            mock = create_autospec(value)
        else:
            mock = create_autospec(dependency, instance=True)
        current_env.__state.mocks[dependency] = mock
        return mock

    @staticmethod
    def _set_current_env(env: 'Context'):
        Context.__local_storage.current_env = env

    def __init__(self, *args: Type[object], **kwargs: object) -> None:
        """
        Construct a new context
        :param args: Dependency decorated types to provide in this context
        :param kwargs: Named dependencies to provide in this context
        """
        self.__registry: Set[Type[object]] = set()
        self.__state: _ContextState = _ContextState()
        self.__named_dependencies = kwargs
        self.__old_current = None
        for c in args:
            self.__use(c)

    def __getitem__(self, item: str):
        """
        Get named dependency in this environment
        :param item: Name of dependency
        :return: dependency
        """
        if self.is_mocked(item):
            return self.__state.mocks[item]
        try:
            return self.__named_dependencies[item]
        except KeyError:
            raise NoNamedDependency(
                f'Named dependency "{item}" not found in: {repr(self)}'
            )

    @property
    def pending(self) -> Set[Type[object]]:
        return self.__state.pending

    def get_mock(self, component: Type[object]) -> MagicMock:
        return self.__state.mocks[component]

    def is_mocked(self, component: Union[str, Type[object]]) -> bool:
        return component in self.__state.mocks

    def __use(self, component: Type[object]) -> 'Context':
        self.__registry.add(component)
        return self

    def __contains__(self, component: Type[object]) -> bool:
        """
        Test if a Dependency is registered in this context
        :param component: Dependency to test
        :return: True if component is registered in this environment else False
        """
        if isinstance(component, str):
            return component in self.__named_dependencies
        return component in self.__registry

    def __call__(self, f):
        """
        Decorate a function to run in this context
        :param f: function to decorate
        :return: decorated function
        """
        @wraps(f)
        def run_in(*args, **kwargs):
            with self:
                f(*args, **kwargs)
        return run_in

    def __iter__(self):
        """
        Iterate over the dependencies registered in this context
        :return: Iterator of components
        """
        return iter(self.__registry)

    def __or__(self, other: 'Context') -> 'Context':
        """
        Combine this environment with another, such that the new context
        can provide all dependencies in both contexts
        :param other: context to combine with this one
        :return: New context with components from this and the other
                 environment
        """
        new_registry = self.__registry | other.__registry
        return Context(*new_registry)

    def __enter__(self):
        """
        Register this context as the current environment in this thread
        :return:
        """
        self.__old_current = Context.current_context()
        old_state = self.__old_current.__copy_state()
        self.__set_state(old_state)
        Context._set_current_env(self)
        return self

    def __set_state(self, state):
        self.__state = state

    def __copy_state(self):
        return deepcopy(self.__state)

    def has_singleton_instance(self, singleton_type):
        return singleton_type in self.__state.singletons

    def __exit__(self, exc_type, exc_val, exc_tb):
        """
        De-register this context as the current environment in this thread
        :param exc_type:
        :param exc_val:
        :param exc_tb:
        :return:
        """
        self.__state.__init__()
        Context._set_current_env(self.__old_current)
        self.__old_current = None

    @staticmethod
    def provide(configuration: DependencyConfiguration) -> Union[T, MagicMock]:
        """
        Provide a dependency in this context
        :param configuration: The type to provide
        :return: Instance of the most specific subtype of component
                 in this environment
        """
        context = Context.current_context()

        def singleton(singleton_type: Type[T]) -> T:
            if context.has_singleton_instance(singleton_type):
                return context.get_singleton(singleton_type)
            else:
                singleton_instance = singleton_type()
                context.add_singleton(
                    singleton_type,
                    singleton_instance
                )
                return singleton_instance

        def instance(component_type: Type[T]) -> T:
            component_instance = component_type()
            return component_instance

        def is_singleton(st):
            return hasattr(st, '__singleton__')

        def instantiate(dependency_type: Type[T]) -> T:
            if dependency_type in context.pending:
                raise CircularDependency(
                    f'Circular dependency encountered while injecting '
                    f'{dependency_type} as "{configuration.name}" in '
                    f'{configuration.owner}'
                )
            context.pending.add(dependency_type)
            try:
                if is_singleton(dependency_type):
                    component_instance = singleton(dependency_type)
                else:
                    component_instance = instance(dependency_type)
                return component_instance
            except CircularDependency:
                raise
            except Exception as e:
                raise InjectionError(
                    f'Could not instantiate {dependency_type} when injecting '
                    f'"{configuration.name}" in {configuration.owner}.'
                ) from e
            finally:
                context.pending.remove(dependency_type)
        dependency: Type[T] = configuration.dependency
        if context.is_mocked(dependency):
            return context.get_mock(dependency)
        subtype = Context.find_subtype(dependency)
        if subtype is None:
            return instantiate(dependency)
        return instantiate(subtype)

    @staticmethod
    def find_subtype(component: Type[T]) -> Union[Type[T], None]:
        def mro_distance(subtype: Type[T]) -> int:
            mro = inspect.getmro(subtype)
            return mro.index(component)

        subtypes = [c for c in Context.current_context()
                    if issubclass(c, component)]
        distances = [mro_distance(subtype) for subtype in subtypes]
        counter = Counter(distances)
        if any(count > 1 for count in counter.values()):
            ambiguous = [str(subtype) for subtype in subtypes
                         if counter[mro_distance(subtype)] > 1]
            message = ('Attempt to inject type {} with '
                       'equally specific provided subtypes: {}')
            message = message.format(
                str(component),
                ', '.join(ambiguous)
            )
            raise AmbiguousDependencies(message)
        if not subtypes:
            return None
        return max(subtypes, key=mro_distance)

    def get_singleton(self, singleton_type):
        return self.__state.singletons[singleton_type]

    def add_singleton(self, singleton_type, instance):
        self.__state.singletons[singleton_type] = instance

    def __repr__(self):
        dependencies = [repr(dependency) for dependency in self.__registry]
        named_dependencies = [f'{key}={repr(value)}' for key, value
                              in self.__named_dependencies.items()]
        args = ', '.join(dependencies + named_dependencies)
        return f'Context({args})'


__all__ = ['Context']


def provide(configuration: DependencyConfiguration):
    if isinstance(configuration.dependency, Key):
        return Context.current_context()[configuration.name]
    return Context.provide(configuration)


def current_context():
    return Context.current_context()