itsVale/Vale.py

View on GitHub
utils/paginator.py

Summary

Maintainability
C
7 hrs
Test Coverage
import asyncio
import collections
import contextlib
import functools
import itertools
import re

import discord
from discord.ext import commands
from more_itertools import chunked, consume, iter_except, unique_everseen

from .colors import random_color
from .misc import maybe_awaitable
from .queue import SimpleQueue

_Trigger = collections.namedtuple('_Trigger', 'emoji pattern blocking fallback')


def trigger(emoji, pattern=None, *, blocking=False, fallback=None):
    """Add a function that will be called with a certain reaction.

    If pattern is a string, it will be used as a regex pattern for messages to trigger that functions.

    If fallback is a string, it will be used as a regex pattern like pattern.
    But this will only be used if the bot can't add reactions.

    If blocking is True, reactions will be ignored for the duration of the execution.
    """

    def decorator(func):
        func.__trigger__ = _Trigger(emoji=emoji, pattern=pattern, blocking=blocking, fallback=fallback)
        return func

    return decorator


paginated = functools.partial(commands.bot_has_permissions, embed_links=True)
# Remove the predicate from the check
_validate_context = paginated()(lambda: 0).__commands_checks__[0]


class _TriggerCooldown(commands.CooldownMapping):
    def __init__(self):
        super().__init__(commands.Cooldown(rate=5, per=2, type=commands.BucketType.user))

    def _bucket_key(self, tup):
        return tup

    def is_rate_limited(self, message_id, user_id):
        bucket = self.get_bucket((message_id, user_id))
        return bucket.update_rate_limit() is not None


_trigger_cooldown = _TriggerCooldown()


class _Callback(collections.namedtuple('_Callback', 'func blocking')):
    """Wrapper class to store the resolved descriptor and the blocking attribute."""

    __slots__ = ()

    @property
    def __doc__(self):
        return self.func.__doc__


