Cog-Creators/Red-DiscordBot

View on GitHub
redbot/core/_settings_caches.py

Summary

Maintainability
A
0 mins
Test Coverage
from __future__ import annotations

from typing import Dict, List, Optional, Union, Set, Iterable, Tuple, overload
import asyncio
from argparse import Namespace
from collections import defaultdict

import discord

from .config import Config
from .utils import AsyncIter


class PrefixManager:
    def __init__(self, config: Config, cli_flags: Namespace):
        self._config: Config = config
        self._global_prefix_override: Optional[List[str]] = (
            sorted(cli_flags.prefix, reverse=True) or None
        )
        self._cached: Dict[Optional[int], List[str]] = {}

    async def get_prefixes(self, guild: Optional[discord.Guild] = None) -> List[str]:
        ret: List[str]

        gid: Optional[int] = guild.id if guild else None

        if gid in self._cached:
            ret = self._cached[gid].copy()
        else:
            if gid is not None:
                ret = await self._config.guild_from_id(gid).prefix()
                if not ret:
                    ret = await self.get_prefixes(None)
            else:
                ret = self._global_prefix_override or (await self._config.prefix())

            self._cached[gid] = ret.copy()

        return ret

    async def set_prefixes(
        self, guild: Optional[discord.Guild] = None, prefixes: Optional[List[str]] = None
    ):
        gid: Optional[int] = guild.id if guild else None
        prefixes = prefixes or []
        if not isinstance(prefixes, list) and not all(isinstance(pfx, str) for pfx in prefixes):
            raise TypeError("Prefixes must be a list of strings")
        if any(prefix.startswith("/") for prefix in prefixes):
            raise ValueError(
                "Prefixes cannot start with '/', as it conflicts with Discord's slash commands."
            )
        prefixes = sorted(prefixes, reverse=True)
        if gid is None:
            if not prefixes:
                raise ValueError("You must have at least one prefix.")
            self._cached.clear()
            await self._config.prefix.set(prefixes)
        else:
            self._cached.pop(gid, None)
            await self._config.guild_from_id(gid).prefix.set(prefixes)


class I18nManager:
    def __init__(self, config: Config):
        self._config: Config = config
        self._guild_locale: Dict[Union[int, None], Union[str, None]] = {}
        self._guild_regional_format: Dict[Union[int, None], Union[str, None]] = {}

    async def get_locale(self, guild: Union[discord.Guild, None]) -> str:
        """Get the guild locale from the cache"""
        # Ensure global locale is in the cache
        if None not in self._guild_locale:
            global_locale = await self._config.locale()
            self._guild_locale[None] = global_locale

        if guild is None:  # Not a guild so cannot support guild locale
            # Return the bot's globally set locale if its None on a guild scope.
            return self._guild_locale[None]
        elif guild.id in self._guild_locale:  # Cached guild
            if self._guild_locale[guild.id] is None:
                return self._guild_locale[None]
            else:
                return self._guild_locale[guild.id]
        else:  # Uncached guild
            out = await self._config.guild(guild).locale()  # No locale set
            if out is None:
                self._guild_locale[guild.id] = None
                return self._guild_locale[None]
            else:
                self._guild_locale[guild.id] = out
                return out

    @overload
    async def set_locale(self, guild: None, locale: str):
        ...

    @overload
    async def set_locale(self, guild: discord.Guild, locale: Union[str, None]):
        ...

    async def set_locale(
        self, guild: Union[discord.Guild, None], locale: Union[str, None]
    ) -> None:
        """Set the locale in the config and cache"""
        if guild is None:
            if locale is None:
                # this method should never be called like this
                raise ValueError("Global locale can't be None!")
            self._guild_locale[None] = locale
            await self._config.locale.set(locale)
            return
        self._guild_locale[guild.id] = locale
        await self._config.guild(guild).locale.set(locale)

    async def get_regional_format(self, guild: Union[discord.Guild, None]) -> Optional[str]:
        """Get the regional format from the cache"""
        # Ensure global locale is in the cache
        if None not in self._guild_regional_format:
            global_regional_format = await self._config.regional_format()
            self._guild_regional_format[None] = global_regional_format

        if guild is None:  # Not a guild so cannot support guild locale
            return self._guild_regional_format[None]
        elif guild.id in self._guild_regional_format:  # Cached guild
            if self._guild_regional_format[guild.id] is None:
                return self._guild_regional_format[None]
            else:
                return self._guild_regional_format[guild.id]
        else:  # Uncached guild
            out = await self._config.guild(guild).regional_format()  # No locale set
            if out is None:
                self._guild_regional_format[guild.id] = None
                return self._guild_regional_format[None]
            else:  # Not cached, got a custom regional format.
                self._guild_regional_format[guild.id] = out
                return out

    async def set_regional_format(
        self, guild: Union[discord.Guild, None], regional_format: Union[str, None]
    ) -> None:
        """Set the regional format in the config and cache"""
        if guild is None:
            self._guild_regional_format[None] = regional_format
            await self._config.regional_format.set(regional_format)
            return
        self._guild_regional_format[guild.id] = regional_format
        await self._config.guild(guild).regional_format.set(regional_format)


