From 5c7a7ce2ed819ed1bc4fafd64fb7ad4b14eb18d3 Mon Sep 17 00:00:00 2001 From: Jareth Gomes Date: Fri, 31 May 2024 10:19:15 -0500 Subject: [PATCH] Migrated invitational collection to ODM --- src/discord/globals.py | 3 +- src/discord/invitationals.py | 42 ++-- src/discord/reporter.py | 17 +- src/discord/staff/invitationals.py | 334 +++++++++++++---------------- src/mongo/models.py | 20 ++ src/mongo/mongo.py | 12 -- 6 files changed, 183 insertions(+), 245 deletions(-) diff --git a/src/discord/globals.py b/src/discord/globals.py index b4bd50b..9f77c3b 100644 --- a/src/discord/globals.py +++ b/src/discord/globals.py @@ -2,6 +2,7 @@ Holds global variables shared between cogs and variables that are initialized when the bot is first setup. """ +from src.mongo.models import Invitational ############## # CONSTANTS @@ -121,7 +122,7 @@ CENSOR = {} EVENT_INFO = [] PING_INFO = [] -INVITATIONAL_INFO = [] +INVITATIONAL_INFO: list[Invitational] = [] REPORTS = [] TAGS = [] CURRENT_WIKI_PAGE = None diff --git a/src/discord/invitationals.py b/src/discord/invitationals.py index 79c33ad..d62aa49 100644 --- a/src/discord/invitationals.py +++ b/src/discord/invitationals.py @@ -4,6 +4,8 @@ from typing import TYPE_CHECKING import discord +from beanie import SortDirection +from beanie.odm.operators.update.array import Push from discord.ext import commands from env import env @@ -18,6 +20,7 @@ ROLE_AT, ROLE_GM, ) +from src.mongo.models import Invitational if TYPE_CHECKING: from bot import PiBot @@ -28,24 +31,6 @@ logger = logging.getLogger(__name__) -class Invitational: - official_name: str - voters: list - - def __init__(self, objects): - self._properties = objects - self.doc_id = objects.get("_id") - self.official_name = objects.get("official_name") - self.channel_name = objects.get("channel_name") - self.emoji = objects.get("emoji") - self.aliases = objects.get("aliases") - self.tourney_date = objects.get("tourney_date") - self.open_days = objects.get("open_days") - self.closed_days = objects.get("closed_days") - self.voters = objects.get("voters") - self.status = objects.get("status") - - class AllInvitationalsView(discord.ui.View): """ A view class for holding the button to toggle visibility of all invitationals for a user. @@ -144,7 +129,7 @@ async def callback(self, interaction: discord.Interaction): else: # This dropdown is being used for voting - need_to_update = [] + need_to_update: list[Invitational] = [] already_voted_for = [] for value in self.values: @@ -153,6 +138,8 @@ async def callback(self, interaction: discord.Interaction): self.invitationals, official_name=value, ) + if not invitational: + continue if member.id in invitational.voters: # This user has already voted for this invitational. already_voted_for.append(invitational) @@ -164,13 +151,9 @@ async def callback(self, interaction: discord.Interaction): # Update invitationals DB if len(need_to_update) > 0: # Some docs need to be updated - docs_to_update = [t._properties for t in need_to_update] - await self.bot.mongo_database.update_many( - "data", - "invitationals", - docs_to_update, - {"$push": {"voters": member.id}}, - ) + + for invy in need_to_update: + invy.update(Push({Invitational.voters: member.id})) # Format output result_string = "" @@ -201,9 +184,10 @@ async def update_invitational_list(bot: PiBot, rename_dict: dict = {}) -> None: :param rename_dict: A dictionary containing renames of channels and roles that need to be completed. """ # Fetch invitationals - invitationals = await bot.mongo_database.get_invitationals() - invitationals = [Invitational(t) for t in invitationals] - invitationals.sort(key=lambda t: t.official_name) + invitationals = await Invitational.find_all( + sort=[(Invitational.official_name, SortDirection.ASCENDING)], + ignore_cache=True, + ).to_list() # Update global invitational info global INVITATIONAL_INFO diff --git a/src/discord/reporter.py b/src/discord/reporter.py index 25bd7da..5d67cfa 100644 --- a/src/discord/reporter.py +++ b/src/discord/reporter.py @@ -14,7 +14,8 @@ from env import env from src.discord.globals import CHANNEL_CLOSED_REPORTS -from src.discord.invitationals import Invitational, update_invitational_list +from src.discord.invitationals import update_invitational_list +from src.mongo.models import Invitational if TYPE_CHECKING: from bot import PiBot @@ -207,12 +208,7 @@ async def callback(self, interaction: discord.Interaction): await interaction.message.delete() # Update the invitationals database - await self.report_view.bot.mongo_database.update( - "data", - "invitationals", - self.report_view.invitational_obj.doc_id, - {"$set": {"status": "archived"}}, - ) + self.report_view.invitational_obj.set({Invitational.status: "archived"}) # Send an informational message about the report being updated closed_reports = discord.utils.get( @@ -251,12 +247,7 @@ async def callback(self, interaction: discord.Interaction): await interaction.message.delete() # Update the invitationals database - await self.report_view.bot.mongo_database.update( - "data", - "invitationals", - self.report_view.invitational_obj.doc_id, - {"$inc": {"closed_days": 15}}, - ) + await self.report_view.invitational_obj.inc({Invitational.closed_days: 15}) # Send an informational message about the report being updated closed_reports = discord.utils.get( diff --git a/src/discord/staff/invitationals.py b/src/discord/staff/invitationals.py index 26f1491..ae78a88 100644 --- a/src/discord/staff/invitationals.py +++ b/src/discord/staff/invitationals.py @@ -1,11 +1,12 @@ from __future__ import annotations +import collections import datetime import re from typing import TYPE_CHECKING, Literal import discord -from discord import app_commands +from discord import Emoji, Guild, app_commands from discord.ext import commands import commandchecks @@ -13,12 +14,14 @@ from src.discord.globals import ( CATEGORY_ARCHIVE, CATEGORY_INVITATIONALS, + DISCORD_AUTOCOMPLETE_MAX_ENTRIES, EMOJI_LOADING, ROLE_STAFF, ROLE_VIP, ) from src.discord.invitationals import update_invitational_list from src.discord.views import YesNo +from src.mongo.models import Invitational if TYPE_CHECKING: from bot import PiBot @@ -59,16 +62,17 @@ async def invitational_add( commandchecks.is_staff_from_ctx(interaction) # Create invitational doc - new_tourney_doc = { - "official_name": official_name, - "channel_name": channel_name, - "tourney_date": datetime.datetime.strptime(tourney_date, "%Y-%m-%d"), - "aliases": [], - "open_days": 10, - "closed_days": 30, - "voters": [], - "status": "open" if status == "add_immediately" else "voting", - } + new_tourney_doc = Invitational( + official_name=official_name, + channel_name=channel_name, + tourney_date=datetime.datetime.strptime(tourney_date, "%Y-%m-%d"), + emoji=None, + aliases=[], + open_days=10, + closed_days=30, + voters=[], + status="open" if status == "add_immediately" else "voting", + ) # Send default message await interaction.response.send_message(f"{EMOJI_LOADING} Loading...") @@ -158,8 +162,8 @@ async def invitational_add( description = f""" **Official Name:** {official_name} **Channel Name:** `#{channel_name}` - **Tournament Date:** {discord.utils.format_dt(new_tourney_doc['tourney_date'], 'D')} - **Closes After:** {new_tourney_doc['closed_days']} days (the invitational channel is expected to close on {discord.utils.format_dt(new_tourney_doc['tourney_date'] + datetime.timedelta(days=new_tourney_doc['closed_days']), 'D')}) + **Tournament Date:** {discord.utils.format_dt(new_tourney_doc.tourney_date, 'D')} + **Closes After:** {new_tourney_doc.closed_days} days (the invitational channel is expected to close on {discord.utils.format_dt(new_tourney_doc.tourney_date + datetime.timedelta(days=new_tourney_doc.closed_days), 'D')}) **Emoji:** {emoji} """ @@ -194,12 +198,8 @@ async def invitational_add( await view.wait() if view.value: # Staff member responded with "Yes" - new_tourney_doc["emoji"] = str(emoji) - await self.bot.mongo_database.insert( - "data", - "invitationals", - new_tourney_doc, - ) + new_tourney_doc.emoji = str(emoji) + await new_tourney_doc.insert() await interaction.edit_original_response( content="The invitational was added successfully! The invitational list will now be refreshed.", embed=None, @@ -234,45 +234,34 @@ async def invitational_approve( f"{EMOJI_LOADING} Attempting to approve...", ) - invitationals = await self.bot.mongo_database.get_invitationals() - found_invitationals = [ - i for i in invitationals if i["channel_name"] == short_name - ] + invitational = await Invitational.find_one( + Invitational.channel_name == short_name, + ignore_cache=True, + ) # If invitational is not found - if len(found_invitationals) < 1: - await interaction.edit_original_response( + if not invitational: + return await interaction.edit_original_response( content=f"Sorry, I couldn't find an invitational with the short name of `{short_name}`.", ) # If an invitational is found - elif len(found_invitationals) == 1: - - # Check to see if invitational is already open - if found_invitationals[0]["status"] == "open": - await interaction.edit_original_response( - content=f"The `{short_name}` invitational is already open.", - ) - - # If not, update invitational to be open - await self.bot.mongo_database.update( - "data", - "invitationals", - found_invitationals[0]["_id"], - {"$set": {"status": "open"}}, - ) + # Check to see if invitational is already open + if invitational.status == "open": await interaction.edit_original_response( - content=f"The status of the `{short_name}` invitational was updated.", + content=f"The `{short_name}` invitational is already open.", ) - # Update invitational list - await update_invitational_list(self.bot, {}) + # If not, update invitational to be open + invitational.status = "open" + await invitational.save() - else: - await interaction.edit_original_response( - content="I found more than one invitational with a matching name. Contact an administrator - " - "something is wrong. ", - ) + await interaction.edit_original_response( + content=f"The status of the `{short_name}` invitational was updated.", + ) + + # Update invitational list + await update_invitational_list(self.bot, {}) @invitational_status_group.command( name="edit", @@ -303,10 +292,9 @@ async def invitational_edit( ) # Attempt to find invitational - invitationals = await self.bot.mongo_database.get_invitationals() - found_invitationals = [ - i for i in invitationals if i["channel_name"] == short_name - ] + found_invitationals = await Invitational.find( + Invitational.channel_name == short_name, + ).to_list() # If no invitational was found if len(found_invitationals) < 1: @@ -349,18 +337,14 @@ async def invitational_edit( # Make sure to rename the roles rename_dict = { "roles": { - invitational["official_name"]: content_message.content, + invitational.official_name: content_message.content, }, } # and update the DB - await self.bot.mongo_database.update( - "data", - "invitationals", - invitational["_id"], - {"$set": {"official_name": content_message.content}}, - ) + invitational.official_name = content_message.content + await invitational.save() await interaction.edit_original_response( - content=f"`{invitational['official_name']}` was renamed to **`{content_message.content}`**.", + content=f"`{invitational.official_name}` was renamed to **`{content_message.content}`**.", ) # If editing invitational's short name @@ -368,18 +352,14 @@ async def invitational_edit( # Make sure to rename the channel rename_dict = { "channels": { - invitational["channel_name"]: content_message.content, + invitational.channel_name: content_message.content, }, } # and update the DB - await self.bot.mongo_database.update( - "data", - "invitationals", - invitational["_id"], - {"$set": {"channel_name": content_message.content}}, - ) + invitational.channel_name = content_message.content + await invitational.save() await interaction.edit_original_response( - content=f"The channel for {invitational['official_name']} was renamed from `{invitational['channel_name']}` to **`{content_message.content}`**.", + content=f"The channel for {invitational.official_name} was renamed from `{invitational.channel_name}` to **`{content_message.content}`**.", ) # If editing invitational's emoji @@ -407,11 +387,11 @@ async def invitational_edit( # Create new emoji, delete old emoji created_emoji = False for guild_id in env.emoji_guilds: - guild = self.bot.get_guild(guild_id) + guild: Guild = self.bot.get_guild(guild_id) for emoji in guild.emojis: if ( emoji.name - == f"tournament_{invitational['channel_name']}" + == f"tournament_{invitational.channel_name}" ): await emoji.delete( reason=f"Replaced with alternate emoji by {interaction.user}.", @@ -422,7 +402,7 @@ async def invitational_edit( ): # The guild can fit more custom emojis emoji = await guild.create_custom_emoji( - name=f"tournament_{invitational['channel_name']}", + name=f"tournament_{invitational.channel_name}", image=await emoji_attachment.read(), reason=f"Created by {interaction.user}.", ) @@ -438,17 +418,16 @@ async def invitational_edit( else: emoji = content_message.content + if isinstance(emoji, Emoji): + invitational.emoji = str(emoji) + else: + invitational.emoji = emoji # Update the DB with info - await self.bot.mongo_database.update( - "data", - "invitationals", - invitational["_id"], - {"$set": {"emoji": emoji}}, - ) + await invitational.save() # Send confirmation message await interaction.edit_original_response( - content=f"The emoji for `{invitational['official_name']}` was updated to: {emoji}.", + content=f"The emoji for `{invitational.official_name}` was updated to: {emoji}.", ) # If editing the invitational date @@ -457,15 +436,11 @@ async def invitational_edit( date_str = content_message.content date_dt = datetime.datetime.strptime(date_str, "%Y-%m-%d") # and update DB - await self.bot.mongo_database.update( - "data", - "invitationals", - invitational["_id"], - {"$set": {"tourney_date": date_dt}}, - ) + invitational.tourney_date = date_dt + await invitational.save() # and send user confirmation await interaction.edit_original_response( - content=f"The tournament date for `{invitational['official_name']}` was updated to {discord.utils.format_dt(date_dt, 'D')}.", + content=f"The tournament date for `{invitational.official_name}` was updated to {discord.utils.format_dt(date_dt, 'D')}.", ) await update_invitational_list(self.bot, rename_dict) else: @@ -494,27 +469,21 @@ async def invitational_archive( content=f"{EMOJI_LOADING} Attempting to archive the `{short_name}` invitational...", ) - invitationals = await self.bot.mongo_database.get_invitationals() - found_invitationals = [ - i for i in invitationals if i["channel_name"] == short_name - ] - if not len(found_invitationals): - await interaction.edit_original_response( + invitational = await Invitational.find_one( + Invitational.channel_name == short_name, + ignore_cache=True, + ) + if not invitational: + return await interaction.edit_original_response( content=f"Sorry, I couldn't find an invitational with a short name of {short_name}.", ) - # Invitational was found - invitational = found_invitationals[0] - # Update the database and invitational list - await self.bot.mongo_database.update( - "data", - "invitationals", - invitational["_id"], - {"$set": {"status": "archived"}}, - ) + invitational.status = "archived" + await invitational.save() + await interaction.edit_original_response( - content=f"The **`{invitational['official_name']}`** is now being archived.", + content=f"The **`{invitational.official_name}`** is now being archived.", ) await update_invitational_list(self.bot, {}) @@ -540,61 +509,53 @@ async def invitational_delete( ) # Attempt to find invitational - invitationals = await self.bot.mongo_database.get_invitationals() - found_invitationals = [ - i for i in invitationals if i["channel_name"] == short_name - ] + invitational = await Invitational.find_one( + Invitational.channel_name == short_name, + ignore_cache=True, + ) - if not len(found_invitationals): - await interaction.edit_original_response( + if not invitational: + return await interaction.edit_original_response( content=f"Sorry, I couldn't find an invitational with a short name of {short_name}.", ) - else: - # Find the relevant invitational - invitational = found_invitationals[0] - - # Get the relevant channel and role - server = self.bot.get_guild(env.server_id) - ch = discord.utils.get( - server.text_channels, - name=invitational["channel_name"], - ) - r = discord.utils.get(server.roles, name=invitational["official_name"]) - - # Delete the channel and role - if ( - ch - and ch.category - and ch.category.name - in [ - CATEGORY_ARCHIVE, - CATEGORY_INVITATIONALS, - ] - ): - await ch.delete() - if r: - await r.delete() - - # Delete the invitational emoji - search = re.findall(r"<:.*:\d+>", invitational["emoji"]) - if len(search): - emoji = self.bot.get_emoji(search[0]) - if emoji: - await emoji.delete() - - # Delete from the DB - await self.bot.mongo_database.delete( - "data", - "invitationals", - invitational["_id"], - ) - await interaction.edit_original_response( - content=f"Deleted the **`{invitational['official_name']}`**.", - ) + # Get the relevant channel and role + server = self.bot.get_guild(env.server_id) + ch = discord.utils.get( + server.text_channels, + name=invitational.channel_name, + ) + r = discord.utils.get(server.roles, name=invitational.official_name) + + # Delete the channel and role + if ( + ch + and ch.category + and ch.category.name + in [ + CATEGORY_ARCHIVE, + CATEGORY_INVITATIONALS, + ] + ): + await ch.delete() + if r: + await r.delete() + + # Delete the invitational emoji + search = re.findall(r"<:.*:\d+>", invitational.emoji) + if len(search): + emoji = self.bot.get_emoji(search[0]) + if emoji: + await emoji.delete() + + # Delete from the DB + await invitational.delete() + await interaction.edit_original_response( + content=f"Deleted the **`{invitational.official_name}`**.", + ) - # Update the invitational list to reflect - await update_invitational_list(self.bot, {}) + # Update the invitational list to reflect + await update_invitational_list(self.bot, {}) @invitational_status_group.command( name="season", @@ -651,14 +612,12 @@ async def invitational_season(self, interaction: discord.Interaction): ) # Remove voters from all tourneys - invitationals = await self.bot.mongo_database.get_invitationals() - for invitational in invitationals: - await self.bot.mongo_database.update( - "data", - "tournaments", - invitational["_id"], - {"$set": {"voters": []}}, - ) + u = collections.UserDict( + [(Invitational.voters, [])], + ) # Needed to avoid using normal + # dict since `Invitational.voters` + # is a list and cannot be hashed + Invitational.update_all(u) # Update the invitational list to reflect await update_invitational_list(self.bot, {}) @@ -697,22 +656,18 @@ async def invitational_renew( content=f"{EMOJI_LOADING} Attempting to renew the `{short_name}` invitational...", ) - invitationals = await self.bot.mongo_database.get_invitationals() - found_invitationals = [ - i for i in invitationals if i["channel_name"] == short_name - ] + invitational = await Invitational.find_one( + Invitational.channel_name == short_name, + ignore_cache=True, + ) - if not len(found_invitationals): - await interaction.edit_original_response( + if not invitational: + return await interaction.edit_original_response( content=f"Sorry, I couldn't find an invitational with a short name of {short_name}.", ) - invitational = found_invitationals[0] - await self.bot.mongo_database.update( - "data", - "invitationals", - invitational["_id"], - {"$set": {"status": "voting" if voting == "yes" else "open"}}, + await invitational.set( + {Invitational.status: "voting" if voting == "yes" else "open"}, ) # Update the invitational list to reflect @@ -729,14 +684,14 @@ async def short_name_voting_autocomplete( interaction: discord.Interaction, current: str, ) -> list[discord.app_commands.Choice[str]]: - invitationals = await self.bot.mongo_database.get_invitationals() + invitationals = await Invitational.find_all().to_list() return [ discord.app_commands.Choice( - name=f"#{i['channel_name']} ({len(i['voters'])} voters)", - value=i["channel_name"], + name=f"#{i.channel_name} ({len(i.voters)} voters)", + value=i.channel_name, ) for i in invitationals - if current.lower() in i["channel_name"].lower() and i["status"] == "voting" + if current.lower() in i.channel_name.lower() and i.status == "voting" ][:25] @invitational_edit.autocomplete("short_name") @@ -746,14 +701,14 @@ async def short_name_autocomplete( interaction: discord.Interaction, current: str, ) -> list[discord.app_commands.Choice[str]]: - invitationals = await self.bot.mongo_database.get_invitationals() + invitationals = await Invitational.find_all().to_list() return [ discord.app_commands.Choice( - name=f"#{i['channel_name']}", - value=i["channel_name"], + name=f"#{i.channel_name}", + value=i.channel_name, ) for i in invitationals - if current.lower() in i["channel_name"].lower() + if current.lower() in i.channel_name.lower() ][:25] @invitational_archive.autocomplete("short_name") @@ -762,31 +717,30 @@ async def short_name_archive_autocomplete( interaction: discord.Interaction, current: str, ) -> list[discord.app_commands.Choice[str]]: - invitationals = await self.bot.mongo_database.get_invitationals() + invitationals = await Invitational.find_all().to_list() return [ discord.app_commands.Choice( - name=f"#{i['channel_name']}", - value=i["channel_name"], + name=f"#{i.channel_name}", + value=i.channel_name, ) for i in invitationals - if current.lower() in i["channel_name"].lower() and i["status"] == "open" - ][:25] + if current.lower() in i.channel_name.lower() and i.status == "open" + ][:DISCORD_AUTOCOMPLETE_MAX_ENTRIES] @invitational_renew.autocomplete("short_name") async def short_name_renew_autocomplete( self, - interaction: discord.Interaction, + _interaction: discord.Interaction, current: str, ) -> list[discord.app_commands.Choice[str]]: - invitationals = await self.bot.mongo_database.get_invitationals() + invitationals = await Invitational.find_all().to_list() return [ discord.app_commands.Choice( - name=f"#{i['channel_name']}", - value=i["channel_name"], + name=f"#{i.channel_name}", + value=i.channel_name, ) for i in invitationals - if current.lower() in i["channel_name"].lower() - and i["status"] == "archived" + if current.lower() in i.channel_name.lower() and i.status == "archived" ][:25] diff --git a/src/mongo/models.py b/src/mongo/models.py index 64cd474..e3e0c04 100644 --- a/src/mongo/models.py +++ b/src/mongo/models.py @@ -1,3 +1,23 @@ """ Contains all database models """ +from datetime import datetime +from typing import Literal + +from beanie import Document + + +class Invitational(Document): + official_name: str + channel_name: str + emoji: str | None + aliases: list[str] + tourney_date: datetime + open_days: int + closed_days: int + voters: list[int] # FIXME: is this a list of ids? if so, are they str or ints? + status: Literal["voting", "open", "archived"] # FIX: Are there any more statuses? + + class Settings: + name = "invitationals" + use_cache = True diff --git a/src/mongo/mongo.py b/src/mongo/mongo.py index eef3f2e..0f05de6 100644 --- a/src/mongo/mongo.py +++ b/src/mongo/mongo.py @@ -83,18 +83,6 @@ async def get_entire_collection( result.append(doc) return result - async def get_invitationals(self): - """ - Gets all documents in the invitationals collection. - """ - return await self.get_entire_collection("data", "invitationals") - - async def get_cron(self): - """ - Gets all documents in the CRON collection. - """ - return await self.get_entire_collection("data", "cron") - async def get_censor(self): """ Gets the document containing censor information from the censor collection.