winter_sqlalchemy/repository.py

Summary

Maintainability
B
5 hrs
Test Coverage
from typing import Iterable
from typing import List
from typing import Optional
from typing import TypeVar

from injector import inject
from sqlalchemy import and_
from sqlalchemy import exists
from sqlalchemy import func
from sqlalchemy import inspect
from sqlalchemy import select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import class_mapper
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.exc import UnmappedClassError

from winter.core import get_injector
from winter.data import CRUDRepository
from winter.data.exceptions import NotFoundException
from winter_ddd import AggregateRoot
from winter_ddd import DomainEvent
from winter_ddd import DomainEventDispatcher

T = TypeVar('T')
K = TypeVar('K')


def sqla_crud(repository_cls):
    if not issubclass(repository_cls, CRUDRepository):
        raise TypeError('Repository must be inherited from CRUDRepository before annotating with sqla_crud')

    entity_cls = repository_cls.__entity_cls__

    if not issubclass(entity_cls, AggregateRoot):
        raise TypeError('Entity class must be inherited from winter_ddd.AggregateRoot')

    try:
        mapper = class_mapper(entity_cls)
    except UnmappedClassError:  # pragma: no cover
        raise TypeError('Invalid SQLAlchemy entity class given')

    if len(mapper.tables) > 1:
        raise TypeError('sqla_crud does not support entities mapped to multiple tables')

    entity_table = mapper.tables[0]

    class DefaultCRUDRepositoryImpl:
        """
        SQLAlchemy implementation for CRUDRepository
        This repository implementation is not thread-safe.
        """
        class RepositoryException(Exception):
            pass

        def count(self) -> int:
            return self.__engine.execute(select([func.count()]).select_from(entity_table)).scalar()

        def delete(self, entity: T):
            try:
                session = self.__sessions[entity]
            except KeyError:  # pragma: no cover
                raise self.RepositoryException('Entity must be fetched with repository before being deleted')
            pk = inspect(entity).identity
            del self.__identity_map[pk]
            session.delete(entity)
            session.commit()
            session.close()
            del self.__sessions[entity]

        def delete_many(self, entities: Iterable[T]):
            for entity in entities:
                self.delete(entity)

        def delete_all(self):  # pragma: no cover
            self.__engine.execute(entity_table.delete())
            self.__identity_map = {}
            for session in self.__sessions.values():
                session.close()
            self.__sessions = {}

        def delete_by_id(self, id_: K):  # pragma: no cover
            if not isinstance(id_, tuple):
                id_ = (id_,)
            if id_ in self.__identity_map:  # pragma: no cover
                entity = self.__identity_map[id_]
                self.delete(entity)
            else:
                expressions = (column == value for column, value in zip(entity_table.primary_key.columns, id_))
                self.__engine.execute(entity_table.delete().where(and_(*expressions)))

        def exists_by_id(self, id_: K) -> bool:  # pragma: no cover
            if not isinstance(id_, tuple):
                id_ = (id_,)
            if id_ in self.__identity_map:
                return True
            expressions = (column == value for column, value in zip(entity_table.primary_key.columns, id_))
            return self.__engine.execute(select([exists().where(and_(*expressions))])).scalar()

        def find_all(self) -> Iterable[T]:
            ids = self.__engine.execute(select(entity_table.primary_key.columns))
            ids = next(zip(*ids), [])
            return self.find_all_by_id(ids)

        def find_all_by_id(self, ids: Iterable[K]) -> Iterable[T]:
            result = [self.find_by_id(id_) for id_ in ids]
            result = [entity for entity in result if entity]
            return result

        def find_by_id(self, id_: K) -> Optional[T]:  # pragma: no cover
            if not isinstance(id_, tuple):
                id_ = (id_,)
            if id_ in self.__identity_map:
                return self.__identity_map[id_]

            session = self.__session_factory()
            instance = session.query(entity_cls).get(id_)
            if instance is None:
                return None

            self.__identity_map[id_] = instance
            self.__sessions[instance] = session
            return instance

        def get_by_id(self, id_: K) -> T:  # pragma: no cover
            entity = self.find_by_id(id_)
            if entity is None:
                raise NotFoundException(id_, entity_cls)
            return entity

        def save(self, entity: T) -> T:
            self._save(entity)
            self._process_domain_events([entity])
            return entity

        def save_many(self, entities: Iterable[T]) -> Iterable[T]:
            entities = list(entities)
            for entity in entities:
                self._save(entity)
            self._process_domain_events(entities)
            return entities

        def _save(self, entity: T):
            if entity not in self.__sessions:
                self.__sessions[entity] = self.__session_factory()
                self.__sessions[entity].add(entity)
            self.__sessions[entity].flush()
            pk = inspect(entity).identity
            self.__identity_map[pk] = entity

        def _process_domain_events(self, aggregates: Iterable[AggregateRoot]):
            domain_events: List[DomainEvent] = []
            for aggregate in aggregates:
                domain_events.extend(aggregate.domain_events)
                aggregate.clear_domain_events()
            self.__domain_event_dispatcher.dispatch(domain_events)

        @inject
        def __init__(self, engine: Engine, domain_event_dispatcher: DomainEventDispatcher):
            self.__engine = engine
            self.__session_factory = sessionmaker(bind=self.__engine)
            self.__identity_map = {}
            self.__sessions = {}
            self.__domain_event_dispatcher = domain_event_dispatcher

    repository_subclasses = repository_cls.__subclasses__()

    class RepositoryImpl(DefaultCRUDRepositoryImpl, *repository_subclasses):
        @inject
        def __init__(self):
            injector = get_injector()
            injector.call_with_injection(DefaultCRUDRepositoryImpl.__init__, self)
            for subclass in repository_subclasses:
                injector.call_with_injection(subclass.__init__, self)

    return RepositoryImpl