itsVale/Vale.py

View on GitHub
utils/disambiguate.py

Summary

Maintainability
A
0 mins
Test Coverage
import asyncio
import functools
import random
import re
import weakref
from itertools import starmap

import discord
from discord.ext import commands

from .examples import get_example
from .colors import random_color

_ID_REGEX = re.compile(r'([0-9]{15,21})$')


async def disambiguate(ctx, matches, transform=str, *, tries=3):
    """Prompts the user to choose from a list of matches."""

    if not matches:
        raise commands.BadArgument('No results found.')

    num_matches = len(matches)
    if num_matches == 1:
        return matches[0]

    entries = '\n'.join(starmap('{0}: {1}'.format, enumerate(map(transform, matches), 1)))

    permissions = ctx.channel.permissions_for(ctx.me)
    if permissions.embed_links:
        # Build the embed as we go. And make it nice and pretty.
        embed = discord.Embed(color=random_color(), description=entries)
        embed.set_author(name=f"There were {num_matches} matches found. Which one did you mean?")

        index = random.randrange(len(matches))
        instructions = f'Just type the number.\nFor example, typing `{index + 1}` will return {matches[index]}'
        embed.add_field(name='Instructions:', value=instructions)

        message = await ctx.send(embed=embed)
    else:
        await ctx.send('There are too many matches. Which one did you mean? **Only say the number**.')
        message = await ctx.send(entries)

    def check(m):
        return (m.author.id == ctx.author.id
                and m.channel.id == ctx.channel.id
                and m.content.isdigit())

    await ctx.release()

    try:
        for i in range(tries):
            try:
                msg = await ctx.bot.wait_for('message', check=check, timeout=30.0)
            except asyncio.TimeoutError:
                raise commands.BadArgument('Took too long. Goodbye.')

            index = int(msg.content)
            try:
                return matches[index - 1]
            except IndexError:
                await ctx.send(f'Please give me a valid number. {tries - i - 1} tries remaining...')

        raise commands.BadArgument('Too many tries. Goodbye.')
    finally:
        await message.delete()
        await ctx.acquire()


class _DisambiguateExampleGenerator:
    def __get__(self, obj, cls):
        cls_name = cls.__name__.replace('Disambiguate', '')
        return functools.partial(get_example, getattr(discord, cls_name))


class Converter(commands.Converter):
    """This is the base class for all disambiguating converters.

    By default, if there is more than one thing with a given name, the ext converters will only pick the first results.
    These allow you to pick from multiple results.

    Especially important when the args become case-insensitive.
    """

    _transform = str
    __converters__ = weakref.WeakValueDictionary()
    random_example = _DisambiguateExampleGenerator()

    def __init__(self, *, ignore_case=True):
        self.ignore_case = ignore_case

    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        Converter.__converters__[cls.__name__] = cls

    def _get_possible_entries(self, ctx):
        """Returns an iterable of possible entries to find matches with.

        Subclasses must provide this to allow disambiguating.
        """

        raise NotImplementedError

    def _exact_match(self, ctx, argument):
        """Returns an "exact" match given an argument.

        If this returns anything but None, that result will be returned without going through disambiguating.

        Subclasses may override this method to provide an "exact" functionality.
        """

        return None

    # The following predicates can be overridden if necessary

    def _predicate(self, obj, argument):
        """Standard predicate for filtering."""

        return obj.name == argument

    def _predicate_ignore_case(self, obj, argument):
        """Same thing like `predicate` but with case-insensitive filtering."""

        return obj.name.lower() == argument

    def _get_possible_results(self, ctx, argument):
        entries = self._get_possible_entries(ctx)

        if self.ignore_case:
            lowered = argument.lower()
            predicate = self._predicate_ignore_case
        else:
            lowered = argument
            predicate = self._predicate

        return [obj for obj in entries if predicate(obj, lowered)]

    async def convert(self, ctx, argument):
        exact_match = self._exact_match(ctx, argument)
        if exact_match:
            return exact_match

        matches = self._get_possible_results(ctx, argument)
        return await disambiguate(ctx, matches, transform=self._transform)


