From 3c643691a9cc893d07a988a5fc03ead8074d0bbc Mon Sep 17 00:00:00 2001 From: Hung Pham Date: Wed, 6 Mar 2024 09:54:28 +1100 Subject: [PATCH] feat: clean up field casting using Pydantic + other QoL (#15) --- .env.example | 5 ++++- cogs/puzzle.py | 17 ++++++++--------- cogs/team.py | 32 +++++++++++++++----------------- main.py | 14 ++++++++++---- requirements.txt | 3 +++ src/context/puzzle.py | 2 +- src/models/player.py | 7 +++---- src/models/team.py | 13 ++++++------- src/queries/player.py | 12 ++++++------ src/queries/submission.py | 6 +++--- src/queries/team.py | 18 +++++++++--------- src/utils/decorators.py | 4 ++-- 12 files changed, 70 insertions(+), 63 deletions(-) diff --git a/.env.example b/.env.example index 144942d..c8b6a94 100644 --- a/.env.example +++ b/.env.example @@ -1,2 +1,5 @@ DISCORD_TOKEN=insert_bot_token_here -DATABASE_URL=postgres://postgres:password@127.0.0.1:5432/puzzlehunt_bot \ No newline at end of file +DATABASE_URL=postgres://postgres:password@127.0.0.1:5432/puzzlehunt_bot +ADMIN_CHANNEL_ID= +VICTOR_ROLE_ID= +VICTOR_TEXT_CHANNEL_ID= \ No newline at end of file diff --git a/cogs/puzzle.py b/cogs/puzzle.py index 31359e8..a61eadb 100644 --- a/cogs/puzzle.py +++ b/cogs/puzzle.py @@ -1,7 +1,6 @@ from src.config import config from datetime import datetime -from typing import Literal from zoneinfo import ZoneInfo import discord from discord.ext import commands @@ -10,7 +9,7 @@ from src.queries.puzzle import get_puzzle, get_completed_puzzles from src.queries.submission import ( create_submission, - find_submissions_by_player_id_and_puzzle_id, + find_submissions_by_discord_id_and_puzzle_id, ) from src.queries.player import get_player from src.queries.team import get_team, get_team_members @@ -34,12 +33,12 @@ async def submit_answer( puzzle_id = puzzle_id.upper() puzzle = await get_puzzle(puzzle_id) - player = await get_player(str(interaction.user.id)) + player = await get_player(interaction.user.id) if not puzzle or not await can_access_puzzle(puzzle, player.team_name): return await interaction.followup.send( "No puzzle with the corresponding ID exists!" ) - submissions = await find_submissions_by_player_id_and_puzzle_id( + submissions = await find_submissions_by_discord_id_and_puzzle_id( player.discord_id, puzzle_id ) @@ -91,7 +90,7 @@ async def submit_answer( team_members = await get_team_members(player.team_name) for team_member in team_members: - discord_member = guild.get_member(int(team_member.discord_id)) + discord_member = guild.get_member(team_member.discord_id) await discord_member.add_roles(guild.get_role(config["VICTOR_ROLE_ID"])) await interaction.followup.send( @@ -105,15 +104,15 @@ async def submit_answer( @in_team_channel async def list_puzzles(self, interaction: discord.Interaction): await interaction.response.defer() - player = await get_player(str(interaction.user.id)) + player = await get_player(interaction.user.id) puzzles = await get_accessible_puzzles(player.team_name) embed = discord.Embed(title="Current Puzzles", color=discord.Color.greyple()) puzzle_ids = [] puzzle_name_links = [] for puzzle in puzzles: - submissions = await find_submissions_by_player_id_and_puzzle_id( - str(player.discord_id), puzzle.puzzle_id + submissions = await find_submissions_by_discord_id_and_puzzle_id( + player.discord_id, puzzle.puzzle_id ) if any([submission.submission_is_correct for submission in submissions]): @@ -131,7 +130,7 @@ async def list_puzzles(self, interaction: discord.Interaction): @in_team_channel async def hint(self, interaction: discord.Interaction): await interaction.response.defer() - team = await get_player(str(interaction.user.id)) + team = await get_player(interaction.user.id) await interaction.client.get_channel(config["ADMIN_CHANNEL_ID"]).send( f"Hint request submitted from team {team.team_name}! {interaction.channel.mention}" ) diff --git a/cogs/team.py b/cogs/team.py index 7873487..edca420 100644 --- a/cogs/team.py +++ b/cogs/team.py @@ -1,7 +1,6 @@ import discord from discord.ext import commands from discord import app_commands -from discord import Guild import src.queries.team as team_query import src.queries.player as player_query @@ -23,7 +22,7 @@ async def create_team(self, interaction: discord.Interaction, team_name: str): await interaction.response.defer(ephemeral=True) - if await player_query.get_player(str(user.id)): + if await player_query.get_player(user.id): await interaction.followup.send( "You are already in a team. Please leave the team before creating a new one.", ephemeral=True, @@ -62,14 +61,14 @@ async def create_team(self, interaction: discord.Interaction, team_name: str): # create team in database await team_query.create_team( team_name, - str(category.id), - str(voice_channel.id), - str(text_channel.id), - str(team_role.id), + category.id, + voice_channel.id, + text_channel.id, + team_role.id, ) # add player to database - await player_query.add_player(str(user.id), team_name) + await player_query.add_player(user.id, team_name) # give role to user await user.add_roles(team_role) @@ -88,7 +87,7 @@ async def leave_team(self, interaction: discord.Interaction): # get the team name from the user discord_id = user.id - player = await player_query.get_player(str(discord_id)) + player = await player_query.get_player(discord_id) if not player: await interaction.followup.send( @@ -99,13 +98,12 @@ async def leave_team(self, interaction: discord.Interaction): team_name = player.team_name team = await team_query.get_team(team_name) - team_role_id = int(team.team_role_id) - role = guild.get_role(team_role_id) + role = guild.get_role(team.team_role_id) await user.remove_roles(role) # delete player - await player_query.remove_player(str(discord_id)) + await player_query.remove_player(discord_id) # check amount of people still in team # if none, delete team and respective channels @@ -121,9 +119,9 @@ async def leave_team(self, interaction: discord.Interaction): ) # if here, then there are no members remaining in the teams - text_channel = guild.get_channel(int(team.text_channel_id)) - voice_channel = guild.get_channel(int(team.voice_channel_id)) - category_channel = guild.get_channel(int(team.category_channel_id)) + text_channel = guild.get_channel(team.text_channel_id) + voice_channel = guild.get_channel(team.voice_channel_id) + category_channel = guild.get_channel(team.category_channel_id) await text_channel.delete() await voice_channel.delete() @@ -145,7 +143,7 @@ async def invite( await interaction.response.defer(ephemeral=True) # check user is already in a team - player = await player_query.get_player(str(user.id)) + player = await player_query.get_player(user.id) if not player: await interaction.followup.send( "You must be in a team to use this command.", ephemeral=True @@ -155,7 +153,7 @@ async def invite( team_name = player.team_name # check invited user is not already in a team - if await player_query.get_player(str(invited_user.id)): + if await player_query.get_player(invited_user.id): await interaction.followup.send( "The user you're trying to invite is already in a team.", ephemeral=True ) @@ -218,7 +216,7 @@ async def accept_callback(interaction: discord.Interaction): # add new user to team await invited_user.add_roles(guild.get_role(int(team.team_role_id))) - await player_query.add_player(str(new_player.id), team_name) + await player_query.add_player(new_player.id, team_name) # edit message so that user can't click again accept_embed = discord.Embed( diff --git a/main.py b/main.py index 8775ac4..078415f 100644 --- a/main.py +++ b/main.py @@ -12,8 +12,17 @@ bot = commands.Bot(command_prefix="!", help_command=None, intents=intents) +async def load_cogs(): + for filename in os.listdir("cogs/"): + if filename.endswith(".py"): + await bot.load_extension(f"{COGS_DIR}.{filename[:-3]}") + print(f"Loaded {filename}") + + @bot.event async def on_ready(): + await load_cogs() + await bot.tree.sync() print(f"Connected as {bot.user}.") @@ -21,10 +30,7 @@ async def on_ready(): @bot.command() @commands.has_role(EXEC_ID) async def startup(ctx: commands.context.Context): - for filename in os.listdir("cogs/"): - if filename.endswith(".py"): - await bot.load_extension(f"{COGS_DIR}.{filename[:-3]}") - print(f"Loaded {filename}") + await load_cogs() await ctx.send(f"Loaded all cogs") diff --git a/requirements.txt b/requirements.txt index bbf1d5d..38ddabe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ aiohttp==3.9.3 aiosignal==1.3.1 +annotated-types==0.6.0 async-timeout==4.0.3 attrs==23.2.0 discord==2.3.2 @@ -13,6 +14,8 @@ packaging==23.2 pluggy==1.4.0 psycopg==3.1.18 psycopg-binary==3.1.18 +pydantic==2.6.3 +pydantic_core==2.16.3 pytest==8.0.2 pytest-asyncio==0.23.5 python-dotenv==1.0.1 diff --git a/src/context/puzzle.py b/src/context/puzzle.py index b911b6b..77da402 100644 --- a/src/context/puzzle.py +++ b/src/context/puzzle.py @@ -1,7 +1,7 @@ from typing import List from src.models.puzzle import Puzzle from src.queries.puzzle import get_puzzles, get_completed_puzzles -from src.queries.player import get_player + NUMBER_OF_FEEDERS = {"UTS": 4, "UNSW": 4, "USYD": 6} diff --git a/src/models/player.py b/src/models/player.py index 1e9d58b..f4f4f60 100644 --- a/src/models/player.py +++ b/src/models/player.py @@ -1,7 +1,6 @@ -from dataclasses import dataclass +from pydantic import BaseModel -@dataclass -class Player: - discord_id: str +class Player(BaseModel): + discord_id: int team_name: str diff --git a/src/models/team.py b/src/models/team.py index 7b8430b..9b47782 100644 --- a/src/models/team.py +++ b/src/models/team.py @@ -1,11 +1,10 @@ -from dataclasses import dataclass +from pydantic import BaseModel -@dataclass -class Team: +class Team(BaseModel): team_name: str - category_channel_id: str - voice_channel_id: str - text_channel_id: str - team_role_id: str + category_channel_id: int + voice_channel_id: int + text_channel_id: int + team_role_id: int puzzle_solved: int diff --git a/src/queries/player.py b/src/queries/player.py index be80d8f..a46246d 100644 --- a/src/queries/player.py +++ b/src/queries/player.py @@ -7,7 +7,7 @@ DATABASE_URL = config["DATABASE_URL"] -async def get_player(discord_id: str): +async def get_player(discord_id: int): aconn = await psycopg.AsyncConnection.connect(DATABASE_URL) acur = aconn.cursor(row_factory=class_row(Player)) @@ -16,7 +16,7 @@ async def get_player(discord_id: str): SELECT * FROM public.players AS p WHERE p.discord_id = %s """, - (discord_id,), + (str(discord_id),), ) player = await acur.fetchone() @@ -27,7 +27,7 @@ async def get_player(discord_id: str): return player -async def add_player(discord_id: str, team_name: str): +async def add_player(discord_id: int, team_name: str): aconn = await psycopg.AsyncConnection.connect(DATABASE_URL) acur = aconn.cursor() @@ -37,7 +37,7 @@ async def add_player(discord_id: str, team_name: str): (discord_id, team_name) VALUES (%s, %s) """, - (discord_id, team_name), + (str(discord_id), team_name), ) await aconn.commit() @@ -46,7 +46,7 @@ async def add_player(discord_id: str, team_name: str): await aconn.close() -async def remove_player(discord_id: str): +async def remove_player(discord_id: int): aconn = await psycopg.AsyncConnection.connect(DATABASE_URL) acur = aconn.cursor() @@ -55,7 +55,7 @@ async def remove_player(discord_id: str): DELETE FROM public.players AS p WHERE p.discord_id = %s """, - (discord_id,), + (str(discord_id),), ) await aconn.commit() diff --git a/src/queries/submission.py b/src/queries/submission.py index a13d3f9..7784946 100644 --- a/src/queries/submission.py +++ b/src/queries/submission.py @@ -62,8 +62,8 @@ async def find_submissions_by_team(team_name: str): return await acur.fetchall() -async def find_submissions_by_player_id_and_puzzle_id( - player_id: str, puzzle_id: str +async def find_submissions_by_discord_id_and_puzzle_id( + discord_id: int, puzzle_id: str ) -> List[Submission]: async with await psycopg.AsyncConnection.connect(DATABASE_URL) as aconn: async with aconn.cursor(row_factory=class_row(Submission)) as acur: @@ -75,7 +75,7 @@ async def find_submissions_by_player_id_and_puzzle_id( INNER JOIN public.players AS p ON p.team_name = t.team_name WHERE p.discord_id = %s AND s.puzzle_id = %s """, - (player_id, puzzle_id), + (str(discord_id), puzzle_id), ) return await acur.fetchall() diff --git a/src/queries/team.py b/src/queries/team.py index c6ae952..bd6cd06 100644 --- a/src/queries/team.py +++ b/src/queries/team.py @@ -1,6 +1,6 @@ import psycopg from psycopg.rows import class_row -from datetime import datetime +from src.models.puzzle import Puzzle from src.config import config from src.models.team import Team @@ -53,10 +53,10 @@ async def get_team_members(team_name: str): async def create_team( team_name: str, - category_channel_id: str, - voice_channel_id: str, - text_channel_id: str, - team_role_id: str, + category_channel_id: int, + voice_channel_id: int, + text_channel_id: int, + team_role_id: int, ): aconn = await psycopg.AsyncConnection.connect(DATABASE_URL) acur = aconn.cursor() @@ -69,10 +69,10 @@ async def create_team( """, ( team_name, - category_channel_id, - voice_channel_id, - text_channel_id, - team_role_id, + str(category_channel_id), + str(voice_channel_id), + str(text_channel_id), + str(team_role_id), ), ) diff --git a/src/utils/decorators.py b/src/utils/decorators.py index 8af7323..6d251b8 100644 --- a/src/utils/decorators.py +++ b/src/utils/decorators.py @@ -15,7 +15,7 @@ async def wrapper(self, *args, **kwargs): if not isinstance(user, discord.Member): return await interaction.response.send_message("Something went wrong!") channel = interaction.channel - player = await get_player(str(user.id)) + player = await get_player(user.id) if not player: return await interaction.response.send_message( @@ -24,7 +24,7 @@ async def wrapper(self, *args, **kwargs): team = await get_team(player.team_name) - if team.text_channel_id != str(channel.id): + if team.text_channel_id != channel.id: return await interaction.response.send_message( "You can only use this command in your team's channel!" )