class IgnoreManager:
    def __init__(self, config: Config):
        self._config: Config = config
        self._cached_channels: Dict[int, bool] = {}
        self._cached_guilds: Dict[int, bool] = {}

    async def get_ignored_channel(
        self,
        channel: Union[
            discord.TextChannel,
            discord.VoiceChannel,
            discord.StageChannel,
            discord.ForumChannel,
            discord.Thread,
        ],
        check_category: bool = True,
    ) -> bool:
        ret: bool

        cid: int = channel.id
        cat_id: Optional[int] = (
            channel.category.id if check_category and channel.category else None
        )
        if cid in self._cached_channels:
            chan_ret = self._cached_channels[cid]
        else:
            chan_ret = await self._config.channel_from_id(cid).ignored()
            self._cached_channels[cid] = chan_ret
        if cat_id and cat_id in self._cached_channels:
            cat_ret = self._cached_channels[cat_id]
        else:
            if cat_id:
                cat_ret = await self._config.channel_from_id(cat_id).ignored()
                self._cached_channels[cat_id] = cat_ret
            else:
                cat_ret = False
        ret = chan_ret or cat_ret

        return ret

    async def set_ignored_channel(
        self,
        channel: Union[
            discord.TextChannel,
            discord.VoiceChannel,
            discord.StageChannel,
            discord.Thread,
            discord.ForumChannel,
            discord.CategoryChannel,
        ],
        set_to: bool,
    ):
        cid: int = channel.id
        self._cached_channels[cid] = set_to
        if set_to:
            await self._config.channel_from_id(cid).ignored.set(set_to)
        else:
            await self._config.channel_from_id(cid).ignored.clear()

    async def get_ignored_guild(self, guild: discord.Guild) -> bool:
        ret: bool

        gid: int = guild.id

        if gid in self._cached_guilds:
            ret = self._cached_guilds[gid]
        else:
            ret = await self._config.guild_from_id(gid).ignored()
            self._cached_guilds[gid] = ret

        return ret

    async def set_ignored_guild(self, guild: discord.Guild, set_to: bool):
        gid: int = guild.id
        self._cached_guilds[gid] = set_to
        if set_to:
            await self._config.guild_from_id(gid).ignored.set(set_to)
        else:
            await self._config.guild_from_id(gid).ignored.clear()