class IDConverter(Converter):
    MENTION_REGEX = None

    def _get_from_id(self, ctx, id):
        """Returns an object via a given ID."""

        raise NotImplementedError

    def __get_id_from_mention(self, argument):
        return re.match(self.MENTION_REGEX, argument) if self.MENTION_REGEX else None

    def _exact_match(self, ctx, argument):
        match = _ID_REGEX.match(argument) or self.__get_id_from_mention(argument)
        if not match:
            return None

        return self._get_from_id(ctx, int(match[1]))


class UserConverterMixin:
    MENTION_REGEX = r'<@!?([0-9]+)>$'

    def _exact_match(self, ctx, argument):
        result = super()._exact_match(ctx, argument)
        if result is not None:
            return result

        if not (len(argument) > 5 and argument[-5] == '#'):
            # No discriminator provided which makes an exact match impossible
            return None

        name, _, discriminator = argument.rpartition('#')
        return discord.utils.find(
            lambda u: u.name == name and u.discriminator == discriminator,
            self._get_possible_entries(ctx)
        )


class User(UserConverterMixin, IDConverter):
    def _get_from_id(self, ctx, id):
        return ctx.bot.get_user(id)

    def _get_possible_entries(self, ctx):
        return ctx._state._users.values()


class Member(UserConverterMixin, IDConverter):
    def _get_from_id(self, ctx, id):
        return ctx.guild.get_member(id)

    def _get_possible_entries(self, ctx):
        return ctx.guild._members.values()

    # Overriding these is necessary due to members having nicknames
    def _predicate(self, obj, argument):
        return super()._predicate(obj, argument) or (obj.nick and obj.nick == argument)

    def _predicate_ignore_case(self, obj, argument):
        return (
            super()._predicate_ignore_case(obj, argument)
            or (obj.nick and obj.nick.lower() == argument)
        )


class Role(IDConverter):
    MENTION_REGEX = r'<@&([0-9]+)>$'

    def _get_from_id(self, ctx, id):
        return discord.utils.get(self._get_possible_entries(ctx), id=id)

    def _get_possible_entries(self, ctx):
        return ctx.guild.roles


class TextChannel(IDConverter):
    MENTION_REGEX = r'<#([0-9]+)>$'

    def _get_from_id(self, ctx, id):
        return ctx.guild.get_channel(id)

    def _get_possible_entries(self, ctx):
        return ctx.guild.text_channels


class Guild(IDConverter):
    def _get_from_id(self, ctx, id):
        return ctx.bot.get_guild(id)

    def _get_possible_entries(self, ctx):
        return ctx._state._guilds.values()


def _is_discord_py_type(cls):
    module = getattr(cls, '__module__', '')
    return module.startswith('discord.') and not module.endswith('converter')


def _disambiguated(type_):
    """Return the corresponding disambiguating converter if one exists
    for that type.
    If no such converter exists, it returns the type.
    """
    if not _is_discord_py_type(type_):
        return type_

    return Converter.__converters__.get(type_.__name__, type_)


def _get_current_parameter(ctx):
    parameters = list(ctx.command.params.values())

    # Need to account for varargs and consume-rest kwarg only
    index = min(len(ctx.args) + len(ctx.kwargs), len(parameters) - 1)
    return parameters[index]


class Union(commands.Converter):
    _transform = '{0} ({0.__class__.__name__})'.format

    def __init__(self, *types, ignore_case=True):
        self.types = [
            type_(ignore_case=ignore_case)
            if isinstance(type_, type) and issubclass(type_, Converter)
            else type_
            for type_ in map(_disambiguated, types)
        ]

    async def convert(self, ctx, argument):
        param = _get_current_parameter(ctx)
        results = []
        for converter in self.types:
            if isinstance(converter, Converter):
                exact = converter._exact_match(ctx, argument)
                if exact is not None:
                    return exact

                results.extend(converter._get_possible_results(ctx, argument))
            else:
                # Standard type, so standard conversion
                try:
                    result = await ctx.command.do_conversion(ctx, converter, argument, param)
                except commands.BadArgument:
                    continue
                else:
                    results.append(result)
        return await disambiguate(ctx, results, transform=self._transform)

    def random_example(self, ctx):
        return get_example(random.choice(self.types), ctx)