class InteractiveSession:
    """Base class for all interactive sessions.

    Subclasses must implement the 'default' method. If necessary, they can override 'start'.

    A page should either return a discord.Embed, or None to indicate the page was invalid somehow.
    e.g. the page number given was out of bounds, or there were side effects associated with it.
    """

    def __init__(self, ctx):
        self.ctx = ctx
        self._bot = ctx.bot
        self._channel = ctx.channel
        self._users = {ctx.author.id}
        self._message = None
        self._blocking = False

        self._current = None

        self._queue = SimpleQueue()

        self._using_reactions = False

    def __init_subclass__(cls, stop_emoji='\N{BLACK SQUARE FOR STOP}', stop_pattern=None, stop_fallback='exit', **kwargs):
        super().__init_subclass__(**kwargs)
        cls._reaction_map = callbacks = collections.OrderedDict()

        cls._message_callbacks = message_callbacks = []
        cls._message_fallbacks = message_fallbacks = []
        known_patterns = set()

        def trigger_iterator():
            name_members = itertools.chain.from_iterable(b.__dict__.items() for b in cls.__mro__)
            for name, member in unique_everseen(name_members, key=lambda p: p[0]):
                trigger = getattr(member, '__trigger__', None)
                if not trigger:
                    continue

                resolved = getattr(cls, name)
                callback = _Callback(resolved, trigger.blocking)
                yield trigger.emoji, trigger.pattern, trigger.fallback, callback

            if stop_emoji or stop_pattern or stop_fallback:
                yield stop_emoji, stop_pattern, stop_fallback, _Callback(cls.stop, False)

        for emoji, pattern, fallback, callback in trigger_iterator():
            if emoji not in callbacks:
                callbacks[emoji] = callback

            if pattern and pattern not in known_patterns:
                known_patterns.add(pattern)
                message_callbacks.append((pattern, callback))

            if fallback and fallback not in known_patterns:
                known_patterns.add(fallback)
                message_fallbacks.append((fallback, callback))

    def using_reactions(self):
        """Returns True if reactions are being used for the current session, False otherwise.

        This can return False if the session is not running.
        """

        return self._using_reactions

    def check(self, reaction, _):
        """Extra checks for reactions."""

        return reaction.emoji in self._reaction_map

    async def add_reactions(self):
        """Adds the reactions to the message."""

        for emoji in self._reaction_map:
            await self._message.add_reaction(emoji)

    def default(self):
        """Returns the first embed to start the controller."""

        raise NotImplementedError

    async def start(self):
        """First thing that gets called."""

        self._current = embed = await maybe_awaitable(self.default)
        self._message = await self._channel.send(embed=embed)

    async def stop(self):
        """Stops running the controller."""

        await self._queue.put(None)

    async def cleanup(self, *, delete_after):
        """Cleans up anything else after stopping."""

        method = self._message.delete if delete_after else self._message.clear_reactions()
        with contextlib.suppress(Exception):
            await method()

    async def run(self, *, timeout=120, delete_after=True):
        """Runs the interactive loop."""

        _validate_context(self.ctx)

        self._using_reactions = self._channel.permissions_for(self.ctx.me).add_reactions

        await self.start()
        if not self._message:
            raise RuntimeError('start() must be self._message.')

        message = self._message
        triggers = self._message_callbacks.copy()
        task = None
        listeners = []

        def listen(func):
            listeners.append(func)
            return self._bot.listen()(func)

        if self._using_reactions:
            task = self._bot.loop.create_task(self.add_reactions())

            @listen
            async def on_reaction_add(reaction, user):
                if (
                    not self._blocking
                    and reaction.message.id == message.id
                    and user.id in self._users
                    and self.check(reaction, user)
                    and not _trigger_cooldown.is_rate_limited(message.id, user.id)
                ):
                    callback, self._blocking = self._reaction_map[reaction.emoji]
                    cleanup = functools.partial(message.remove_reaction, reaction.emoji, user)
                    await self._queue.put((callback, cleanup))

        else:
            triggers.extend(self._message_fallbacks)

        if triggers:
            @listen
            async def on_message(msg):
                if (
                    self._blocking
                    or msg.channel != self._channel
                    or msg.author.id not in self._users
                ):
                    return

                patterns, callbacks = zip(*triggers)
                selectors = map(re.fullmatch, patterns, itertools.repeat(msg.content))
                callback = next(itertools.compress(callbacks, selectors), None)
                if not callback:
                    return

                if _trigger_cooldown.is_rate_limited(message.id, msg.author.id):
                    return

                callback, self._blocking = callback
                await self._queue.put((callback, msg.delete))

        try:
            while True:
                # Would async_timeout be better here?
                try:
                    job = await asyncio.wait_for(self._queue.get(), timeout=timeout)
                except asyncio.TimeoutError:
                    break

                if job is None:
                    break

                callback, after = job

                result = await maybe_awaitable(callback, self)

                with contextlib.suppress(discord.HTTPException):
                    await after()

                self._blocking = False

                if result is None:
                    continue

                self._current = result

                try:
                    await message.edit(embed=result)
                except discord.NotFound:
                    break

        finally:
            self._using_reactions = False

            for listener in listeners:
                self._bot.remove_listener(listener)

            if not (task is None or task.done()):
                task.cancel()

            consume(iter_except(self._queue.get_nowait, asyncio.QueueEmpty))
            await self.cleanup(delete_after=delete_after)

    interact = run

    @property
    def reaction_help(self):
        return '\n'.join(itertools.starmap('{0} => {1.__doc__}'.format, self._reaction_map.items()))


