PyDrocsid/cogs

View on GitHub
integrations/reddit/cog.py

Summary

Maintainability
A
0 mins
Test Coverage
import re
from datetime import datetime
from typing import List, Optional

from aiohttp import ClientSession
from discord import Embed, TextChannel
from discord.ext import commands, tasks
from discord.ext.commands import CommandError, Context, UserInputError, guild_only

from PyDrocsid.cog import Cog
from PyDrocsid.command import reply
from PyDrocsid.config import Config
from PyDrocsid.database import db, db_wrapper, filter_by, select
from PyDrocsid.logger import get_logger
from PyDrocsid.translations import t
from PyDrocsid.util import check_message_send_permissions

from .colors import Colors
from .models import RedditChannel, RedditPost
from .permissions import RedditPermission
from .settings import RedditSettings
from ...contributor import Contributor
from ...pubsub import send_alert, send_to_changelog


tg = t.g
t = t.reddit

logger = get_logger(__name__)


def remove_prefix(subreddit: str) -> str:
    return re.sub("^(/r/|r/)", "", subreddit)


async def get_subreddit_name(subreddit: str) -> str | None:
    subreddit = remove_prefix(subreddit)
    async with ClientSession() as session, session.get(
        # raw_json=1 as parameter to get unicode characters instead of html escape sequences
        f"https://www.reddit.com/r/{subreddit}/about.json?raw_json=1",
        headers={"User-agent": f"{Config.NAME}/{Config.VERSION}"},
        allow_redirects=False,
    ) as response:
        if response.status != 200:
            return None

        return (await response.json())["data"]["display_name"]


async def fetch_reddit_posts(subreddit: str, limit: int) -> Optional[List[dict]]:
    subreddit = remove_prefix(subreddit)
    async with ClientSession() as session, session.get(
        # raw_json=1 as parameter to get unicode characters instead of html escape sequences
        f"https://www.reddit.com/r/{subreddit}/hot.json?raw_json=1",
        headers={"User-agent": f"{Config.NAME}/{Config.VERSION}"},
        params={"limit": str(limit)},
    ) as response:
        if response.status != 200:
            return None

        data = (await response.json())["data"]

    filter_nsfw = await RedditSettings.filter_nsfw.get()
    posts: List[dict] = []
    for post in data["children"]:
        # t3 = link
        if post["kind"] == "t3" and post["data"].get("post_hint") == "image":
            if post["data"]["over_18"] and filter_nsfw:
                continue

            posts.append(
                {
                    "id": post["data"]["id"],
                    "author": post["data"]["author"],
                    "title": post["data"]["title"],
                    "created_utc": post["data"]["created_utc"],
                    "score": post["data"]["score"],
                    "num_comments": post["data"]["num_comments"],
                    "permalink": post["data"]["permalink"],
                    "url": post["data"]["url"],
                    "subreddit": post["data"]["subreddit"],
                }
            )
    return posts


def create_embed(post: dict) -> Embed:
    embed = Embed(
        # add a blank character after every : and . to prevent wrong redirects for titles
        title=post["title"].replace(":", ":\u200b").replace(".", ".\u200b"),
        url=f"https://reddit.com{post['permalink']}",
        description=f"{post['score']} :thumbsup: \u00B7 {post['num_comments']} :speech_balloon:",
        colour=Colors.Reddit,  # Reddit's brand color
    )
    embed.set_author(name=f"u/{post['author']}", url=f"https://reddit.com/u/{post['author']}")
    embed.set_image(url=post["url"])
    embed.set_footer(text=f"r/{post['subreddit']}")
    embed.timestamp = datetime.utcfromtimestamp(post["created_utc"])
    return embed


