NabDev/NabBot

View on GitHub
cogs/utils/context.py

Summary

Maintainability
A
0 mins
Test Coverage
#  Copyright 2019 Allan Galarza
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import asyncio
import functools
import re
from typing import Any, Callable, Optional, Sequence, TypeVar, Union

import aiohttp
import asyncpg
import discord
from discord.ext import commands

import nabbot
from . import config, safe_delete_message
from .database import get_server_property

_mention = re.compile(r'<@!?([0-9]{1,19})>')

T = TypeVar('T')


class NabCtx(commands.Context):
    """An override of :class:`commands.Context` that provides properties and methods for NabBot."""
    bot: "nabbot.NabBot"
    guild: discord.Guild
    message: discord.Message
    channel: discord.TextChannel
    author: Union[discord.User, discord.Member]
    me: Union[discord.Member, discord.ClientUser]
    command: commands.Command

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.pool: asyncpg.pool.Pool = self.bot.pool
        self.session: aiohttp.ClientSession = self.bot.session
        self.yes_no_reactions = ("🇾", "🇳")
        self.check_reactions = (config.true_emoji, config.false_emoji)

    # region Properties
    @property
    def author_permissions(self) -> discord.Permissions:
        """Shortcut to check the command author's permission to the current channel.

        :return: The permissions for the author in the current channel.
        """
        return self.channel.permissions_for(self.author)

    @property
    def bot_permissions(self) -> discord.Permissions:
        """Shortcut to check the bot's permission to the current channel.

        :return: The permissions for the author in the current channel."""
        return self.channel.permissions_for(self.me)

    @property
    def clean_prefix(self) -> str:
        """Gets the clean prefix used in the command invocation.

        This is used to clean mentions into plain text."""
        m = _mention.match(self.prefix)
        if m:
            user = self.bot.get_user(int(m.group(1)))
            if user:
                return f'@{user.name} '
        return self.prefix

    @property
    def is_lite(self) -> bool:
        """Checks if the current context is limited to lite mode.

        If the guild is in the lite_guilds list, the context is in lite mode.
        If the guild is in private message, and the message author is in at least ONE guild that is not in lite_guilds,
        then context is not lite"""
        if self.guild is not None:
            return self.guild.id in config.lite_servers
        if self.is_private:
            for g in self.bot.get_user_guilds(self.author.id):
                if g.id not in config.lite_servers:
                    return False
        return False

    @property
    def is_private(self) -> bool:
        """Whether the current context is a private channel or not."""
        return self.guild is None

    @property
    def usage(self) -> str:
        """Shows the parameters signature of the invoked command"""
        if self.command.usage:
            return self.command.usage

        params = self.command.clean_params
        if not params:
            return ''
        result = []
        for name, param in params.items():
            if param.default is not param.empty:
                # We don't want None or '' to trigger the [name=value] case and instead it should
                # do [name] since [name=None] or [name=] are not exactly useful for the user.
                should_print = param.default if isinstance(param.default, str) else param.default is not None
                if should_print:
                    result.append(f'[{name}={param.default!r}]')
                else:
                    result.append(f'[{name}]')
            elif param.kind == param.VAR_POSITIONAL:
                result.append(f'[{name}...]')
            else:
                result.append(f'<{name}>')

        return ' '.join(result)

    @property
    def world(self) -> Optional[str]:
        """Check the world that is currently being tracked by the guild

        :return: The world that the server is tracking.
        :rtype: str | None
        """
        if self.guild is None:
            return None
        else:
            return self.bot.tracked_worlds.get(self.guild.id, None)

    async def ask_channel_name(self) -> Optional[str]:
        """Gets the name of the ask channel for the current server.

        :return: The name of the ask channel if applicable
        :rtype: str or None"""
        if self.guild is None:
            return None
        ask_channel_id = await get_server_property(self.pool, self.guild.id, "ask_channel")
        ask_channel = self.guild.get_channel(ask_channel_id)
        if ask_channel is None:
            return config.ask_channel_name
        return ask_channel.name
    # endregion

    async def choose(self, matches: Sequence[Any], title="Suggestions", not_found=True):
        """Shows a list of options and awaits for the user's answer."""
        if len(matches) == 0:
            raise ValueError('No results found.')

        if len(matches) == 1:
            return matches[0]

        embed = discord.Embed(colour=discord.Colour.blurple(), title=title,
                              description='\n'.join(f'{index}: {item}' for index, item in enumerate(matches, 1)))

        suggestion_text = "Please choose one of the options.\n"
        not_found_text = "I couldn't find what you were looking for, maybe you mean one of these?\n"
        cancel_text = "**Only say the number** (*0 to cancel*)"
        text = not_found_text + cancel_text if not_found else suggestion_text + cancel_text
        msg = await self.send(text, embed=embed)

        def check(m: discord.Message):
            return m.content.isdigit() and m.author.id == self.author.id and m.channel.id == self.channel.id
        message = None
        try:
            message = await self.bot.wait_for('message', check=check, timeout=30.0)
            index = int(message.content)
            if index == 0:
                await self.send("Alright, choosing cancelled.", delete_after=10)
                return None
            try:
                await msg.delete()
                return matches[index - 1]
            except IndexError:
                await self.send(f"{self.tick(False)} That wasn't in the choices.", delete_after=10)
        except asyncio.TimeoutError:
            return None
        finally:
            try:
                if message:
                    await message.delete()
            except (discord.Forbidden, discord.NotFound):
                pass

    # region Methods
    async def error(self, content, *, embed=None, file=None, files=None, delete_after=None):
        """Sends a message prefixed by a cross."""
        content = f"{self.tick(False)} {content}"
        return await self.send(content, embed=embed, file=file, files=files, delete_after=delete_after)

    async def execute_async(self, func: Callable[..., T], *args, **kwargs) -> T:
        """Executes a synchronous function inside an executor.

        :param func: The function to call inside the executor.
        :param args: The function's arguments
        :param kwargs: The function's keyword arguments.
        :return: The value returned by the function, if any.
        """
        ret = await self.bot.loop.run_in_executor(None, functools.partial(func, *args, **kwargs))
        return ret

    async def input(self, *, timeout=60.0, clean=False, delete_response=False) \
            -> Optional[str]:
        """Waits for text input from the author.

        :param timeout: Maximum time to wait for a message.
        :param clean: Whether the content should be cleaned or not.
        :param delete_response: Whether to delete the author's message after.
        :return: The content of the message replied by the author
        """
        def check(_message):
            return _message.channel == self.channel and _message.author == self.author

        try:
            value = await self.bot.wait_for("message", timeout=timeout, check=check)
            if clean:
                ret = value.clean_content
            else:
                ret = value.content
            if delete_response:
                try:
                    await value.delete()
                except discord.HTTPException:
                    pass
            return ret
        except asyncio.TimeoutError:
            return None

    async def is_askchannel(self):
        """Checks if the current channel is the command channel"""
        ask_channel_id = await get_server_property(self.pool, self.guild.id, "ask_channel")
        ask_channel = self.guild.get_channel(ask_channel_id)
        if ask_channel is None:
            return self.channel.name == config.ask_channel_name
        return ask_channel == self.channel

    async def is_long(self) -> bool:
        """Whether the current context allows long replies or not

        Private messages and command channels allow long replies.
        """
        if self.guild is None:
            return True
        return await self.is_askchannel()

    async def react_confirm(self, message: discord.Message, *, timeout=60.0, delete_after=False,
                            use_checkmark=False) -> Optional[bool]:
        """Waits for the command author to reply with a Y or N reaction.

        Returns True if the user reacted with Y
        Returns False if the user reacted with N
        Returns None if the user didn't react at all

        :param message: The message that will contain the reactions.
        :param timeout: The maximum time to wait for reactions
        :param delete_after: Whether to delete or not the message after finishing.
        :param use_checkmark: Whether to use or not checkmarks instead of Y/N
        :return: True if reacted with Y, False if reacted with N, None if timeout.
        """
        if not self.channel.permissions_for(self.me).add_reactions:
            raise RuntimeError('Bot does not have Add Reactions permission.')

        reactions = self.check_reactions if use_checkmark else self.yes_no_reactions
        for emoji in reactions:
            emoji = emoji.replace("<", "").replace(">", "")
            await message.add_reaction(emoji)

        def check_react(reaction: discord.Reaction, user: discord.User):
            if reaction.message.id != message.id:
                return False
            if user.id != self.author.id:
                return False
            if reaction.emoji not in reactions:
                return False
            return True
        try:
            react = await self.bot.wait_for("reaction_add", timeout=timeout, check=check_react)
            if react[0].emoji == reactions[1]:
                return False
        except asyncio.TimeoutError:
            return None
        finally:
            if delete_after:
                await safe_delete_message(message)
            elif self.guild is not None:
                try:
                    await message.clear_reactions()
                except discord.Forbidden:
                    pass
        return True

    async def success(self, content, *, embed=None, file=None, files=None, delete_after=None):
        """Sends a message prefixed by a checkmark."""
        content = f"{self.tick(True)} {content}"
        return await self.send(content, embed=embed, file=file, files=files, delete_after=delete_after)

    def tick(self, value: bool = True, label: str = None) -> str:
        """Displays a checkmark or a cross depending on the value.

        :param value: The value to evaluate
        :param label: An optional label to display
        :return: A checkmark or cross
        """
        emoji = self.check_reactions[int(not value)]
        if label:
            return emoji + label
        return emoji

    # endregion

    async def show_help(self, command=None):
        """Shows the help command for the specified command if given.
        If no command is given, then it'll show help for the current
        command.
        """
        cmd = self.bot.get_command('help')
        command = command or self.command.qualified_name
        await self.invoke(cmd, command=command)