Skip to content

Commit

Permalink
feat: clean up field casting using Pydantic + other QoL (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkellyBG authored Mar 5, 2024
1 parent 310f9a7 commit 3c64369
Show file tree
Hide file tree
Showing 12 changed files with 70 additions and 63 deletions.
5 changes: 4 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
DISCORD_TOKEN=insert_bot_token_here
DATABASE_URL=postgres://postgres:password@127.0.0.1:5432/puzzlehunt_bot
DATABASE_URL=postgres://postgres:password@127.0.0.1:5432/puzzlehunt_bot
ADMIN_CHANNEL_ID=
VICTOR_ROLE_ID=
VICTOR_TEXT_CHANNEL_ID=
17 changes: 8 additions & 9 deletions cogs/puzzle.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from src.config import config

from datetime import datetime
from typing import Literal
from zoneinfo import ZoneInfo
import discord
from discord.ext import commands
Expand All @@ -10,7 +9,7 @@
from src.queries.puzzle import get_puzzle, get_completed_puzzles
from src.queries.submission import (
create_submission,
find_submissions_by_player_id_and_puzzle_id,
find_submissions_by_discord_id_and_puzzle_id,
)
from src.queries.player import get_player
from src.queries.team import get_team, get_team_members
Expand All @@ -34,12 +33,12 @@ async def submit_answer(

puzzle_id = puzzle_id.upper()
puzzle = await get_puzzle(puzzle_id)
player = await get_player(str(interaction.user.id))
player = await get_player(interaction.user.id)
if not puzzle or not await can_access_puzzle(puzzle, player.team_name):
return await interaction.followup.send(
"No puzzle with the corresponding ID exists!"
)
submissions = await find_submissions_by_player_id_and_puzzle_id(
submissions = await find_submissions_by_discord_id_and_puzzle_id(
player.discord_id, puzzle_id
)

Expand Down Expand Up @@ -91,7 +90,7 @@ async def submit_answer(
team_members = await get_team_members(player.team_name)

for team_member in team_members:
discord_member = guild.get_member(int(team_member.discord_id))
discord_member = guild.get_member(team_member.discord_id)
await discord_member.add_roles(guild.get_role(config["VICTOR_ROLE_ID"]))

await interaction.followup.send(
Expand All @@ -105,15 +104,15 @@ async def submit_answer(
@in_team_channel
async def list_puzzles(self, interaction: discord.Interaction):
await interaction.response.defer()
player = await get_player(str(interaction.user.id))
player = await get_player(interaction.user.id)
puzzles = await get_accessible_puzzles(player.team_name)
embed = discord.Embed(title="Current Puzzles", color=discord.Color.greyple())

puzzle_ids = []
puzzle_name_links = []
for puzzle in puzzles:
submissions = await find_submissions_by_player_id_and_puzzle_id(
str(player.discord_id), puzzle.puzzle_id
submissions = await find_submissions_by_discord_id_and_puzzle_id(
player.discord_id, puzzle.puzzle_id
)

if any([submission.submission_is_correct for submission in submissions]):
Expand All @@ -131,7 +130,7 @@ async def list_puzzles(self, interaction: discord.Interaction):
@in_team_channel
async def hint(self, interaction: discord.Interaction):
await interaction.response.defer()
team = await get_player(str(interaction.user.id))
team = await get_player(interaction.user.id)
await interaction.client.get_channel(config["ADMIN_CHANNEL_ID"]).send(
f"Hint request submitted from team {team.team_name}! {interaction.channel.mention}"
)
Expand Down
32 changes: 15 additions & 17 deletions cogs/team.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import discord
from discord.ext import commands
from discord import app_commands
from discord import Guild

import src.queries.team as team_query
import src.queries.player as player_query
Expand All @@ -23,7 +22,7 @@ async def create_team(self, interaction: discord.Interaction, team_name: str):

await interaction.response.defer(ephemeral=True)

if await player_query.get_player(str(user.id)):
if await player_query.get_player(user.id):
await interaction.followup.send(
"You are already in a team. Please leave the team before creating a new one.",
ephemeral=True,
Expand Down Expand Up @@ -62,14 +61,14 @@ async def create_team(self, interaction: discord.Interaction, team_name: str):
# create team in database
await team_query.create_team(
team_name,
str(category.id),
str(voice_channel.id),
str(text_channel.id),
str(team_role.id),
category.id,
voice_channel.id,
text_channel.id,
team_role.id,
)

# add player to database
await player_query.add_player(str(user.id), team_name)
await player_query.add_player(user.id, team_name)

# give role to user
await user.add_roles(team_role)
Expand All @@ -88,7 +87,7 @@ async def leave_team(self, interaction: discord.Interaction):

# get the team name from the user
discord_id = user.id
player = await player_query.get_player(str(discord_id))
player = await player_query.get_player(discord_id)

if not player:
await interaction.followup.send(
Expand All @@ -99,13 +98,12 @@ async def leave_team(self, interaction: discord.Interaction):
team_name = player.team_name

team = await team_query.get_team(team_name)
team_role_id = int(team.team_role_id)
role = guild.get_role(team_role_id)
role = guild.get_role(team.team_role_id)

await user.remove_roles(role)

# delete player
await player_query.remove_player(str(discord_id))
await player_query.remove_player(discord_id)

# check amount of people still in team
# if none, delete team and respective channels
Expand All @@ -121,9 +119,9 @@ async def leave_team(self, interaction: discord.Interaction):
)

# if here, then there are no members remaining in the teams
text_channel = guild.get_channel(int(team.text_channel_id))
voice_channel = guild.get_channel(int(team.voice_channel_id))
category_channel = guild.get_channel(int(team.category_channel_id))
text_channel = guild.get_channel(team.text_channel_id)
voice_channel = guild.get_channel(team.voice_channel_id)
category_channel = guild.get_channel(team.category_channel_id)

await text_channel.delete()
await voice_channel.delete()
Expand All @@ -145,7 +143,7 @@ async def invite(
await interaction.response.defer(ephemeral=True)

# check user is already in a team
player = await player_query.get_player(str(user.id))
player = await player_query.get_player(user.id)
if not player:
await interaction.followup.send(
"You must be in a team to use this command.", ephemeral=True
Expand All @@ -155,7 +153,7 @@ async def invite(
team_name = player.team_name

# check invited user is not already in a team
if await player_query.get_player(str(invited_user.id)):
if await player_query.get_player(invited_user.id):
await interaction.followup.send(
"The user you're trying to invite is already in a team.", ephemeral=True
)
Expand Down Expand Up @@ -218,7 +216,7 @@ async def accept_callback(interaction: discord.Interaction):

# add new user to team
await invited_user.add_roles(guild.get_role(int(team.team_role_id)))
await player_query.add_player(str(new_player.id), team_name)
await player_query.add_player(new_player.id, team_name)

# edit message so that user can't click again
accept_embed = discord.Embed(
Expand Down
14 changes: 10 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,25 @@
bot = commands.Bot(command_prefix="!", help_command=None, intents=intents)


async def load_cogs():
for filename in os.listdir("cogs/"):
if filename.endswith(".py"):
await bot.load_extension(f"{COGS_DIR}.{filename[:-3]}")
print(f"Loaded {filename}")


@bot.event
async def on_ready():
await load_cogs()
await bot.tree.sync()
print(f"Connected as {bot.user}.")


# load all available cogs on startup
@bot.command()
@commands.has_role(EXEC_ID)
async def startup(ctx: commands.context.Context):
for filename in os.listdir("cogs/"):
if filename.endswith(".py"):
await bot.load_extension(f"{COGS_DIR}.{filename[:-3]}")
print(f"Loaded {filename}")
await load_cogs()
await ctx.send(f"Loaded all cogs")


Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
aiohttp==3.9.3
aiosignal==1.3.1
annotated-types==0.6.0
async-timeout==4.0.3
attrs==23.2.0
discord==2.3.2
Expand All @@ -13,6 +14,8 @@ packaging==23.2
pluggy==1.4.0
psycopg==3.1.18
psycopg-binary==3.1.18
pydantic==2.6.3
pydantic_core==2.16.3
pytest==8.0.2
pytest-asyncio==0.23.5
python-dotenv==1.0.1
Expand Down
2 changes: 1 addition & 1 deletion src/context/puzzle.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List
from src.models.puzzle import Puzzle
from src.queries.puzzle import get_puzzles, get_completed_puzzles
from src.queries.player import get_player


NUMBER_OF_FEEDERS = {"UTS": 4, "UNSW": 4, "USYD": 6}

Expand Down
7 changes: 3 additions & 4 deletions src/models/player.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dataclasses import dataclass
from pydantic import BaseModel


@dataclass
class Player:
discord_id: str
class Player(BaseModel):
discord_id: int
team_name: str
13 changes: 6 additions & 7 deletions src/models/team.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from dataclasses import dataclass
from pydantic import BaseModel


@dataclass
class Team:
class Team(BaseModel):
team_name: str
category_channel_id: str
voice_channel_id: str
text_channel_id: str
team_role_id: str
category_channel_id: int
voice_channel_id: int
text_channel_id: int
team_role_id: int
puzzle_solved: int
12 changes: 6 additions & 6 deletions src/queries/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
DATABASE_URL = config["DATABASE_URL"]


async def get_player(discord_id: str):
async def get_player(discord_id: int):
aconn = await psycopg.AsyncConnection.connect(DATABASE_URL)
acur = aconn.cursor(row_factory=class_row(Player))

Expand All @@ -16,7 +16,7 @@ async def get_player(discord_id: str):
SELECT * FROM public.players AS p
WHERE p.discord_id = %s
""",
(discord_id,),
(str(discord_id),),
)

player = await acur.fetchone()
Expand All @@ -27,7 +27,7 @@ async def get_player(discord_id: str):
return player


async def add_player(discord_id: str, team_name: str):
async def add_player(discord_id: int, team_name: str):
aconn = await psycopg.AsyncConnection.connect(DATABASE_URL)
acur = aconn.cursor()

Expand All @@ -37,7 +37,7 @@ async def add_player(discord_id: str, team_name: str):
(discord_id, team_name)
VALUES (%s, %s)
""",
(discord_id, team_name),
(str(discord_id), team_name),
)

await aconn.commit()
Expand All @@ -46,7 +46,7 @@ async def add_player(discord_id: str, team_name: str):
await aconn.close()


async def remove_player(discord_id: str):
async def remove_player(discord_id: int):
aconn = await psycopg.AsyncConnection.connect(DATABASE_URL)
acur = aconn.cursor()

Expand All @@ -55,7 +55,7 @@ async def remove_player(discord_id: str):
DELETE FROM public.players AS p
WHERE p.discord_id = %s
""",
(discord_id,),
(str(discord_id),),
)

await aconn.commit()
Expand Down
6 changes: 3 additions & 3 deletions src/queries/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ async def find_submissions_by_team(team_name: str):
return await acur.fetchall()


async def find_submissions_by_player_id_and_puzzle_id(
player_id: str, puzzle_id: str
async def find_submissions_by_discord_id_and_puzzle_id(
discord_id: int, puzzle_id: str
) -> List[Submission]:
async with await psycopg.AsyncConnection.connect(DATABASE_URL) as aconn:
async with aconn.cursor(row_factory=class_row(Submission)) as acur:
Expand All @@ -75,7 +75,7 @@ async def find_submissions_by_player_id_and_puzzle_id(
INNER JOIN public.players AS p ON p.team_name = t.team_name
WHERE p.discord_id = %s AND s.puzzle_id = %s
""",
(player_id, puzzle_id),
(str(discord_id), puzzle_id),
)

return await acur.fetchall()
18 changes: 9 additions & 9 deletions src/queries/team.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import psycopg
from psycopg.rows import class_row
from datetime import datetime
from src.models.puzzle import Puzzle

from src.config import config
from src.models.team import Team
Expand Down Expand Up @@ -53,10 +53,10 @@ async def get_team_members(team_name: str):

async def create_team(
team_name: str,
category_channel_id: str,
voice_channel_id: str,
text_channel_id: str,
team_role_id: str,
category_channel_id: int,
voice_channel_id: int,
text_channel_id: int,
team_role_id: int,
):
aconn = await psycopg.AsyncConnection.connect(DATABASE_URL)
acur = aconn.cursor()
Expand All @@ -69,10 +69,10 @@ async def create_team(
""",
(
team_name,
category_channel_id,
voice_channel_id,
text_channel_id,
team_role_id,
str(category_channel_id),
str(voice_channel_id),
str(text_channel_id),
str(team_role_id),
),
)

Expand Down
4 changes: 2 additions & 2 deletions src/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ async def wrapper(self, *args, **kwargs):
if not isinstance(user, discord.Member):
return await interaction.response.send_message("Something went wrong!")
channel = interaction.channel
player = await get_player(str(user.id))
player = await get_player(user.id)

if not player:
return await interaction.response.send_message(
Expand All @@ -24,7 +24,7 @@ async def wrapper(self, *args, **kwargs):

team = await get_team(player.team_name)

if team.text_channel_id != str(channel.id):
if team.text_channel_id != channel.id:
return await interaction.response.send_message(
"You can only use this command in your team's channel!"
)
Expand Down

0 comments on commit 3c64369

Please sign in to comment.