Cog-Creators/Red-DiscordBot

View on GitHub
redbot/core/_drivers/postgres/postgres.py

Summary

Maintainability
A
0 mins
Test Coverage
import getpass
import json
import sys
from pathlib import Path
from typing import Optional, Any, AsyncIterator, Tuple, Union, Callable, List

try:
    # pylint: disable=import-error
    import asyncpg
except ModuleNotFoundError:
    asyncpg = None

from ... import data_manager, errors
from ..base import BaseDriver, IdentifierData, ConfigCategory
from ..log import log

__all__ = ["PostgresDriver"]

_PKG_PATH = Path(__file__).parent
DDL_SCRIPT_PATH = _PKG_PATH / "ddl.sql"
DROP_DDL_SCRIPT_PATH = _PKG_PATH / "drop_ddl.sql"


def encode_identifier_data(
    id_data: IdentifierData,
) -> Tuple[str, str, str, List[str], List[str], int, bool]:
    return (
        id_data.cog_name,
        id_data.uuid,
        id_data.category,
        ["0"] if id_data.category == ConfigCategory.GLOBAL else list(id_data.primary_key),
        list(id_data.identifiers),
        1 if id_data.category == ConfigCategory.GLOBAL else id_data.primary_key_len,
        id_data.is_custom,
    )


class PostgresDriver(BaseDriver):
    _pool: Optional["asyncpg.pool.Pool"] = None

    @classmethod
    async def initialize(cls, **storage_details) -> None:
        if asyncpg is None:
            raise errors.MissingExtraRequirements(
                "Red must be installed with the [postgres] extra to use the PostgreSQL driver"
            )
        cls._pool = await asyncpg.create_pool(**storage_details)
        with DDL_SCRIPT_PATH.open() as fs:
            await cls._pool.execute(fs.read())

    @classmethod
    async def teardown(cls) -> None:
        if cls._pool is not None:
            await cls._pool.close()

    @staticmethod
    def get_config_details():
        unixmsg = (
            ""
            if sys.platform == "win32"
            else (
                " - Common directories for PostgreSQL Unix-domain sockets (/run/postgresql, "
                "/var/run/postgresl, /var/pgsql_socket, /private/tmp, and /tmp),\n"
            )
        )
        host = (
            input(
                f"Enter the PostgreSQL server's address.\n"
                f"If left blank, Red will try the following, in order:\n"
                f" - The PGHOST environment variable,\n{unixmsg}"
                f" - localhost.\n"
                f"> "
            )
            or None
        )

        print(
            "Enter the PostgreSQL server port.\n"
            "If left blank, this will default to either:\n"
            " - The PGPORT environment variable,\n"
            " - 5432."
        )
        while True:
            port = input("> ") or None
            if port is None:
                break

            try:
                port = int(port)
            except ValueError:
                print("Port must be a number")
            else:
                break

        user = (
            input(
                "Enter the PostgreSQL server username.\n"
                "If left blank, this will default to either:\n"
                " - The PGUSER environment variable,\n"
                " - The OS name of the user running Red (ident/peer authentication).\n"
                "> "
            )
            or None
        )

        passfile = r"%APPDATA%\postgresql\pgpass.conf" if sys.platform == "win32" else "~/.pgpass"
        password = getpass.getpass(
            f"Enter the PostgreSQL server password. The input will be hidden.\n"
            f"  NOTE: If using ident/peer authentication (no password), enter NONE.\n"
            f"When NONE is entered, this will default to:\n"
            f" - The PGPASSWORD environment variable,\n"
            f" - Looking up the password in the {passfile} passfile,\n"
            f" - No password.\n"
            f"> "
        )
        if password == "NONE":
            password = None

        database = (
            input(
                "Enter the PostgreSQL database's name.\n"
                "If left blank, this will default to either:\n"
                " - The PGDATABASE environment variable,\n"
                " - The OS name of the user running Red.\n"
                "> "
            )
            or None
        )

        return {
            "host": host,
            "port": port,
            "user": user,
            "password": password,
            "database": database,
        }

    async def get(self, identifier_data: IdentifierData):
        result = await self._execute(
            "SELECT red_config.get($1)",
            encode_identifier_data(identifier_data),
            method=self._pool.fetchval,
        )

        if result is None:
            # The result is None both when postgres yields no results, or when it yields a NULL row
            # A 'null' JSON value would be returned as encoded JSON, i.e. the string 'null'
            raise KeyError
        return json.loads(result)

    async def set(self, identifier_data: IdentifierData, value=None):
        try:
            await self._execute(
                "SELECT red_config.set($1, $2::jsonb)",
                encode_identifier_data(identifier_data),
                json.dumps(value),
            )
        except asyncpg.ErrorInAssignmentError:
            raise errors.CannotSetSubfield

    async def clear(self, identifier_data: IdentifierData):
        await self._execute("SELECT red_config.clear($1)", encode_identifier_data(identifier_data))

    async def inc(
        self, identifier_data: IdentifierData, value: Union[int, float], default: Union[int, float]
    ) -> Union[int, float]:
        try:
            return await self._execute(
                f"SELECT red_config.inc($1, $2, $3)",
                encode_identifier_data(identifier_data),
                value,
                default,
                method=self._pool.fetchval,
            )
        except asyncpg.WrongObjectTypeError as exc:
            raise errors.StoredTypeError(*exc.args)

    async def toggle(self, identifier_data: IdentifierData, default: bool) -> bool:
        try:
            return await self._execute(
                "SELECT red_config.inc($1, $2)",
                encode_identifier_data(identifier_data),
                default,
                method=self._pool.fetchval,
            )
        except asyncpg.WrongObjectTypeError as exc:
            raise errors.StoredTypeError(*exc.args)

    @classmethod
    async def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]:
        query = "SELECT cog_name, cog_id FROM red_config.red_cogs"
        log.invisible(query)
        async with cls._pool.acquire() as conn, conn.transaction():
            async for row in conn.cursor(query):
                yield row["cog_name"], row["cog_id"]

    @classmethod
    async def delete_all_data(cls, *, drop_db: Optional[bool] = None, **kwargs) -> None:
        """Delete all data being stored by this driver.

        Schemas within the database which
        store bot data will be dropped, as well as functions,
        aggregates, event triggers, and meta-tables.

        Parameters
        ----------
        drop_db : Optional[bool]
            If set to ``True``, function will print information
            about not being able to drop the entire database.

        """
        if drop_db is True:
            print(
                "Dropping the entire database is not possible in PostgreSQL driver."
                " We will delete all of Red's data within this database,"
                " without dropping the database itself."
            )
        with DROP_DDL_SCRIPT_PATH.open() as fs:
            await cls._pool.execute(fs.read())

    @classmethod
    async def _execute(cls, query: str, *args, method: Optional[Callable] = None) -> Any:
        if method is None:
            method = cls._pool.execute
        log.invisible("Query: %s", query)
        if args:
            log.invisible("Args: %s", args)
        return await method(query, *args)