class RedditCog(Cog, name="Reddit"):
    CONTRIBUTORS = [Contributor.Scriptim, Contributor.Defelo, Contributor.wolflu, Contributor.Anorak]

    async def on_ready(self):
        interval = await RedditSettings.interval.get()
        await self.start_loop(interval)

    @tasks.loop()
    @db_wrapper
    async def reddit_loop(self):
        await self.pull_hot_posts()

    async def pull_hot_posts(self):
        logger.info("pulling hot reddit posts")
        limit = await RedditSettings.limit.get()
        async for reddit_channel in await db.stream(select(RedditChannel)):  # type: RedditChannel
            text_channel: Optional[TextChannel] = self.bot.get_channel(reddit_channel.channel)
            if text_channel is None:
                await db.delete(reddit_channel)
                continue

            try:
                check_message_send_permissions(text_channel, check_embed=True)
            except CommandError:
                await send_alert(self.bot.guilds[0], t.cannot_send(text_channel.mention))
                continue

            posts = await fetch_reddit_posts(reddit_channel.subreddit, limit)
            if posts is None:
                await send_alert(self.bot.guilds[0], t.could_not_fetch(reddit_channel.subreddit))
                continue

            for post in posts:
                if await RedditPost.post(post["id"]):
                    await text_channel.send(embed=create_embed(post))

        await RedditPost.clean()

    async def start_loop(self, interval):
        self.reddit_loop.cancel()
        self.reddit_loop.change_interval(hours=interval)
        try:
            self.reddit_loop.start()
        except RuntimeError:
            self.reddit_loop.restart()

    @commands.group()
    @RedditPermission.read.check
    @guild_only()
    async def reddit(self, ctx: Context):
        """
        manage reddit integration
        """

        if ctx.subcommand_passed is not None:
            if ctx.invoked_subcommand is None:
                raise UserInputError
            return

        embed = Embed(title=t.reddit, colour=Colors.Reddit)

        interval = await RedditSettings.interval.get()
        embed.add_field(name=t.interval, value=t.x_hours(cnt=interval))

        limit = await RedditSettings.limit.get()
        embed.add_field(name=t.limit, value=str(limit))

        filter_nsfw = await RedditSettings.filter_nsfw.get()
        embed.add_field(name=t.nsfw_filter, value=tg.enabled if filter_nsfw else tg.disabled, inline=False)

        out = []
        async for reddit_channel in await db.stream(select(RedditChannel)):  # type: RedditChannel
            text_channel: Optional[TextChannel] = self.bot.get_channel(reddit_channel.channel)
            if text_channel is None:
                await db.delete(reddit_channel)
            else:
                sub = reddit_channel.subreddit
                out.append(f":small_orange_diamond: [r/{sub}](https://reddit.com/r/{sub}) -> {text_channel.mention}")
        embed.add_field(name=t.reddit_links, value="\n".join(out) or t.no_reddit_links, inline=False)

        await reply(ctx, embed=embed)

    @reddit.command(name="add", aliases=["a", "+"])
    @RedditPermission.write.check
    async def reddit_add(self, ctx: Context, subreddit: str, channel: TextChannel):
        """
        create a link between a subreddit and a channel
        """

        if not (subreddit := await get_subreddit_name(subreddit)):
            raise CommandError(t.subreddit_not_found)

        check_message_send_permissions(channel, check_embed=True)

        if await db.exists(filter_by(RedditChannel, subreddit=subreddit, channel=channel.id)):
            raise CommandError(t.reddit_link_already_exists)

        await RedditChannel.create(subreddit, channel.id)
        embed = Embed(title=t.reddit, colour=Colors.Reddit, description=t.reddit_link_created)
        await reply(ctx, embed=embed)
        await send_to_changelog(ctx.guild, t.log_reddit_link_created(subreddit, channel.mention))

    @reddit.command(name="remove", aliases=["r", "del", "d", "-"])
    @RedditPermission.write.check
    async def reddit_remove(self, ctx: Context, subreddit: str, channel: TextChannel):
        """
        remove a reddit link
        """

        subreddit = await get_subreddit_name(subreddit) or subreddit
        link: Optional[RedditChannel] = await db.get(RedditChannel, subreddit=subreddit, channel=channel.id)
        if link is None:
            raise CommandError(t.reddit_link_not_found)

        await db.delete(link)
        embed = Embed(title=t.reddit, colour=Colors.Reddit, description=t.reddit_link_removed)
        await reply(ctx, embed=embed)
        await send_to_changelog(ctx.guild, t.log_reddit_link_removed(subreddit, channel.mention))

    @reddit.command(name="interval", aliases=["int", "i"])
    @RedditPermission.write.check
    async def reddit_interval(self, ctx: Context, hours: int):
        """
        change lookup interval (in hours)
        """

        if not 0 < hours < (1 << 31):
            raise CommandError(t.invalid_interval)

        await RedditSettings.interval.set(hours)
        await self.start_loop(hours)
        embed = Embed(title=t.reddit, colour=Colors.Reddit, description=t.reddit_interval_set)
        await reply(ctx, embed=embed)
        await send_to_changelog(ctx.guild, t.log_reddit_interval_set(cnt=hours))

    @reddit.command(name="limit", aliases=["lim"])
    @RedditPermission.write.check
    async def reddit_limit(self, ctx: Context, limit: int):
        """
        change limit of posts to be sent concurrently
        """

        if not 0 < limit < (1 << 31):
            raise CommandError(t.invalid_limit)

        await RedditSettings.limit.set(limit)
        embed = Embed(title=t.reddit, colour=Colors.Reddit, description=t.reddit_limit_set)
        await reply(ctx, embed=embed)
        await send_to_changelog(ctx.guild, t.log_reddit_limit_set(limit))

    @reddit.command(name="nsfw_filter", aliases=["nsfw"])
    @RedditPermission.write.check
    async def reddit_nsfw_filter(self, ctx: Context, enabled: bool):
        """
        enable/disable nsfw filter for posts
        """

        embed = Embed(title=t.reddit, colour=Colors.Reddit)
        await RedditSettings.filter_nsfw.set(enabled)
        if enabled:
            embed.description = t.nsfw_filter_now_enabled
            await send_to_changelog(ctx.guild, t.log_nsfw_filter_now_enabled)
        else:
            embed.description = t.nsfw_filter_now_disabled
            await send_to_changelog(ctx.guild, t.log_nsfw_filter_now_disabled)
        await reply(ctx, embed=embed)

    @reddit.command(name="trigger", aliases=["t"])
    @RedditPermission.trigger.check
    async def reddit_trigger(self, ctx: Context):
        """
        pull hot posts now and reset the timer
        """

        await self.start_loop(await RedditSettings.interval.get())
        embed = Embed(title=t.reddit, colour=Colors.Reddit, description=t.done)
        await reply(ctx, embed=embed)