Cog-Creators/Red-DiscordBot

View on GitHub
redbot/cogs/audio/apis/local_db.py

Summary

Maintainability
A
0 mins
Test Coverage
import concurrent
import contextlib
import datetime
import random
import time
from pathlib import Path
from types import SimpleNamespace
from typing import TYPE_CHECKING, Callable, List, MutableMapping, Optional, Tuple, Union

from red_commons.logging import getLogger

from redbot.core import Config
from redbot.core.bot import Red
from redbot.core.commands import Cog
from redbot.core.i18n import Translator
from redbot.core.utils import AsyncIter
from redbot.core.utils.dbtools import APSWConnectionWrapper

from ..sql_statements import (
    LAVALINK_CREATE_INDEX,
    LAVALINK_CREATE_TABLE,
    LAVALINK_DELETE_OLD_ENTRIES,
    LAVALINK_FETCH_ALL_ENTRIES_GLOBAL,
    LAVALINK_QUERY,
    LAVALINK_QUERY_ALL,
    LAVALINK_QUERY_LAST_FETCHED_RANDOM,
    LAVALINK_UPDATE,
    LAVALINK_UPSERT,
    SPOTIFY_CREATE_INDEX,
    SPOTIFY_CREATE_TABLE,
    SPOTIFY_DELETE_OLD_ENTRIES,
    SPOTIFY_QUERY,
    SPOTIFY_QUERY_ALL,
    SPOTIFY_QUERY_LAST_FETCHED_RANDOM,
    SPOTIFY_UPDATE,
    SPOTIFY_UPSERT,
    YOUTUBE_CREATE_INDEX,
    YOUTUBE_CREATE_TABLE,
    YOUTUBE_DELETE_OLD_ENTRIES,
    YOUTUBE_QUERY,
    YOUTUBE_QUERY_ALL,
    YOUTUBE_QUERY_LAST_FETCHED_RANDOM,
    YOUTUBE_UPDATE,
    YOUTUBE_UPSERT,
    PRAGMA_FETCH_user_version,
    PRAGMA_SET_journal_mode,
    PRAGMA_SET_read_uncommitted,
    PRAGMA_SET_temp_store,
    PRAGMA_SET_user_version,
)
from .api_utils import (
    LavalinkCacheFetchForGlobalResult,
    LavalinkCacheFetchResult,
    SpotifyCacheFetchResult,
    YouTubeCacheFetchResult,
)

if TYPE_CHECKING:
    from .. import Audio


log = getLogger("red.cogs.Audio.api.LocalDB")
_ = Translator("Audio", Path(__file__))
_SCHEMA_VERSION = 3


