Skip to content

Commit

Permalink
Site: Refactor discord member and fix cache (#603)
Browse files Browse the repository at this point in the history
  • Loading branch information
ConnorNelson authored Oct 18, 2024
1 parent c50f759 commit 6f58580
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 69 deletions.
52 changes: 18 additions & 34 deletions dojo_plugin/api/v1/discord.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import hmac
import datetime
import json
from datetime import datetime, date
from itertools import islice

from flask import request
from flask_restx import Namespace, Resource
from sqlalchemy import and_
from CTFd.cache import cache
from CTFd.models import db
from CTFd.utils.decorators import authed_only
from CTFd.utils.user import get_current_user

from ...config import DISCORD_CLIENT_SECRET
from ...models import DiscordUsers, DiscordUserActivity
from ...utils.discord import get_discord_member, get_discord_member_by_discord_id
from ...utils.discord import get_discord_member
from ...utils.dojo import get_current_dojo_challenge

discord_namespace = Namespace("discord", description="Endpoint to manage discord")
Expand All @@ -32,9 +30,10 @@ class Discord(Resource):
@authed_only
def delete(self):
user = get_current_user()
DiscordUsers.query.filter_by(user=user).delete()
db.session.commit()
cache.delete_memoized(get_discord_member, user.id)
discord_user = DiscordUsers.query.filter_by(user=user).first()
if discord_user:
db.session.delete(discord_user)
db.session.commit()
return {"success": True}


Expand Down Expand Up @@ -129,7 +128,7 @@ def post_user_activity(discord_id, activity, request):
'channel_id': data.get("channel_id"),
'message_id': data.get("message_id"),
'timestamp': data.get("timestamp"),
'message_timestamp': datetime.datetime.fromisoformat(data.get("message_timestamp")),
'message_timestamp': datetime.fromisoformat(data.get("message_timestamp")),
'type': activity
}
entry = DiscordUserActivity(**kwargs)
Expand Down Expand Up @@ -158,37 +157,22 @@ def post(self, discord_id):
@discord_namespace.route("/thanks/leaderboard", methods=["GET"])
class GetDiscordLeaderBoard(Resource):
def get(self):
start_stamp = request.args.get("start")

def year_stamp():
year = datetime.datetime.now().year
return datetime.datetime(year, 1, 1)

try:
if start_stamp is None:
start = year_stamp()
else:
start = datetime.datetime.fromisoformat(start_stamp)
except:
return {"success": False, "error": "invalid start format"}, 400
start = datetime.fromisoformat(request.args.get("start", f"{date.today().year}-01-01"))
except ValueError:
return {"success": False, "error": "Invalid start format"}, 400

sq = DiscordUserActivity.query.where(
DiscordUserActivity.type == 'thanks').where(
DiscordUserActivity.message_timestamp >= start).with_entities(
DiscordUserActivity.user_id, DiscordUserActivity.source_user_id, DiscordUserActivity.message_id).distinct().subquery()
thanks_scores = db.session.execute(db.select(sq.c.user_id, db.func.count(sq.c.user_id)).select_from(sq).group_by(sq.c.user_id).order_by(db.func.count(sq.c.user_id).desc())).all()

def get_name(discord_id):
try:
response = get_discord_member_by_discord_id(discord_id)
if not response:
return
except:
return

return response['user']['global_name']

results = [[get_name(discord_id), score] for discord_id, score in thanks_scores]
results = [[name, score] for name, score in results if name is not None][:25]
leaderboard = list(islice(
((discord_member["user"]["global_name"], score)
for discord_id, score in thanks_scores
if (discord_member := get_discord_member(discord_id))),
25
))

return {"success": True, "leaderboard": results}, 200
return {"success": True, "leaderboard": leaderboard}, 200
1 change: 1 addition & 0 deletions dojo_plugin/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,7 @@ class DiscordUserActivity(db.Model):
message_id = db.Column(db.BigInteger)
message_timestamp = db.Column(db.DateTime, default=datetime.datetime.utcnow)