class WhitelistBlacklistManager:
    def __init__(self, config: Config):
        self._config: Config = config
        self._cached_whitelist: Dict[Optional[int], Set[int]] = {}
        self._cached_blacklist: Dict[Optional[int], Set[int]] = {}
        # because of discord deletion
        # we now have sync and async access that may need to happen at the
        # same time.
        # blame discord for this.
        self._access_lock = asyncio.Lock()

    async def discord_deleted_user(self, user_id: int):
        async with self._access_lock:
            async for guild_id_or_none, ids in AsyncIter(
                self._cached_whitelist.items(), steps=100
            ):
                ids.discard(user_id)

            async for guild_id_or_none, ids in AsyncIter(
                self._cached_blacklist.items(), steps=100
            ):
                ids.discard(user_id)

            for grp in (self._config.whitelist, self._config.blacklist):
                async with grp() as ul:
                    try:
                        ul.remove(user_id)
                    except ValueError:
                        pass

            # don't use this in extensions, it's optimized and controlled for here,
            # but can't be safe in 3rd party use

            async with self._config._get_base_group("GUILD").all() as abuse:
                for guild_str, guild_data in abuse.items():
                    for l_name in ("whitelist", "blacklist"):
                        try:
                            guild_data[l_name].remove(user_id)
                        except (ValueError, KeyError):
                            pass  # this is raw access not filled with defaults

    async def get_whitelist(self, guild: Optional[discord.Guild] = None) -> Set[int]:
        async with self._access_lock:
            ret: Set[int]
            gid: Optional[int] = guild.id if guild else None
            if gid in self._cached_whitelist:
                ret = self._cached_whitelist[gid].copy()
            else:
                if gid is not None:
                    ret = set(await self._config.guild_from_id(gid).whitelist())
                else:
                    ret = set(await self._config.whitelist())

                self._cached_whitelist[gid] = ret.copy()

            return ret

    async def add_to_whitelist(self, guild: Optional[discord.Guild], role_or_user: Iterable[int]):
        async with self._access_lock:
            gid: Optional[int] = guild.id if guild else None
            role_or_user = role_or_user or []
            if not all(isinstance(r_or_u, int) for r_or_u in role_or_user):
                raise TypeError("`role_or_user` must be an iterable of `int`s.")

            if gid is None:
                if gid not in self._cached_whitelist:
                    self._cached_whitelist[gid] = set(await self._config.whitelist())
                self._cached_whitelist[gid].update(role_or_user)
                await self._config.whitelist.set(list(self._cached_whitelist[gid]))

            else:
                if gid not in self._cached_whitelist:
                    self._cached_whitelist[gid] = set(
                        await self._config.guild_from_id(gid).whitelist()
                    )
                self._cached_whitelist[gid].update(role_or_user)
                await self._config.guild_from_id(gid).whitelist.set(
                    list(self._cached_whitelist[gid])
                )

    async def clear_whitelist(self, guild: Optional[discord.Guild] = None):
        async with self._access_lock:
            gid: Optional[int] = guild.id if guild else None
            self._cached_whitelist[gid] = set()
            if gid is None:
                await self._config.whitelist.clear()
            else:
                await self._config.guild_from_id(gid).whitelist.clear()

    async def remove_from_whitelist(
        self, guild: Optional[discord.Guild], role_or_user: Iterable[int]
    ):
        async with self._access_lock:
            gid: Optional[int] = guild.id if guild else None
            role_or_user = role_or_user or []
            if not all(isinstance(r_or_u, int) for r_or_u in role_or_user):
                raise TypeError("`role_or_user` must be an iterable of `int`s.")

            if gid is None:
                if gid not in self._cached_whitelist:
                    self._cached_whitelist[gid] = set(await self._config.whitelist())
                self._cached_whitelist[gid].difference_update(role_or_user)
                await self._config.whitelist.set(list(self._cached_whitelist[gid]))

            else:
                if gid not in self._cached_whitelist:
                    self._cached_whitelist[gid] = set(
                        await self._config.guild_from_id(gid).whitelist()
                    )
                self._cached_whitelist[gid].difference_update(role_or_user)
                await self._config.guild_from_id(gid).whitelist.set(
                    list(self._cached_whitelist[gid])
                )

    async def get_blacklist(self, guild: Optional[discord.Guild] = None) -> Set[int]:
        async with self._access_lock:
            ret: Set[int]
            gid: Optional[int] = guild.id if guild else None
            if gid in self._cached_blacklist:
                ret = self._cached_blacklist[gid].copy()
            else:
                if gid is not None:
                    ret = set(await self._config.guild_from_id(gid).blacklist())
                else:
                    ret = set(await self._config.blacklist())

                self._cached_blacklist[gid] = ret.copy()

            return ret

    async def add_to_blacklist(self, guild: Optional[discord.Guild], role_or_user: Iterable[int]):
        async with self._access_lock:
            gid: Optional[int] = guild.id if guild else None
            role_or_user = role_or_user or []
            if not all(isinstance(r_or_u, int) for r_or_u in role_or_user):
                raise TypeError("`role_or_user` must be an iterable of `int`s.")
            if gid is None:
                if gid not in self._cached_blacklist:
                    self._cached_blacklist[gid] = set(await self._config.blacklist())
                self._cached_blacklist[gid].update(role_or_user)
                await self._config.blacklist.set(list(self._cached_blacklist[gid]))
            else:
                if gid not in self._cached_blacklist:
                    self._cached_blacklist[gid] = set(
                        await self._config.guild_from_id(gid).blacklist()
                    )
                self._cached_blacklist[gid].update(role_or_user)
                await self._config.guild_from_id(gid).blacklist.set(
                    list(self._cached_blacklist[gid])
                )

    async def clear_blacklist(self, guild: Optional[discord.Guild] = None):
        async with self._access_lock:
            gid: Optional[int] = guild.id if guild else None
            self._cached_blacklist[gid] = set()
            if gid is None:
                await self._config.blacklist.clear()
            else:
                await self._config.guild_from_id(gid).blacklist.clear()

    async def remove_from_blacklist(
        self, guild: Optional[discord.Guild], role_or_user: Iterable[int]
    ):
        async with self._access_lock:
            gid: Optional[int] = guild.id if guild else None
            role_or_user = role_or_user or []
            if not all(isinstance(r_or_u, int) for r_or_u in role_or_user):
                raise TypeError("`role_or_user` must be an iterable of `int`s.")
            if gid is None:
                if gid not in self._cached_blacklist:
                    self._cached_blacklist[gid] = set(await self._config.blacklist())
                self._cached_blacklist[gid].difference_update(role_or_user)
                await self._config.blacklist.set(list(self._cached_blacklist[gid]))
            else:
                if gid not in self._cached_blacklist:
                    self._cached_blacklist[gid] = set(
                        await self._config.guild_from_id(gid).blacklist()
                    )
                self._cached_blacklist[gid].difference_update(role_or_user)
                await self._config.guild_from_id(gid).blacklist.set(
                    list(self._cached_blacklist[gid])
                )


