diff --git a/src/extensions/automod.py b/src/extensions/automod.py index 389d2df..eaf023b 100644 --- a/src/extensions/automod.py +++ b/src/extensions/automod.py @@ -8,13 +8,11 @@ import hikari import kosu -import src.utils as utils from src.etc import const from src.etc.settings_static import notices from src.models.client import SnedClient, SnedPlugin from src.models.events import AutoModMessageFlagEvent from src.utils import helpers -from src.utils.ratelimiter import MemberBucket INVITE_REGEX = re.compile(r"(?:https?://)?discord(?:app)?\.(?:com/invite|gg)/[a-zA-Z0-9]+/?") """Used to detect and handle Discord invites.""" @@ -23,12 +21,14 @@ DISCORD_FORMATTING_REGEX = re.compile(r"<\S+>") """Remove Discord-specific formatting. Performance is key so some false-positives are acceptable here.""" -SPAM_RATELIMITER = utils.RateLimiter(10, 8, bucket=MemberBucket, wait=False) -PUNISH_RATELIMITER = utils.RateLimiter(30, 1, bucket=MemberBucket, wait=False) -ATTACH_SPAM_RATELIMITER = utils.RateLimiter(30, 2, bucket=MemberBucket, wait=False) -LINK_SPAM_RATELIMITER = utils.RateLimiter(30, 2, bucket=MemberBucket, wait=False) -ESCALATE_PREWARN_RATELIMITER = utils.RateLimiter(30, 1, bucket=MemberBucket, wait=False) -ESCALATE_RATELIMITER = utils.RateLimiter(30, 1, bucket=MemberBucket, wait=False) +MEMBER_KEY_FUNC: t.Callable[[hikari.PartialMessage], str] = lambda m: f"{m.author.id if m.author else None}{m.guild_id}" # noqa: E731 + +SPAM_RATELIMITER = arc.utils.RateLimiter[hikari.PartialMessage](10, 8, get_key_with=MEMBER_KEY_FUNC) +PUNISH_RATELIMITER = arc.utils.RateLimiter[hikari.PartialMessage](30, 1, get_key_with=MEMBER_KEY_FUNC) +ATTACH_SPAM_RATELIMITER = arc.utils.RateLimiter[hikari.PartialMessage](30, 2, get_key_with=MEMBER_KEY_FUNC) +LINK_SPAM_RATELIMITER = arc.utils.RateLimiter[hikari.PartialMessage](30, 2, get_key_with=MEMBER_KEY_FUNC) +ESCALATE_PREWARN_RATELIMITER = arc.utils.RateLimiter[hikari.PartialMessage](30, 1, get_key_with=MEMBER_KEY_FUNC) +ESCALATE_RATELIMITER = arc.utils.RateLimiter[hikari.PartialMessage](30, 1, get_key_with=MEMBER_KEY_FUNC) logger = logging.getLogger(__name__) @@ -156,9 +156,9 @@ async def punish( silencers.append(AutoModState.ESCALATE.value) if state in silencers: - await PUNISH_RATELIMITER.acquire(message) - - if PUNISH_RATELIMITER.is_rate_limited(message): + try: + await PUNISH_RATELIMITER.acquire(message, wait=False) + except arc.utils.RateLimiterExhaustedError: return if not original_action: @@ -209,9 +209,10 @@ async def punish( return elif state == AutoModState.ESCALATE.value: - await ESCALATE_PREWARN_RATELIMITER.acquire(message) - - if not ESCALATE_PREWARN_RATELIMITER.is_rate_limited(message): + try: + # Check if the user has been warned before + await ESCALATE_PREWARN_RATELIMITER.acquire(message, wait=False) + # If not, issue a notice await message.respond( content=offender.mention, embed=hikari.Embed( @@ -230,19 +231,20 @@ async def punish( f"Message flagged by auto-moderator for {reason} ({action.name}).", ) ) - - elif ESCALATE_PREWARN_RATELIMITER.is_rate_limited(message): - embed = await plugin.client.mod.warn( - offender, - me, - f"Warned by auto-moderator for previous offenses ({action.name}).", - ) - await message.respond(embed=embed) - return - - else: - await ESCALATE_RATELIMITER.acquire(message) - if ESCALATE_RATELIMITER.is_rate_limited(message): + except hikari.RateLimitTooLongError: + # If yes, then check if we should escalate + try: + await ESCALATE_RATELIMITER.acquire(message, wait=False) + # If not, issue a warning + embed = await plugin.client.mod.warn( + offender, + me, + f"Warned by auto-moderator for previous offenses ({action.name}).", + ) + await message.respond(embed=embed) + return + except arc.utils.RateLimiterExhaustedError: + # Escalate to a full punishment return await punish( message=message, policies=policies, @@ -335,8 +337,12 @@ async def detect_spam(message: hikari.PartialMessage, policies: dict[str, t.Any] bool Whether or not the analysis should proceed to the next check. """ - await SPAM_RATELIMITER.acquire(message) - if policies["spam"]["state"] != AutoModState.DISABLED.value and SPAM_RATELIMITER.is_rate_limited(message): + if policies["spam"]["state"] == AutoModState.DISABLED.value: + return True + + try: + await SPAM_RATELIMITER.acquire(message, wait=False) + except arc.utils.RateLimiterExhaustedError: await punish(message, policies, AutomodActionType.SPAM, reason="spam") return False return True @@ -357,14 +363,14 @@ async def detect_attach_spam(message: hikari.PartialMessage, policies: dict[str, bool Whether or not the analysis should proceed to the next check. """ - if policies["attach_spam"]["state"] != AutoModState.DISABLED.value and message.attachments: - await ATTACH_SPAM_RATELIMITER.acquire(message) + if policies["attach_spam"]["state"] == AutoModState.DISABLED.value or not message.attachments: + return True - if ATTACH_SPAM_RATELIMITER.is_rate_limited(message): - await punish( - message, policies, AutomodActionType.ATTACH_SPAM, reason="posting images/attachments too quickly" - ) - return False + try: + await ATTACH_SPAM_RATELIMITER.acquire(message, wait=False) + except arc.utils.RateLimiterExhaustedError: + await punish(message, policies, AutomodActionType.ATTACH_SPAM, reason="posting images/attachments too quickly") + return False return True @@ -451,31 +457,30 @@ async def detect_link_spam(message: hikari.PartialMessage, policies: dict[str, t bool Whether or not the analysis should proceed to the next check. """ - if not message.content: + if not message.content or policies["link_spam"]["state"] == AutoModState.DISABLED.value: return True - if policies["link_spam"]["state"] != AutoModState.DISABLED.value: - link_matches = URL_REGEX.findall(message.content) - if len(link_matches) > 7: + link_matches = URL_REGEX.findall(message.content) + if len(link_matches) > 7: + await punish( + message, + policies, + AutomodActionType.LINK_SPAM, + reason="having too many links in a single message", + ) + return False + + if link_matches: + try: + await LINK_SPAM_RATELIMITER.acquire(message, wait=False) + except arc.utils.RateLimiterExhaustedError: await punish( message, policies, AutomodActionType.LINK_SPAM, - reason="having too many links in a single message", + reason="posting links too quickly", ) return False - - if link_matches: - await LINK_SPAM_RATELIMITER.acquire(message) - - if LINK_SPAM_RATELIMITER.is_rate_limited(message): - await punish( - message, - policies, - AutomodActionType.LINK_SPAM, - reason="posting links too quickly", - ) - return False return True diff --git a/src/extensions/dev.py b/src/extensions/dev.py index aa4201e..eaba57d 100644 --- a/src/extensions/dev.py +++ b/src/extensions/dev.py @@ -1,27 +1,33 @@ import ast +import contextlib import logging import os import shlex import subprocess import textwrap import traceback -from contextlib import suppress +import typing as t +import arc import hikari -import lightbulb import miru +from config import Config from miru.ext import nav from src.etc import const -from src.models import AuthorOnlyNavigator, SnedPrefixContext -from src.models.bot import SnedBot -from src.models.plugin import SnedPlugin +from src.models import AuthorOnlyNavigator +from src.models.client import SnedClient, SnedContext, SnedPlugin from src.models.views import AuthorOnlyView logger = logging.getLogger(__name__) -dev = SnedPlugin("Development") -dev.add_checks(lightbulb.owner_only) +plugin = SnedPlugin( + "Development", default_enabled_guilds=Config().DEBUG_GUILDS, default_permissions=hikari.Permissions.ADMINISTRATOR +).add_hook(arc.owner_only) + + +class CodeInputModal(miru.Modal): + code = miru.TextInput(label="Code", placeholder="Enter Python code here", custom_id="code", required=True) class TrashButton(nav.NavButton): @@ -64,7 +70,7 @@ def format_output(text: str) -> str: async def send_paginated( - ctx: SnedPrefixContext, + ctx: SnedContext, messageable: hikari.SnowflakeishOr[hikari.TextableChannel | hikari.User], text: str, *, @@ -80,16 +86,16 @@ async def send_paginated( if len(text) <= 2000: if channel_id: view = TrashView(ctx.author, timeout=300) - message = await ctx.app.rest.create_message( + message = await ctx.client.rest.create_message( channel_id, f"{prefix}{format_output(text)}{suffix}", components=view ) - return await view.start(message) + ctx.client.miru.start_view(view, bind_to=message) else: assert isinstance(messageable, (hikari.TextableChannel, hikari.User)) await messageable.send(f"{prefix}{format_output(text)}{suffix}") return - buttons = [ + buttons: list[nav.NavItem] = [ nav.FirstButton(emoji=const.EMOJI_FIRST), nav.PrevButton(emoji=const.EMOJI_PREV), nav.IndicatorButton(), @@ -97,92 +103,81 @@ async def send_paginated( nav.LastButton(emoji=const.EMOJI_LAST), TrashButton(), ] - paginator = lightbulb.utils.StringPaginator(prefix=prefix, suffix=suffix, max_chars=2000) + paginator = nav.Paginator(prefix=prefix, suffix=suffix, max_len=2000) for line in text.split("\n"): paginator.add_line(format_output(line)) - navmenu = OutputNav(ctx.author, pages=list(paginator.build_pages()), items=buttons, timeout=300) + navmenu = OutputNav(ctx.author, pages=list(paginator.pages), items=buttons, timeout=300) if not channel_id: assert isinstance(messageable, hikari.User) channel_id = hikari.Snowflake(await messageable.fetch_dm_channel()) - await navmenu.send(channel_id) + builder = await navmenu.build_response_async(ctx.client.miru) + await builder.send_to_channel(channel_id) -async def run_shell(ctx: SnedPrefixContext, code: str) -> None: +async def run_shell(ctx: SnedContext, code: str) -> None: """Run code in shell and return output to Discord.""" code = str(code).replace("```py", "").replace("`", "").strip() - await ctx.app.rest.trigger_typing(ctx.channel_id) + await ctx.defer() try: result = subprocess.run(shlex.split(code), stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=10.0) except subprocess.TimeoutExpired as e: - await ctx.event.message.add_reaction("❗") out_bytes = e.stderr or e.stdout out = ":\n" + out_bytes.decode("utf-8") if out_bytes else "" return await send_paginated(ctx, ctx.channel_id, "Process timed out" + out, prefix="```ansi\n", suffix="```") - if result.returncode != 0: - await ctx.event.message.add_reaction("❗") - if result.stderr and result.stderr.decode("utf-8"): - return await send_paginated( - ctx, ctx.channel_id, result.stderr.decode("utf-8"), prefix="```ansi\n", suffix="```" - ) + if result.returncode != 0 and result.stderr and result.stderr.decode("utf-8"): + return await send_paginated( + ctx, ctx.channel_id, result.stderr.decode("utf-8"), prefix="```ansi\n", suffix="```" + ) - await ctx.event.message.add_reaction("✅") if result.stdout and result.stdout.decode("utf-8"): await send_paginated(ctx, ctx.channel_id, result.stdout.decode("utf-8"), prefix="```ansi\n", suffix="```") -@dev.command -@lightbulb.option("extension_name", "The name of the extension to reload.") -@lightbulb.command("reload", "Reload an extension.", pass_options=True) -@lightbulb.implements(lightbulb.PrefixCommand) -async def reload_cmd(ctx: SnedPrefixContext, extension_name: str) -> None: - ctx.app.reload_extensions(extension_name) - await ctx.event.message.add_reaction("✅") - await ctx.respond(f"🔃 `{extension_name}`") - - -@dev.command -@lightbulb.option("extension_name", "The name of the extension to load.") -@lightbulb.command("load", "Load an extension.", pass_options=True) -@lightbulb.implements(lightbulb.PrefixCommand) -async def load_cmd(ctx: SnedPrefixContext, extension_name: str) -> None: - ctx.app.load_extensions(extension_name) - await ctx.event.message.add_reaction("✅") +@plugin.include +@arc.slash_command("load", "Load an extension.") +async def load_cmd( + ctx: SnedContext, extension_name: arc.Option[str, arc.StrParams("The name of the extension to load.")] +) -> None: + ctx.client.load_extension(extension_name) await ctx.respond(f"📥 `{extension_name}`") -@dev.command -@lightbulb.option("extension_name", "The name of the extension to unload.") -@lightbulb.command("unload", "Unload an extension.", pass_options=True) -@lightbulb.implements(lightbulb.PrefixCommand) -async def unload_cmd(ctx: SnedPrefixContext, extension_name: str) -> None: - ctx.app.unload_extensions(extension_name) - await ctx.event.message.add_reaction("✅") +@plugin.include +@arc.slash_command("unload", "Unload an extension.") +async def unload_cmd( + ctx: SnedContext, extension_name: arc.Option[str, arc.StrParams("The name of the extension to unload.")] +) -> None: + ctx.client.unload_extension(extension_name) await ctx.respond(f"📤 `{extension_name}`") -@dev.command -@lightbulb.option("code", "Code to execute.", modifier=lightbulb.OptionModifier.CONSUME_REST) -@lightbulb.command("py", "Run code.", pass_options=True) -@lightbulb.implements(lightbulb.PrefixCommand) -async def eval_py(ctx: SnedPrefixContext, code: str) -> None: - globals_dict = { +@plugin.include +@arc.slash_command("py", "Run code.") +async def eval_py(ctx: SnedContext) -> None: + globals_dict: dict[str, t.Any] = { "_author": ctx.author, - "_bot": ctx.bot, - "_app": ctx.app, + "_bot": ctx.client.app, + "_app": ctx.client.app, + "_client": ctx.client, "_channel": ctx.get_channel(), "_guild": ctx.get_guild(), - "_message": ctx.event.message, "_ctx": ctx, } - code = code.replace("```py", "").replace("`", "").strip() + modal = CodeInputModal() + await ctx.respond_with_builder(modal.build_response(ctx.client.miru)) + await modal.wait() + if modal.code.value is None: + return + + code = modal.code.value.replace("```py", "").replace("`", "").strip() # Check if last line is an expression and return it if so abstract_syntax_tree = ast.parse(code, filename=f"{ctx.guild_id}{ctx.channel_id}.py") @@ -195,98 +190,73 @@ async def eval_py(ctx: SnedPrefixContext, code: str) -> None: code_func = "async def _container():\n" + textwrap.indent(code, " ") - async with ctx.app.rest.trigger_typing(ctx.channel_id): - try: - exec(code_func, globals_dict, locals()) - return_value = await locals()["_container"]() - - await ctx.event.message.add_reaction("✅") - - if return_value is not None: - await send_paginated(ctx, ctx.channel_id, return_value, prefix="```py\n", suffix="```") - - except Exception as e: - try: - await ctx.event.message.add_reaction("❗") - await ctx.respond( - embed=hikari.Embed( - title="❌ Exception encountered", - description=f"```{e.__class__.__name__}: {e}```", - color=const.ERROR_COLOR, - ) - ) - except hikari.ForbiddenError: - pass - - traceback_msg = "\n".join(traceback.format_exception(type(e), e, e.__traceback__)) - await send_paginated(ctx, ctx.author, traceback_msg, prefix="```py\n", suffix="```") - - -def load(bot: SnedBot) -> None: - bot.add_plugin(dev) - - -def unload(bot: SnedBot) -> None: - bot.remove_plugin(dev) + await ctx.defer() + try: + exec(code_func, globals_dict, locals()) + return_value = await locals()["_container"]() + + if return_value is not None: + await send_paginated(ctx, ctx.channel_id, return_value, prefix="```py\n", suffix="```") + + except Exception as e: + with contextlib.suppress(hikari.ForbiddenError): + await ctx.respond( + embed=hikari.Embed( + title="❌ Exception encountered", + description=f"```{e.__class__.__name__}: {e}```", + color=const.ERROR_COLOR, + ) + ) -@dev.command -@lightbulb.option("code", "Code to execute.", modifier=lightbulb.OptionModifier.CONSUME_REST) -@lightbulb.command("sh", "Run code.", pass_options=True) -@lightbulb.implements(lightbulb.PrefixCommand) -async def eval_sh(ctx: SnedPrefixContext, code: str) -> None: - await run_shell(ctx, code) + traceback_msg = "\n".join(traceback.format_exception(type(e), e, e.__traceback__)) + await send_paginated(ctx, ctx.author, traceback_msg, prefix="```py\n", suffix="```") -@dev.command -@lightbulb.option("code", "Code to execute.", modifier=lightbulb.OptionModifier.CONSUME_REST) -@lightbulb.command("git", "Run git commands.", pass_options=True) -@lightbulb.implements(lightbulb.PrefixCommand) -async def dev_git_pull(ctx: SnedPrefixContext, code: str) -> None: - await run_shell(ctx, f"git {code}") +@plugin.include +@arc.slash_command("sh", "Run code.") +async def eval_sh(ctx: SnedContext) -> None: + modal = CodeInputModal() + await ctx.respond_with_builder(modal.build_response(ctx.client.miru)) + await modal.wait() + if modal.code.value is None: + return + await run_shell(ctx, modal.code.value) -@dev.command -@lightbulb.option( - "--force", "If True, purges application commands before re-registering them.", type=bool, required=False -) -@lightbulb.command("sync", "Sync application commands.") -@lightbulb.implements(lightbulb.PrefixCommand) -async def resync_app_cmds(ctx: SnedPrefixContext) -> None: - await ctx.app.rest.trigger_typing(ctx.channel_id) - if ctx.options["--force"]: - await ctx.app.purge_application_commands(*ctx.app.default_enabled_guilds, global_commands=True) - await ctx.app.sync_application_commands() - await ctx.event.message.add_reaction("✅") +@plugin.include +@arc.slash_command("sync", "Sync application commands.") +async def resync_app_cmds(ctx: SnedContext) -> None: + await ctx.defer() + await ctx.client.resync_commands() await ctx.respond("🔃 Synced application commands.") -@dev.command -@lightbulb.command("sql", "Execute an SQL file") -@lightbulb.implements(lightbulb.PrefixCommand) -async def run_sql(ctx: SnedPrefixContext) -> None: - if not ctx.attachments or not ctx.attachments[0].filename.endswith(".sql"): +@plugin.include +@arc.slash_command("sql", "Execute an SQL file") +async def run_sql( + ctx: SnedContext, file: arc.Option[hikari.Attachment, arc.AttachmentParams("The SQL file to execute")] +) -> None: + if file.filename.endswith(".sql"): await ctx.respond( embed=hikari.Embed( title="❌ No valid attachment", - description="Expected a singular `.sql` file as attachment with `UTF-8` encoding!", + description="Expected a `.sql` file.", color=const.ERROR_COLOR, ) ) return - await ctx.app.rest.trigger_typing(ctx.channel_id) - sql: str = (await ctx.attachments[0].read()).decode("utf-8") - return_value = await ctx.app.db.execute(sql) - await ctx.event.message.add_reaction("✅") + await ctx.defer() + sql: str = (await file.read()).decode("utf-8") + return_value = await ctx.client.db.execute(sql) await send_paginated(ctx, ctx.channel_id, str(return_value), prefix="```sql\n", suffix="```") -@dev.command -@lightbulb.command("shutdown", "Shut down the bot.") -@lightbulb.implements(lightbulb.PrefixCommand) -async def shutdown_cmd(ctx: SnedPrefixContext) -> None: +@plugin.include +@arc.slash_command("shutdown", "Shut down the bot.") +async def shutdown_cmd(ctx: SnedContext) -> None: confirm_payload = {"content": "⚠️ Shutting down...", "components": []} cancel_payload = {"content": "❌ Shutdown cancelled", "components": []} confirmed = await ctx.confirm( @@ -295,51 +265,48 @@ async def shutdown_cmd(ctx: SnedPrefixContext) -> None: cancel_payload=cancel_payload, ) if confirmed: - await ctx.event.message.add_reaction("✅") - return await ctx.app.close() - await ctx.event.message.add_reaction("❌") + return await ctx.client.app.close() + await ctx.respond("❌ Shutdown cancelled") -@dev.command -@lightbulb.command("pg_dump", "Back up the database.", aliases=["dbbackup", "backup"]) -@lightbulb.implements(lightbulb.PrefixCommand) -async def backup_db_cmd(ctx: SnedPrefixContext) -> None: - await ctx.app.backup_db() - await ctx.event.message.add_reaction("✅") +@plugin.include +@arc.slash_command("backup", "Back up the database.") +async def backup_db_cmd(ctx: SnedContext) -> None: + await ctx.client.backup_db() await ctx.respond("📤 Database backup complete.") -@dev.command -@lightbulb.option("--ignore-errors", "Ignore all errors.", type=bool, default=False) -@lightbulb.command("pg_restore", "Restore database from attached dump file.", aliases=["restore"]) -@lightbulb.implements(lightbulb.PrefixCommand) -async def restore_db(ctx: SnedPrefixContext) -> None: - if not ctx.attachments or not ctx.attachments[0].filename.endswith(".pgdmp"): +@plugin.include +@arc.slash_command("restore", "Restore database from attached dump file.") +async def restore_db( + ctx: SnedContext, + file: arc.Option[hikari.Attachment, arc.AttachmentParams("The pgdmp file to restore from")], + ignore_errors: arc.Option[bool, arc.BoolParams("Ignore all errors. This is wildly unsafe to use!")] = False, +) -> None: + if file.filename.endswith(".pgdmp"): await ctx.respond( embed=hikari.Embed( title="❌ No valid attachment", - description="Required dump-file attachment not found. Expected a `.pgdmp` file.", + description="Expected a `.pgdmp` file.", color=const.ERROR_COLOR, - ) + ), + flags=hikari.MessageFlag.EPHEMERAL, ) return - await ctx.app.rest.trigger_typing(ctx.channel_id) - - if not os.path.isdir(os.path.join(ctx.app.base_dir, "src", "db", "backups")): - os.mkdir(os.path.join(ctx.app.base_dir, "src", "db", "backups")) + await ctx.defer() - path = os.path.join(ctx.app.base_dir, "src", "db", "backups", "dev_pg_restore_snapshot.pgdmp") - with open(path, "wb") as file: - file.write((await ctx.attachments[0].read())) + if not os.path.isdir(os.path.join(ctx.client.base_dir, "src", "db", "backups")): + os.mkdir(os.path.join(ctx.client.base_dir, "src", "db", "backups")) - with suppress(hikari.HikariError): - await ctx.event.message.delete() + path = os.path.join(ctx.client.base_dir, "src", "db", "backups", "dev_pg_restore_snapshot.pgdmp") + with open(path, "wb") as f: + f.write((await file.read())) - await ctx.app.db_cache.stop() + await ctx.client.db_cache.stop() # Drop all tables - async with ctx.app.db.acquire() as con: + async with ctx.client.db.acquire() as con: records = await con.fetch( """ SELECT * FROM pg_catalog.pg_tables @@ -349,74 +316,72 @@ async def restore_db(ctx: SnedPrefixContext) -> None: for record in records: await con.execute(f"""DROP TABLE IF EXISTS {record.get("tablename")} CASCADE""") - arg = "-e" if not ctx.options["--ignore-errors"] else "" - code = os.system(f"pg_restore {path} {arg} -n 'public' -j 4 -d {ctx.app.db.dsn}") + arg = "-e" if not ignore_errors else "" + code = os.system(f"pg_restore {path} {arg} -n 'public' -j 4 -d {ctx.client.db.dsn}") - if code != 0 and not ctx.options["--ignore-errors"]: - await ctx.respond("❌ **Fatal:** Failed to load database backup, database corrupted. Shutting down...") - return await ctx.app.close() + if code != 0 and not ignore_errors: + await ctx.respond( + "❌ **Fatal:** Failed to load database backup, database corrupted. Shutting down...", + flags=hikari.MessageFlag.EPHEMERAL, + ) + return await ctx.client.app.close() elif code != 0: await ctx.respond( - "❌ **Fatal:** Failed to load database backup, database may be corrupted. Shutdown recommended." + "❌ **Fatal:** Failed to load database backup, database may be corrupted. Shutdown recommended.", + flags=hikari.MessageFlag.EPHEMERAL, ) else: - await ctx.app.db.update_schema() - await ctx.app.db_cache.start() - ctx.app.scheduler.restart() + await ctx.client.db.update_schema() + await ctx.client.db_cache.start() + ctx.client.scheduler.restart() await ctx.respond("📥 Restored database from backup file.") -@dev.command -@lightbulb.option("user", "The user to manage.", type=hikari.User) -@lightbulb.option("mode", "The mode of operation.", type=str) -@lightbulb.command("blacklist", "Commands to manage the blacklist.", pass_options=True) -@lightbulb.implements(lightbulb.PrefixCommand) -async def blacklist_cmd(ctx: SnedPrefixContext, mode: str, user: hikari.User) -> None: +@plugin.include +@arc.slash_command("blacklist", "Commands to manage the blacklist.") +async def blacklist_cmd( + ctx: SnedContext, + mode: arc.Option[str, arc.StrParams("The mode of operation.", choices=["add", "del"])], + user: arc.Option[hikari.User, arc.UserParams("The user to manage.")], +) -> None: if user.id == ctx.user.id: - await ctx.event.message.add_reaction("❌") await ctx.respond("❌ Cannot blacklist self") return - records = await ctx.app.db_cache.get(table="blacklist", user_id=user.id) + records = await ctx.client.db_cache.get(table="blacklist", user_id=user.id) - if mode.casefold() == "add": + if mode == "add": if records: - await ctx.event.message.add_reaction("❌") await ctx.respond("❌ Already blacklisted") return - await ctx.app.db.execute("""INSERT INTO blacklist (user_id) VALUES ($1)""", user.id) - await ctx.app.db_cache.refresh(table="blacklist", user_id=user.id) - await ctx.event.message.add_reaction("✅") + await ctx.client.db.execute("""INSERT INTO blacklist (user_id) VALUES ($1)""", user.id) + await ctx.client.db_cache.refresh(table="blacklist", user_id=user.id) await ctx.respond("✅ User added to blacklist") - elif mode.casefold() in ["del", "delete", "remove"]: + elif mode == "del": if not records: - await ctx.event.message.add_reaction("❌") await ctx.respond("❌ Not blacklisted") return - await ctx.app.db.execute("""DELETE FROM blacklist WHERE user_id = $1""", user.id) - await ctx.app.db_cache.refresh(table="blacklist", user_id=user.id) - await ctx.event.message.add_reaction("✅") + await ctx.client.db.execute("""DELETE FROM blacklist WHERE user_id = $1""", user.id) + await ctx.client.db_cache.refresh(table="blacklist", user_id=user.id) await ctx.respond("✅ User removed from blacklist") else: - await ctx.event.message.add_reaction("❌") await ctx.respond("❌ Invalid mode\nValid modes:`add`, `del`.") -@dev.command -@lightbulb.option("guild_id", "The guild_id to reset all settings for.", type=int) -@lightbulb.command("resetsettings", "Reset all settings for the specified guild.", pass_options=True) -@lightbulb.implements(lightbulb.PrefixCommand) -async def resetsettings_cmd(ctx: SnedPrefixContext, guild_id: int) -> None: - guild = ctx.app.cache.get_guild(guild_id) +@plugin.include +@arc.slash_command("resetsettings", "Reset all settings for the specified guild.") +async def resetsettings_cmd( + ctx: SnedContext, guild_id: arc.Option[int, arc.IntParams("The guild_id to reset all settings for.")] +) -> None: + guild = ctx.client.cache.get_guild(guild_id) if not guild: - await ctx.event.message.add_reaction("❌") await ctx.respond("❌ Guild not found.") return @@ -429,13 +394,22 @@ async def resetsettings_cmd(ctx: SnedPrefixContext, guild_id: int) -> None: if not confirmed: return await ctx.event.message.add_reaction("❌") - await ctx.app.db.wipe_guild(guild) - await ctx.app.db_cache.wipe(guild) + await ctx.client.db.wipe_guild(guild) + await ctx.client.db_cache.wipe(guild) - await ctx.event.message.add_reaction("✅") await ctx.respond(f"✅ Wiped data for guild `{guild.id}`.") +@arc.loader +def load(client: SnedClient) -> None: + client.add_plugin(plugin) + + +@arc.unloader +def unload(client: SnedClient) -> None: + client.remove_plugin(plugin) + + # Copyright (C) 2022-present hypergonial # This program is free software: you can redistribute it and/or modify diff --git a/src/extensions/fun.py b/src/extensions/fun.py index 6762063..5a514f5 100644 --- a/src/extensions/fun.py +++ b/src/extensions/fun.py @@ -19,9 +19,8 @@ from src.etc import const from src.models.client import SnedClient, SnedContext, SnedPlugin from src.models.views import AuthorOnlyNavigator, AuthorOnlyView -from src.utils import GlobalBucket, RateLimiter, helpers +from src.utils import helpers from src.utils.dictionaryapi import DictionaryClient, DictionaryEntry, DictionaryError, UrbanEntry -from src.utils.ratelimiter import UserBucket from src.utils.rpn import InvalidExpressionError, Solver from ..config import Config @@ -39,9 +38,11 @@ "racoon": "🦝", } -ANIMAL_RATELIMITER = RateLimiter(60, 45, GlobalBucket, wait=False) -COMF_LIMITER = RateLimiter(60, 5, UserBucket, wait=False) -VESZTETTEM_LIMITER = RateLimiter(1800, 1, GlobalBucket, wait=False) +ANIMAL_RATELIMITER = arc.utils.RateLimiter[hikari.PartialMessage](60, 45, get_key_with=lambda _: "padoru") +COMF_LIMITER = arc.utils.RateLimiter[hikari.PartialMessage]( + 60, 5, get_key_with=lambda c: str(c.author.id if c.author else 0) +) +VESZTETTEM_LIMITER = arc.utils.RateLimiter[hikari.PartialMessage](1800, 1, get_key_with=lambda _: "padoru") COMF_PROGRESS_BAR_WIDTH = 20 logger = logging.getLogger(__name__) @@ -951,9 +952,9 @@ async def lose_autoresponse(event: hikari.GuildMessageCreateEvent) -> None: return if event.content and "vesztettem" in event.content.lower(): - await VESZTETTEM_LIMITER.acquire(event.message) - - if VESZTETTEM_LIMITER.is_rate_limited(event.message): + try: + await VESZTETTEM_LIMITER.acquire(event.message) + except arc.utils.RateLimiterExhaustedError: return await event.message.respond("Vesztettem") diff --git a/src/extensions/reports.py b/src/extensions/reports.py index bfd3c09..a3e07ed 100644 --- a/src/extensions/reports.py +++ b/src/extensions/reports.py @@ -2,8 +2,8 @@ import arc import hikari -import lightbulb import miru +import toolbox from src.etc import const from src.models.client import SnedClient, SnedContext, SnedPlugin @@ -114,7 +114,7 @@ async def report(ctx: SnedContext, member: hikari.Member, message: hikari.Messag me = ctx.client.cache.get_member(ctx.guild_id, ctx.client.user_id) assert me is not None - perms = lightbulb.utils.permissions_in(channel, me) + perms = toolbox.calculate_permissions(me, channel) if not (perms & hikari.Permissions.SEND_MESSAGES): return await report_perms_error(ctx) diff --git a/src/extensions/test.py b/src/extensions/test.py index 93e79aa..5a0f1ab 100644 --- a/src/extensions/test.py +++ b/src/extensions/test.py @@ -1,5 +1,7 @@ import logging +import arc + from src.models.client import SnedClient, SnedPlugin logger = logging.getLogger(__name__) @@ -7,11 +9,13 @@ plugin = SnedPlugin("Test") +@arc.loader def load(client: SnedClient) -> None: # client.add_plugin(test) pass +@arc.unloader def unload(client: SnedClient) -> None: # client.remove_plugin(test) pass diff --git a/src/extensions/userlog.py b/src/extensions/userlog.py index 94a1cbc..e132d33 100644 --- a/src/extensions/userlog.py +++ b/src/extensions/userlog.py @@ -6,6 +6,7 @@ import re import typing as t +import arc import attr import hikari @@ -886,10 +887,12 @@ async def rolebutton_update(event: RoleButtonUpdateEvent) -> None: await plugin.client.userlogger.log(LogEvent.ROLES, log_embed, event.guild_id) +@arc.loader def load(client: SnedClient) -> None: client.add_plugin(plugin) +@arc.unloader def unload(client: SnedClient) -> None: client.remove_plugin(plugin) diff --git a/src/models/mod_actions.py b/src/models/mod_actions.py index e770fa5..ecfe62b 100644 --- a/src/models/mod_actions.py +++ b/src/models/mod_actions.py @@ -299,6 +299,8 @@ async def timeout_extend(self, event: TimerCompleteEvent) -> None: if not event.get_guild(): return + assert isinstance(event.app, hikari.GatewayBot) + member: hikari.Member | None = event.app.cache.get_member(timer.guild_id, timer.user_id) assert timer.notes is not None expiry = int(timer.notes)