class DiscordUsers(db.Model):
__tablename__ = "discord_users"
user_id = db.Column(
Expand Down
26 changes: 11 additions & 15 deletions dojo_plugin/pages/course.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,17 +368,10 @@ def view_course(dojo, resource=None):

discord_role = dojo.course.get("discord_role")
if discord_role:
if DiscordUsers.query.filter_by(user=user).first():
setup["create_discord"] = "complete"
setup["link_discord"] = "complete"
else:
setup["create_discord"] = "incomplete"
setup["link_discord"] = "incomplete"

if user and get_discord_member(user.id):
setup["join_discord"] = "complete"
else:
setup["join_discord"] = "incomplete"
discord_user = DiscordUsers.query.filter_by(user=user).first()
setup["create_discord"] = "complete" if discord_user else "incomplete"
setup["link_discord"] = "complete" if discord_user else "incomplete"
setup["join_discord"] = "complete" if discord_user and get_discord_member(discord_user.discord_id) else "incomplete"

setup_complete = all(status == "complete" for status in setup.values())

Expand Down Expand Up @@ -421,10 +414,11 @@ def update_identity(dojo):

discord_role = dojo.course.get("discord_role")
if discord_role:
discord_member = get_discord_member(user.id)
if discord_member is None:
discord_user = DiscordUsers.query.filter_by(user=user).first()
if not discord_user:
return {"success": True, "warning": "Your Discord account is not linked"}
if discord_member is False:
discord_member = get_discord_member(discord_user.discord_id)
if not discord_member:
return {"success": True, "warning": "Your Discord account has not joined the official Discord server"}
add_role(discord_member["user"]["id"], discord_role)

Expand Down Expand Up @@ -551,7 +545,9 @@ def view_user_info(dojo, user_id):
student = DojoStudents.query.filter_by(dojo=dojo, user=user).first()
identity = dict(identity_name=dojo.course.get("student_id", "Identity"),
identity_value=student.token if student else None)
discord_member = get_discord_member(user.id) if dojo.course.get("discord_role") else None
discord_member = (get_discord_member(DiscordUsers.query.filter_by(user=user)
.with_entities(DiscordUsers.discord_id).scalar())
if dojo.course.get("discord_role") else None)

return render_template("dojo_admin_user.html",
dojo=dojo,
Expand Down
4 changes: 1 addition & 3 deletions dojo_plugin/pages/discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from sqlalchemy.exc import IntegrityError
from itsdangerous.url_safe import URLSafeTimedSerializer
from CTFd.models import db
from CTFd.cache import cache
from CTFd.utils.user import get_current_user
from CTFd.utils.decorators import authed_only

Expand Down Expand Up @@ -66,8 +65,7 @@ def discord_redirect():
else:
existing_discord_user.discord_id = discord_id
db.session.commit()
cache.delete_memoized(get_discord_member, user_id)
if get_discord_member(user_id):
if get_discord_member(discord_id):
add_role(discord_id, "White Belt")
update_awards(user)
except IntegrityError:
Expand Down
5 changes: 3 additions & 2 deletions dojo_plugin/pages/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from CTFd.utils.decorators import authed_only
from CTFd.utils.user import get_current_user

from ..models import Dojos, SSHKeys, DojoMembers
from ..models import SSHKeys, DiscordUsers
from ..config import DISCORD_CLIENT_ID
from ..utils.discord import get_discord_member, discord_avatar_asset

Expand All @@ -19,7 +19,8 @@ def settings_override():

ssh_keys = SSHKeys.query.filter_by(user_id=user.id).all()

discord_member = get_discord_member(user.id)
discord_member = get_discord_member(DiscordUsers.query.filter_by(user=user)
.with_entities(DiscordUsers.discord_id).scalar())

prevent_name_change = get_config("prevent_name_change")

Expand Down
7 changes: 4 additions & 3 deletions dojo_plugin/utils/awards.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from flask import url_for

from .discord import get_discord_roles, get_discord_member, add_role, send_message
from ..models import Dojos, Belts, Emojis
from ..models import Dojos, Belts, Emojis, DiscordUsers


BELT_ORDER = [ "orange", "yellow", "green", "purple", "blue", "brown", "red", "black" ]
Expand Down Expand Up @@ -102,7 +102,8 @@ def update_awards(user):
db.session.commit()
current_belts.append(belt)

discord_member = get_discord_member(user.id)
discord_member = get_discord_member(DiscordUsers.query.filter_by(user=user)
.with_entities(DiscordUsers.discord_id).scalar())
discord_roles = get_discord_roles()
for belt in BELT_REQUIREMENTS:
if belt not in current_belts:
Expand All @@ -115,7 +116,7 @@ def update_awards(user):
message = f"{user_mention} earned their {belt_role}! :tada:"
add_role(discord_member["user"]["id"], belt_role)
send_message(message, "belting-ceremony")
cache.delete_memoized(get_discord_member, user.id)
cache.delete_memoized(get_discord_member, discord_member["user"]["id"])

current_emojis = get_user_emojis(user)
for emoji,dojo_name,dojo_id in current_emojis:
Expand Down
15 changes: 3 additions & 12 deletions dojo_plugin/utils/discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,9 @@ def get_discord_id(auth_code):


@cache.memoize(timeout=3600)
def get_discord_member(user_id):
if not DISCORD_BOT_TOKEN:
return

discord_user = DiscordUsers.query.filter_by(user_id=user_id).first()
if not discord_user:
def get_discord_member(discord_id):
if not DISCORD_BOT_TOKEN or not discord_id:
return

return get_discord_member_by_discord_id(discord_user.discord_id)


@cache.memoize(timeout=3600)
def get_discord_member_by_discord_id(discord_id):
try:
result = guild_request(f"/members/{discord_id}")
except requests.exceptions.RequestException:
Expand All @@ -91,6 +81,7 @@ def get_discord_member_by_discord_id(discord_id):
return None
return result


@cache.memoize(timeout=3600)
def get_discord_roles():
if not DISCORD_BOT_TOKEN:
Expand Down

0 comments on commit 6f58580

Please sign in to comment.