class DisabledCogCache:
    def __init__(self, config: Config):
        self._config = config
        self._disable_map: Dict[str, Dict[int, bool]] = defaultdict(dict)

    async def cog_disabled_in_guild(self, cog_name: str, guild_id: int) -> bool:
        """
        Check if a cog is disabled in a guild

        Parameters
        ----------
        cog_name: str
            This should be the cog's qualified name, not necessarily the classname
        guild_id: int

        Returns
        -------
        bool
        """

        if guild_id in self._disable_map[cog_name]:
            return self._disable_map[cog_name][guild_id]

        gset = await self._config.custom("COG_DISABLE_SETTINGS", cog_name, guild_id).disabled()
        if gset is None:
            gset = await self._config.custom("COG_DISABLE_SETTINGS", cog_name, 0).disabled()
            if gset is None:
                gset = False

        self._disable_map[cog_name][guild_id] = gset
        return gset

    async def default_disable(self, cog_name: str):
        """
        Sets the default for a cog as disabled.

        Parameters
        ----------
        cog_name: str
            This should be the cog's qualified name, not necessarily the classname
        """
        await self._config.custom("COG_DISABLE_SETTINGS", cog_name, 0).disabled.set(True)
        self._disable_map.pop(cog_name, None)

    async def default_enable(self, cog_name: str):
        """
        Sets the default for a cog as enabled.

        Parameters
        ----------
        cog_name: str
            This should be the cog's qualified name, not necessarily the classname
        """
        await self._config.custom("COG_DISABLE_SETTINGS", cog_name, 0).disabled.clear()
        self._disable_map.pop(cog_name, None)

    async def disable_cog_in_guild(self, cog_name: str, guild_id: int) -> bool:
        """
        Disable a cog in a guild.

        Parameters
        ----------
        cog_name: str
            This should be the cog's qualified name, not necessarily the classname
        guild_id: int

        Returns
        -------
        bool
            Whether or not any change was made.
            This may be useful for settings commands.
        """

        if await self.cog_disabled_in_guild(cog_name, guild_id):
            return False

        self._disable_map[cog_name][guild_id] = True
        await self._config.custom("COG_DISABLE_SETTINGS", cog_name, guild_id).disabled.set(True)
        return True

    async def enable_cog_in_guild(self, cog_name: str, guild_id: int) -> bool:
        """
        Enable a cog in a guild.

        Parameters
        ----------
        cog_name: str
            This should be the cog's qualified name, not necessarily the classname
        guild_id: int

        Returns
        -------
        bool
            Whether or not any change was made.
            This may be useful for settings commands.
        """

        if not await self.cog_disabled_in_guild(cog_name, guild_id):
            return False

        self._disable_map[cog_name][guild_id] = False
        await self._config.custom("COG_DISABLE_SETTINGS", cog_name, guild_id).disabled.set(False)
        return True