class Paginator(InteractiveSession):
    """A paginator takes an iterable of entries and paginates them."""

    def __init__(self, ctx, entries, *, per_page=15, title=discord.Embed.Empty, color=None):
        super().__init__(ctx)

        self._pages = tuple(chunked(entries, per_page))
        self._index = 0

        if not color:
            color = random_color()

        self.title = title
        self.color = color

    @property
    def single_page(self):
        """Returns whether there's just a single page or not."""

        return len(self._pages) == 1

    @property
    def small(self):
        """Returns whether there are five pages or less or not."""

        return len(self._pages) <= 5

    @property
    def total(self):
        """Returns the total number of entries in the list."""

        return sum(map(len, self._pages))

    async def start(self):
        """Starts paginating."""

        await super().start()
        if self.single_page:
            # If there's only one page, paginating is unnecessary
            await self.stop()

    async def cleanup(self, *, delete_after):
        if not self.single_page:
            await super().cleanup(delete_after=delete_after)

    async def add_reactions(self):
        if self.single_page:
            return

        fast_forwards = {'\U000023ed', '\U000023ee'}
        small = self.small

        for emoji in self._reaction_map:
            if not (small and emoji in fast_forwards):
                await self._message.add_reaction(emoji)

    def create_embed(self, page):
        """Creates an embed given a slice of entries."""

        return (discord.Embed(title=self.title, description='\n'.join(page), color=self.color)
                .set_footer(text=f'Page {self._index + 1}/{len(self._pages)} ({self.total} total)'))

    def page_at(self, index):
        """Returns the embed that would be created at a certain page. None if the index is out of bounds."""

        if not 0 <= index < len(self._pages):
            return None

        self._index = index
        return self.create_embed(self._pages[index])

    @trigger('\N{BLACK LEFT-POINTING DOUBLE TRIANGLE WITH VERTICAL BAR}', fallback=r'\<\<')
    def default(self):
        """First page."""

        return self.page_at(0)

    @trigger('\N{BLACK LEFT-POINTING TRIANGLE}', fallback=r'\<')
    def previous(self):
        """Previous page."""

        return self.page_at(self._index - 1)

    @trigger('\N{BLACK RIGHT-POINTING TRIANGLE}', fallback=r'\>')
    def next(self):
        """Next page."""

        return self.page_at(self._index + 1)

    @trigger('\N{BLACK RIGHT-POINTING DOUBLE TRIANGLE WITH VERTICAL BAR}', fallback=r'\>\>')
    def last(self):
        """Last page."""

        return self.page_at(len(self._pages) - 1)

    def _goto_embed(self):
        ctx = self.ctx
        description = (
            f'Please enter a number from 1 to {len(self._pages)}.\n\n'
            'To cancel, click \N{INPUT SYMBOL FOR NUMBERS} again.'
        )

        return (discord.Embed(description=description, color=self.color)
                .set_author(name=f'What page do you want to go, {ctx.author.display_name}?'))

    def _goto_parse_input(self, content):
        try:
            index = int(content)
        except ValueError:
            return None

        return self.page_at(index - 1)

    @trigger('\N{INPUT SYMBOL FOR NUMBERS}', blocking=True)
    async def goto(self):
        """Go to a certain page."""

        ctx = self.ctx
        return_result = None
        user_message = None

        def check(m):
            nonlocal return_result, user_message
            if not (m.channel.id == self._channel.id and m.author.id == ctx.author.id):
                return False

            result = self._goto_parse_input(m.content)
            if not result:
                return False

            return_result = result
            user_message = m
            return True

        def remove_check(reaction, user):
            return (reaction.message.id == self._message.id
                    and user.id == ctx.author.id
                    and reaction.emoji == '\N{INPUT SYMBOL FOR NUMBERS}')

        to_wait = [
            self._bot.wait_for('message', check=check),
            self._bot.wait_for('reaction_remove', check=remove_check),
        ]

        try:
            embed = self._goto_embed()
            delete_always = await self._channel.send(embed=embed)

            done, pending = await asyncio.wait(to_wait, loop=self._bot.loop, timeout=60, return_when=asyncio.FIRST_COMPLETED)
            for fut in pending:
                fut.cancel()

            if not done:
                return None

            result = done.pop().result()

            if isinstance(result, discord.Message):
                return return_result

            return None

        finally:
            for m in [delete_always, user_message]:
                with contextlib.suppress(Exception):
                    await m.delete()


class FieldPaginator(Paginator):
    """Similar to Paginator, but uses the fields instead of the description."""

    def __init__(self, ctx, entries, *, inline=True, **kwargs):
        super().__init__(ctx, entries, **kwargs)

        self.inline = inline

        if len(self._pages) > 25:
            raise ValueError('Too many fields per page! The maximum is 25!')

    def create_embed(self, page):
        embed = (discord.Embed(title=self.title, color=self.color)
                 .set_footer(text=f'Page: {self._index + 1} / {len(self._pages)} ({self.total} total)'))

        add_field = functools.partial(embed.add_field, inline=self.inline)
        for name, value in page:
            add_field(name=name, value=value)

        return embed