bot/cogs/get_birds.py
# get_birds.py | commands for getting bird images or songs
# Copyright (C) 2019-2021 EraserBird, person_v1.32, hmmm
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import random
import string
from typing import Optional
from discord import app_commands
from discord.ext import commands
import bot.voice as voice_functions
from bot.core import send_bird
from bot.data import GenericError, database, goatsuckers, logger, states, taxons
from bot.data_functions import bird_setup, session_increment
from bot.filters import Filter, MediaType, arg_autocomplete
from bot.functions import CustomCooldown, build_id_list, check_state_role
BASE_MESSAGE = (
"*Here you go!* \n**Use `b!{new_cmd}` again to get a new {media} of the same bird, "
+ "or `b!{skip_cmd}` to get a new bird. Use `b!{check_cmd} guess` to check your answer. "
+ "Use `b!{hint_cmd}` for a hint.**"
)
BIRD_MESSAGE = BASE_MESSAGE.format(
media="image", new_cmd="bird", skip_cmd="skip", check_cmd="check", hint_cmd="hint"
)
GS_MESSAGE = BASE_MESSAGE.format(
media="image",
new_cmd="gs",
skip_cmd="skip",
check_cmd="check",
hint_cmd="hint",
)
SONG_MESSAGE = BASE_MESSAGE.format(
media="song",
new_cmd="song",
skip_cmd="skip",
check_cmd="check",
hint_cmd="hint",
)
class Birds(commands.Cog):
def __init__(self, bot):
self.bot = bot
async def _send_next_race_media(self, ctx):
if database.exists(f"race.data:{ctx.channel.id}"):
if Filter.from_int(
int(database.hget(f"race.data:{ctx.channel.id}", "filter"))
).vc:
await voice_functions.stop(ctx, silent=True)
media = database.hget(f"race.data:{ctx.channel.id}", "media").decode(
"utf-8"
)
logger.info(f"auto sending next bird {media}")
filter_int, taxon, state = database.hmget(
f"race.data:{ctx.channel.id}", ["filter", "taxon", "state"]
)
await self.send_bird_(
ctx,
media,
Filter.from_int(int(filter_int)),
taxon.decode("utf-8"),
state.decode("utf-8"),
)
def error_handle(
self,
ctx,
media_type: MediaType,
filters: Filter,
taxon_str,
role_str,
retries,
):
"""Return a function to pass to send_bird() as on_error."""
# pylint: disable=unused-argument
async def inner(error):
nonlocal retries
# skip current bird
database.hset(f"channel:{ctx.channel.id}", "bird", "")
database.hset(f"channel:{ctx.channel.id}", "answered", "1")
if retries >= 2: # only retry twice
await ctx.send("**Too many retries.**\n*Please try again.*")
await self._send_next_race_media(ctx)
return
if isinstance(error, GenericError) and error.code == 100:
retries += 1
await ctx.send("**Retrying...**")
await self.send_bird_(
ctx, media_type, filters, taxon_str, role_str, retries
)
else:
await ctx.send("*Please try again.*")
await self._send_next_race_media(ctx)
return inner
@staticmethod
def error_skip(ctx):
async def inner(error):
# pylint: disable=unused-argument
# skip current bird
database.hset(f"channel:{ctx.channel.id}", "bird", "")
database.hset(f"channel:{ctx.channel.id}", "answered", "1")
await ctx.send("*Please try again.*")
return inner
@staticmethod
def increment_bird_frequency(ctx, bird):
bird_setup(ctx, bird)
database.zincrby("frequency.bird:global", 1, string.capwords(bird))
async def send_bird_(
self,
ctx,
media: Optional[str],
filters: Filter,
taxon_str: str = "",
role_str: str = "",
retries=0,
):
media_type = (
MediaType.IMAGE
if media in ("images", "image", "i", "p", MediaType.IMAGE)
else (
MediaType.SONG
if media in ("songs", "song", "s", "a", MediaType.SONG)
else None
)
)
if not media_type:
raise GenericError("Invalid media type", code=990)
if media_type is MediaType.SONG and filters.vc:
current_voice = database.get(f"voice.server:{ctx.guild.id}")
if current_voice is not None and current_voice.decode("utf-8") != str(
ctx.channel.id
):
logger.info("already vc race")
await ctx.send("**The voice channel is currently in use!**")
return
if taxon_str:
taxon = taxon_str.split(" ")
else:
taxon = []
if role_str:
roles = role_str.split(" ")
else:
roles = []
logger.info(
"bird: "
+ database.hget(f"channel:{ctx.channel.id}", "bird").decode("utf-8")
)
currently_in_race = bool(database.exists(f"race.data:{ctx.channel.id}"))
new_user = database.zscore("users:global", str(ctx.author.id)) < 10
answered = int(database.hget(f"channel:{ctx.channel.id}", "answered"))
logger.info(f"answered: {answered}")
# check to see if previous bird was answered
if answered: # if yes, give a new bird
session_increment(ctx, "total", 1)
logger.info(f"filters: {filters}; taxon: {taxon}; roles: {roles}")
if not currently_in_race and retries == 0:
await ctx.send(
"**Recognized arguments:** "
+ f"*Active Filters*: `{'`, `'.join(filters.display())}`, "
+ f"*Taxons*: `{'None' if taxon_str == '' else taxon_str}`, "
+ f"*Detected State*: `{'None' if role_str == '' else role_str}`"
)
find_custom_role = {i if i.startswith("CUSTOM:") else "" for i in roles}
find_custom_role.discard("")
if (
database.exists(f"race.data:{ctx.channel.id}")
and len(find_custom_role) == 1
):
custom_role = find_custom_role.pop()
roles.remove(custom_role)
roles.append("CUSTOM")
user_id = custom_role.split(":")[1]
birds = build_id_list(
user_id=user_id, taxon=taxon, state=roles, media_type=media_type
)
else:
birds = build_id_list(
user_id=ctx.author.id,
taxon=taxon,
state=roles,
media_type=media_type,
)
if not birds:
logger.info("no birds for taxon/state")
await ctx.send(
"**Sorry, no birds could be found for the taxon/state combo.**\n*Please try again*"
)
return
if len(birds) < 2:
logger.info("list less than 2 items")
await ctx.send(
"**Sorry, you must have at least 2 birds in the taxon/state combo."
+ "**\n*Please try again with a different set of taxons/lists.*"
)
return
currentBird = random.choice(birds)
self.increment_bird_frequency(ctx, currentBird)
prevB = database.hget(f"channel:{ctx.channel.id}", "prevB").decode("utf-8")
while currentBird == prevB and len(birds) > 1:
currentBird = random.choice(birds)
database.hset(f"channel:{ctx.channel.id}", "prevB", str(currentBird))
database.hset(f"channel:{ctx.channel.id}", "bird", str(currentBird))
logger.info("currentBird: " + str(currentBird))
database.hset(f"channel:{ctx.channel.id}", "answered", "0")
await send_bird(
ctx,
currentBird,
media_type,
filters,
on_error=self.error_handle(
ctx, media_type, filters, taxon_str, role_str, retries
),
message=(SONG_MESSAGE if media_type is MediaType.SONG else BIRD_MESSAGE)
if not currently_in_race and new_user
else "*Here you go!*",
)
else: # if no, give the same bird
await ctx.send(f"**Active Filters**: `{'`, `'.join(filters.display())}`")
await send_bird(
ctx,
database.hget(f"channel:{ctx.channel.id}", "bird").decode("utf-8"),
media_type,
filters,
on_error=self.error_handle(
ctx, media_type, filters, taxon_str, role_str, retries
),
message=(SONG_MESSAGE if media_type is MediaType.SONG else BIRD_MESSAGE)
if not currently_in_race and new_user
else "*Here you go!*",
)
@staticmethod
async def parse(ctx, args_str: str):
"""Parse arguments for options."""
args = args_str.split(" ")
logger.info(f"args: {args}")
if not database.exists(f"race.data:{ctx.channel.id}"):
roles = check_state_role(ctx)
taxon_args = set(taxons.keys()).intersection({arg.lower() for arg in args})
if taxon_args:
taxon = " ".join(taxon_args).strip()
else:
taxon = ""
state_args = set(states.keys()).intersection({arg.upper() for arg in args})
if state_args:
state = " ".join(state_args).strip()
else:
state = ""
if database.exists(f"session.data:{ctx.author.id}"):
logger.info("session parameters")
if taxon_args:
current_taxons = set(
database.hget(f"session.data:{ctx.author.id}", "taxon")
.decode("utf-8")
.split(" ")
)
logger.info(f"toggle taxons: {taxon_args}")
logger.info(f"current taxons: {current_taxons}")
taxon_args.symmetric_difference_update(current_taxons)
taxon_args.discard("")
logger.info(f"new taxons: {taxon_args}")
taxon = " ".join(taxon_args).strip()
else:
taxon = database.hget(
f"session.data:{ctx.author.id}", "taxon"
).decode("utf-8")
roles = (
database.hget(f"session.data:{ctx.author.id}", "state")
.decode("utf-8")
.split(" ")
)
if roles[0] == "":
roles = []
if not roles:
logger.info("no session lists")
roles = check_state_role(ctx)
session_filter = int(
database.hget(f"session.data:{ctx.author.id}", "filter")
)
filters = Filter.parse(args_str, defaults=False)
if filters.vc:
filters.vc = False
await ctx.send("**The VC filter is not allowed inline!**")
default_quality = Filter().quality
if (
Filter.from_int(session_filter).quality == default_quality
and filters.quality
and filters.quality != default_quality
):
filters ^= Filter() # clear defaults
filters ^= session_filter
else:
filters = Filter.parse(args_str)
if filters.vc:
filters.vc = False
await ctx.send("**The VC filter is not allowed inline!**")
if state_args:
logger.info(f"toggle states: {state_args}")
logger.info(f"current states: {roles}")
state_args.symmetric_difference_update(set(roles))
state_args.discard("")
logger.info(f"new states: {state_args}")
state = " ".join(state_args).strip()
else:
state = " ".join(roles).strip()
if "CUSTOM" in state.upper().split(" "):
if not database.exists(f"custom.list:{ctx.author.id}"):
await ctx.send("**You don't have a custom list set!**")
state_list = state.split(" ")
state_list.remove("CUSTOM")
state = " ".join(state_list)
elif database.exists(f"custom.confirm:{ctx.author.id}"):
await ctx.send(
"**Please verify or confirm your custom list before using!**"
)
state_list = state.split(" ")
state_list.remove("CUSTOM")
state = " ".join(state_list)
else:
logger.info("race parameters")
race_filter = int(database.hget(f"race.data:{ctx.channel.id}", "filter"))
filters = Filter.parse(args_str, defaults=False)
if filters.vc:
filters.vc = False
await ctx.send("**The VC filter is not allowed inline!**")
default_quality = Filter().quality
if (
Filter.from_int(race_filter).quality == default_quality
and filters.quality
and filters.quality != default_quality
):
filters ^= Filter() # clear defaults
filters ^= race_filter
taxon = database.hget(f"race.data:{ctx.channel.id}", "taxon").decode(
"utf-8"
)
state = database.hget(f"race.data:{ctx.channel.id}", "state").decode(
"utf-8"
)
logger.info(f"args: filters: {filters}; taxon: {taxon}; state: {state}")
return (filters, taxon, state)
# Bird command - no args
# help text
@commands.hybrid_command(
help="- Sends a random bird image for you to ID",
aliases=["b"],
usage="[filters] [order/family] [state]",
)
# 5 second cooldown
@commands.check(CustomCooldown(5.0, bucket=commands.BucketType.channel))
@app_commands.rename(args_str="options")
@app_commands.describe(
args_str="Macaulay Library filters, bird lists, or taxons. Muliple options can be used at once (even if it doesn't autocomplete)"
)
@app_commands.autocomplete(args_str=arg_autocomplete)
async def bird(self, ctx: commands.Context, *, args_str: str = ""):
logger.info("command: bird")
filters, taxon, state = await self.parse(ctx, args_str)
media = "images"
if database.exists(f"race.data:{ctx.channel.id}"):
media = database.hget(f"race.data:{ctx.channel.id}", "media").decode(
"utf-8"
)
await self.send_bird_(ctx, media, filters, taxon, state)
# picks a random bird call to send
@commands.hybrid_command(
help="- Sends a random bird song for you to ID",
aliases=["s"],
usage="[filters] [order/family] [state]",
)
@commands.check(CustomCooldown(5.0, bucket=commands.BucketType.channel))
@app_commands.rename(args_str="options")
@app_commands.describe(
args_str="Macaulay Library filters, bird lists, or taxons. Muliple options can be used at once (even if it doesn't autocomplete)"
)
@app_commands.autocomplete(args_str=arg_autocomplete)
async def song(self, ctx: commands.Context, *, args_str: str = ""):
logger.info("command: song")
filters, taxon, state = await self.parse(ctx, args_str)
media = "songs"
if database.exists(f"race.data:{ctx.channel.id}"):
media = database.hget(f"race.data:{ctx.channel.id}", "media").decode(
"utf-8"
)
await self.send_bird_(ctx, media, filters, taxon, state)
# goatsucker command - no args
# just for fun, no real purpose
@commands.hybrid_command(help="- Sends a random goatsucker to ID", aliases=["gs"])
@commands.check(CustomCooldown(5.0, bucket=commands.BucketType.channel))
async def goatsucker(self, ctx: commands.Context):
logger.info("command: goatsucker")
if database.exists(f"race.data:{ctx.channel.id}"):
await ctx.send("This command is disabled during races.")
return
answered = int(database.hget(f"channel:{ctx.channel.id}", "answered"))
# check to see if previous bird was answered
if answered: # if yes, give a new bird
session_increment(ctx, "total", 1)
database.hset(f"channel:{ctx.channel.id}", "answered", "0")
currentBird = random.choice(goatsuckers)
self.increment_bird_frequency(ctx, currentBird)
database.hset(f"channel:{ctx.channel.id}", "bird", str(currentBird))
logger.info("currentBird: " + str(currentBird))
await send_bird(
ctx,
currentBird,
MediaType.IMAGE,
Filter(),
on_error=self.error_skip(ctx),
message=GS_MESSAGE,
)
else: # if no, give the same bird
await send_bird(
ctx,
database.hget(f"channel:{ctx.channel.id}", "bird").decode("utf-8"),
MediaType.IMAGE,
Filter(),
on_error=self.error_skip(ctx),
message=GS_MESSAGE,
)
async def setup(bot):
await bot.add_cog(Birds(bot))