class BaseWrapper:
    def __init__(
        self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog]
    ):
        self.bot = bot
        self.config = config
        self.database = conn
        self.statement = SimpleNamespace()
        self.statement.pragma_temp_store = PRAGMA_SET_temp_store
        self.statement.pragma_journal_mode = PRAGMA_SET_journal_mode
        self.statement.pragma_read_uncommitted = PRAGMA_SET_read_uncommitted
        self.statement.set_user_version = PRAGMA_SET_user_version
        self.statement.get_user_version = PRAGMA_FETCH_user_version
        self.fetch_result: Optional[Callable] = None
        self.cog = cog

    async def init(self) -> None:
        """Initialize the local cache"""
        with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
            executor.submit(self.database.cursor().execute, self.statement.pragma_temp_store)
            executor.submit(self.database.cursor().execute, self.statement.pragma_journal_mode)
            executor.submit(self.database.cursor().execute, self.statement.pragma_read_uncommitted)
            executor.submit(self.maybe_migrate)
            executor.submit(self.database.cursor().execute, LAVALINK_CREATE_TABLE)
            executor.submit(self.database.cursor().execute, LAVALINK_CREATE_INDEX)
            executor.submit(self.database.cursor().execute, YOUTUBE_CREATE_TABLE)
            executor.submit(self.database.cursor().execute, YOUTUBE_CREATE_INDEX)
            executor.submit(self.database.cursor().execute, SPOTIFY_CREATE_TABLE)
            executor.submit(self.database.cursor().execute, SPOTIFY_CREATE_INDEX)
            await self.clean_up_old_entries()

    def close(self) -> None:
        """Close the connection with the local cache"""
        with contextlib.suppress(Exception):
            self.database.close()

    async def clean_up_old_entries(self) -> None:
        """Delete entries older than x in the local cache tables"""
        max_age = await self.config.cache_age()
        maxage = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(days=max_age)
        maxage_int = int(time.mktime(maxage.timetuple()))
        values = {"maxage": maxage_int}
        with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
            executor.submit(self.database.cursor().execute, LAVALINK_DELETE_OLD_ENTRIES, values)
            executor.submit(self.database.cursor().execute, YOUTUBE_DELETE_OLD_ENTRIES, values)
            executor.submit(self.database.cursor().execute, SPOTIFY_DELETE_OLD_ENTRIES, values)

    def maybe_migrate(self) -> None:
        """Maybe migrate Database schema for the local cache"""
        current_version = 0
        with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
            for future in concurrent.futures.as_completed(
                [executor.submit(self.database.cursor().execute, self.statement.get_user_version)]
            ):
                try:
                    row_result = future.result()
                    current_version = row_result.fetchone()
                    break
                except Exception as exc:
                    log.verbose("Failed to completed fetch from database", exc_info=exc)
            if isinstance(current_version, tuple):
                current_version = current_version[0]
            if current_version == _SCHEMA_VERSION:
                return
            executor.submit(
                self.database.cursor().execute,
                self.statement.set_user_version,
                {"version": _SCHEMA_VERSION},
            )

    async def insert(self, values: List[MutableMapping]) -> None:
        """Insert an entry into the local cache"""
        try:
            with self.database.transaction() as transaction:
                transaction.executemany(self.statement.upsert, values)
        except Exception as exc:
            log.trace("Error during table insert", exc_info=exc)

    async def update(self, values: MutableMapping) -> None:
        """Update an entry of the local cache"""

        try:
            time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp())
            values["last_fetched"] = time_now
            with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
                executor.submit(self.database.cursor().execute, self.statement.update, values)
        except Exception as exc:
            log.verbose("Error during table update", exc_info=exc)

    async def _fetch_one(
        self, values: MutableMapping
    ) -> Optional[
        Union[LavalinkCacheFetchResult, SpotifyCacheFetchResult, YouTubeCacheFetchResult]
    ]:
        """Get an entry from the local cache"""
        max_age = await self.config.cache_age()
        maxage = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(days=max_age)
        maxage_int = int(time.mktime(maxage.timetuple()))
        values.update({"maxage": maxage_int})
        row = None
        with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
            for future in concurrent.futures.as_completed(
                [executor.submit(self.database.cursor().execute, self.statement.get_one, values)]
            ):
                try:
                    row_result = future.result()
                    row = row_result.fetchone()
                except Exception as exc:
                    log.verbose("Failed to completed fetch from database", exc_info=exc)
        if not row:
            return None
        if self.fetch_result is None:
            return None
        return self.fetch_result(*row)

    async def _fetch_all(
        self, values: MutableMapping
    ) -> List[Union[LavalinkCacheFetchResult, SpotifyCacheFetchResult, YouTubeCacheFetchResult]]:
        """Get all entries from the local cache"""
        output = []
        row_result = []
        if self.fetch_result is None:
            return []
        with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
            for future in concurrent.futures.as_completed(
                [executor.submit(self.database.cursor().execute, self.statement.get_all, values)]
            ):
                try:
                    row_result = future.result()
                except Exception as exc:
                    log.verbose("Failed to completed fetch from database", exc_info=exc)
        async for row in AsyncIter(row_result):
            output.append(self.fetch_result(*row))
        return output

    async def _fetch_random(
        self, values: MutableMapping
    ) -> Optional[
        Union[LavalinkCacheFetchResult, SpotifyCacheFetchResult, YouTubeCacheFetchResult]
    ]:
        """Get a random entry from the local cache"""
        row = None
        with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
            for future in concurrent.futures.as_completed(
                [
                    executor.submit(
                        self.database.cursor().execute, self.statement.get_random, values
                    )
                ]
            ):
                try:
                    row_result = future.result()
                    rows = row_result.fetchall()
                    if rows:
                        row = random.choice(rows)
                    else:
                        row = None
                except Exception as exc:
                    log.verbose("Failed to completed random fetch from database", exc_info=exc)
        if not row:
            return None
        if self.fetch_result is None:
            return None
        return self.fetch_result(*row)


class YouTubeTableWrapper(BaseWrapper):
    def __init__(
        self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog]
    ):
        super().__init__(bot, config, conn, cog)
        self.statement.upsert = YOUTUBE_UPSERT
        self.statement.update = YOUTUBE_UPDATE
        self.statement.get_one = YOUTUBE_QUERY
        self.statement.get_all = YOUTUBE_QUERY_ALL
        self.statement.get_random = YOUTUBE_QUERY_LAST_FETCHED_RANDOM
        self.fetch_result = YouTubeCacheFetchResult

    async def fetch_one(
        self, values: MutableMapping
    ) -> Tuple[Optional[str], Optional[datetime.datetime]]:
        """Get an entry from the Youtube table"""
        result = await self._fetch_one(values)
        if not result or not isinstance(result.query, str):
            return None, None
        return result.query, result.updated_on

    async def fetch_all(self, values: MutableMapping) -> List[YouTubeCacheFetchResult]:
        """Get all entries from the Youtube table"""
        result = await self._fetch_all(values)
        if result and isinstance(result[0], YouTubeCacheFetchResult):
            return result
        return []

    async def fetch_random(self, values: MutableMapping) -> Optional[str]:
        """Get a random entry from the Youtube table"""
        result = await self._fetch_random(values)
        if not result or not isinstance(result.query, str):
            return None
        return result.query


