cryptic-game/python-daemon

View on GitHub
daemon/database/database.py

Summary

Maintainability
A
0 mins
Test Coverage
A
100%
from asyncio import Event
from contextvars import ContextVar
from typing import TypeVar, Optional, Type

# noinspection PyProtectedMember
from sqlalchemy.engine import URL
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.future import select as sa_select, Select
from sqlalchemy.orm import selectinload
from sqlalchemy.sql import Executable
from sqlalchemy.sql.expression import exists as sa_exists, delete as sa_delete, Delete
from sqlalchemy.sql.functions import count
from sqlalchemy.sql.selectable import Exists

from ..environment import (
    DB_DRIVER,
    DB_HOST,
    DB_PORT,
    DB_DATABASE,
    DB_USERNAME,
    DB_PASSWORD,
    SQL_SHOW_STATEMENTS,
    POOL_RECYCLE,
    POOL_SIZE,
    MAX_OVERFLOW,
)
from ..logger import get_logger

T = TypeVar("T")

logger = get_logger(__name__)


def select(entity, *args) -> Select:
    """Shortcut for :meth:`sqlalchemy.future.select`"""

    if not args:
        return sa_select(entity)

    options = []
    for arg in args:
        if isinstance(arg, (tuple, list)):
            head, *tail = arg
            opt = selectinload(head)
            for x in tail:
                opt = opt.selectinload(x)
            options.append(opt)
        else:
            options.append(selectinload(arg))

    return sa_select(entity).options(*options)


def filter_by(cls, *args, **kwargs) -> Select:
    """Shortcut for :meth:`sqlalchemy.future.Select.filter_by`"""

    return select(cls, *args).filter_by(**kwargs)


def exists(*entities, **kwargs) -> Exists:
    """Shortcut for :meth:`sqlalchemy.future.select`"""

    return sa_exists(*entities, **kwargs)


def delete(table) -> Delete:
    """Shortcut for :meth:`sqlalchemy.sql.expression.delete`"""

    return sa_delete(table)


class DB:
    """
    Database connection

    Attributes
    ----------
    engine: :class:`sqlalchemy.engine.Engine`
    Base: :class:`sqlalchemy.ext.declarative.DeclarativeMeta`
    """

    def __init__(
        self,
        driver: str,
        host: str,
        port: int,
        database: str,
        username: str,
        password: str,
        pool_recycle: int = 300,
        pool_size: int = 20,
        max_overflow: int = 20,
        echo: bool = False,
    ):
        """
        :param driver: name of the sql connection driver
        :param host: host of the sql server
        :param port: port of the sql server
        :param database: name of the database
        :param username: name of the sql user
        :param password: password of the sql user
        :param echo: whether sql queries should be logged
        """

        self.engine: AsyncEngine = create_async_engine(
            URL.create(
                drivername=driver,
                username=username,
                password=password,
                host=host,
                port=port,
                database=database,
            ),
            pool_pre_ping=True,
            pool_recycle=pool_recycle,
            pool_size=pool_size,
            max_overflow=max_overflow,
            echo=echo,
        )

        self.Base = declarative_base()

        self._session: ContextVar[Optional[AsyncSession]] = ContextVar("session", default=None)
        self._close_event: ContextVar[Optional[Event]] = ContextVar("close_event", default=None)

    async def create_tables(self):
        """Create all tables defined in enabled cog packages."""

        logger.debug("creating tables")
        async with self.engine.begin() as conn:
            await conn.run_sync(self.Base.metadata.create_all)

    async def add(self, obj: T) -> T:
        """
        Add a new row to the database

        :param obj: the row to insert
        :return: the same row
        """

        self.session.add(obj)
        return obj

    async def delete(self, obj: T) -> T:
        """
        Remove a row from the database

        :param obj: the row to remove
        :return: the same row
        """

        await self.session.delete(obj)
        return obj

    async def exec(self, statement: Executable, *args, **kwargs):
        """Execute an sql statement and return the result."""

        return await self.session.execute(statement, *args, **kwargs)

    async def stream(self, statement: Executable, *args, **kwargs):
        """Execute an sql statement and stream the result."""

        return (await self.session.stream(statement, *args, **kwargs)).scalars()

    async def all(self, statement: Executable, *args, **kwargs) -> list[T]:
        """Execute an sql statement and return all results as a list."""

        return [x async for x in await self.stream(statement, *args, **kwargs)]

    async def first(self, *args, **kwargs):
        """Execute an sql statement and return the first result."""

        return (await self.exec(*args, **kwargs)).scalar()

    async def exists(self, *args, **kwargs):
        """Execute an sql statement and return whether it returned at least one row."""

        return await self.first(exists(*args, **kwargs).select())

    async def count(self, *args, **kwargs):
        """Execute an sql statement and return the number of returned rows."""

        return await self.first(select(count()).select_from(*args, **kwargs))

    async def get(self, cls: Type[T], *args, **kwargs) -> Optional[T]:
        """Shortcut for first(filter_by(...))"""

        return await self.first(filter_by(cls, *args, **kwargs))

    async def commit(self):
        """Shortcut for :meth:`sqlalchemy.ext.asyncio.AsyncSession.commit`"""

        if self._session.get():
            await self.session.commit()

    async def close(self):
        """Close the current session"""

        if self._session.get():
            await self.session.close()
            self._close_event.get().set()

    def create_session(self) -> AsyncSession:
        """Create a new async session and store it in the context variable."""

        self._session.set(session := AsyncSession(self.engine))
        self._close_event.set(Event())
        return session

    @property
    def session(self) -> AsyncSession:
        """Get the session object for the current task"""

        return self._session.get()

    async def wait_for_close_event(self):
        await self._close_event.get().wait()


def get_database() -> DB:
    """
    Create a database connection object using the environment variables

    :return: The DB object
    """

    return DB(
        driver=DB_DRIVER,
        host=DB_HOST,
        port=DB_PORT,
        database=DB_DATABASE,
        username=DB_USERNAME,
        password=DB_PASSWORD,
        pool_recycle=POOL_RECYCLE,
        pool_size=POOL_SIZE,
        max_overflow=MAX_OVERFLOW,
        echo=SQL_SHOW_STATEMENTS,
    )