class SpotifyTableWrapper(BaseWrapper):
    def __init__(
        self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog]
    ):
        super().__init__(bot, config, conn, cog)
        self.statement.upsert = SPOTIFY_UPSERT
        self.statement.update = SPOTIFY_UPDATE
        self.statement.get_one = SPOTIFY_QUERY
        self.statement.get_all = SPOTIFY_QUERY_ALL
        self.statement.get_random = SPOTIFY_QUERY_LAST_FETCHED_RANDOM
        self.fetch_result = SpotifyCacheFetchResult

    async def fetch_one(
        self, values: MutableMapping
    ) -> Tuple[Optional[str], Optional[datetime.datetime]]:
        """Get an entry from the Spotify table"""
        result = await self._fetch_one(values)
        if not result or not isinstance(result.query, str):
            return None, None
        return result.query, result.updated_on

    async def fetch_all(self, values: MutableMapping) -> List[SpotifyCacheFetchResult]:
        """Get all entries from the Spotify table"""
        result = await self._fetch_all(values)
        if result and isinstance(result[0], SpotifyCacheFetchResult):
            return result
        return []

    async def fetch_random(self, values: MutableMapping) -> Optional[str]:
        """Get a random entry from the Spotify table"""
        result = await self._fetch_random(values)
        if not result or not isinstance(result.query, str):
            return None
        return result.query


class LavalinkTableWrapper(BaseWrapper):
    def __init__(
        self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog]
    ):
        super().__init__(bot, config, conn, cog)
        self.statement.upsert = LAVALINK_UPSERT
        self.statement.update = LAVALINK_UPDATE
        self.statement.get_one = LAVALINK_QUERY
        self.statement.get_all = LAVALINK_QUERY_ALL
        self.statement.get_random = LAVALINK_QUERY_LAST_FETCHED_RANDOM
        self.statement.get_all_global = LAVALINK_FETCH_ALL_ENTRIES_GLOBAL
        self.fetch_result = LavalinkCacheFetchResult
        self.fetch_for_global: Optional[Callable] = LavalinkCacheFetchForGlobalResult

    async def fetch_one(
        self, values: MutableMapping
    ) -> Tuple[Optional[MutableMapping], Optional[datetime.datetime]]:
        """Get an entry from the Lavalink table"""
        result = await self._fetch_one(values)
        if not result or not isinstance(result.query, dict):
            return None, None
        return result.query, result.updated_on

    async def fetch_all(self, values: MutableMapping) -> List[LavalinkCacheFetchResult]:
        """Get all entries from the Lavalink table"""
        result = await self._fetch_all(values)
        if result and isinstance(result[0], LavalinkCacheFetchResult):
            return result
        return []

    async def fetch_random(self, values: MutableMapping) -> Optional[MutableMapping]:
        """Get a random entry from the Lavalink table"""
        result = await self._fetch_random(values)
        if not result or not isinstance(result.query, dict):
            return None
        return result.query

    async def fetch_all_for_global(self) -> List[LavalinkCacheFetchForGlobalResult]:
        """Get all entries from the Lavalink table"""
        output: List[LavalinkCacheFetchForGlobalResult] = []
        row_result = []
        if self.fetch_for_global is None:
            return []
        with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
            for future in concurrent.futures.as_completed(
                [executor.submit(self.database.cursor().execute, self.statement.get_all_global)]
            ):
                try:
                    row_result = future.result()
                except Exception as exc:
                    log.verbose("Failed to completed fetch from database", exc_info=exc)
        async for row in AsyncIter(row_result):
            output.append(self.fetch_for_global(*row))
        return output


class LocalCacheWrapper:
    """Wraps all table apis into 1 object representing the local cache"""

    def __init__(
        self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog]
    ):
        self.bot = bot
        self.config = config
        self.database = conn
        self.cog = cog
        self.lavalink: LavalinkTableWrapper = LavalinkTableWrapper(bot, config, conn, self.cog)
        self.spotify: SpotifyTableWrapper = SpotifyTableWrapper(bot, config, conn, self.cog)
        self.youtube: YouTubeTableWrapper = YouTubeTableWrapper(bot, config, conn, self.cog)