Skip to content

Commit

Permalink
Make things neater
Browse files Browse the repository at this point in the history
  • Loading branch information
trumully committed Oct 4, 2024
1 parent 6705836 commit 3408b1f
Show file tree
Hide file tree
Showing 18 changed files with 176 additions and 201 deletions.
6 changes: 5 additions & 1 deletion dynamo/_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Callable, Coroutine, Mapping
from typing import Any, ParamSpec, Protocol, TypeVar

import discord.abc
from discord import Interaction as DInter
from discord import app_commands
from discord.ext import commands
Expand All @@ -16,7 +17,10 @@

CogT = TypeVar("CogT", bound=commands.Cog)
CommandT = TypeVar("CommandT", bound=commands.Command[Any, ..., Any])
ContextT = TypeVar("ContextT", bound=commands.Context[Any], covariant=True)
ContextT_co = TypeVar("ContextT_co", bound=commands.Context[Any], covariant=True)

type AppCommandT[CogT: commands.Cog, **P, T] = app_commands.Command[CogT, P, T]
type MaybeSnowflake = discord.abc.Snowflake | None


type Coro[T] = Coroutine[Any, Any, T]
Expand Down
6 changes: 3 additions & 3 deletions dynamo/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from dynamo.core.base_cog import DynamoCog
from dynamo.core.bot import Dynamo
from dynamo.core.logging_context import setup_logging
from dynamo.core.cog import Cog
from dynamo.core.logger import setup_logging

__all__ = (
"Cog",
"Dynamo",
"DynamoCog",
"setup_logging",
)
94 changes: 49 additions & 45 deletions dynamo/core/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import re
from collections.abc import AsyncGenerator, Generator
from typing import Any, Self, cast
from typing import Any, cast

import aiohttp
import apsw
Expand All @@ -13,22 +13,13 @@
from discord import app_commands
from discord.ext import commands

from dynamo._types import RawSubmittable
from dynamo._types import AppCommandT, MaybeSnowflake, RawSubmittable
from dynamo.utils.context import Context
from dynamo.utils.emoji import Emojis
from dynamo.utils.helper import get_cog, platformdir, resolve_path_with_links

log = logging.getLogger(__name__)

initial_extensions = (
get_cog("errors"),
get_cog("help"),
get_cog("dev"),
get_cog("events"),
get_cog("general"),
get_cog("info"),
get_cog("tags"),
)
initial_extensions = tuple(get_cog(e) for e in ("errors", "help", "dev", "events", "general", "info", "tags"))

description = """
Quantum entanglement.
Expand All @@ -38,9 +29,13 @@
button_regex = re.compile(r"^b:(.{1,10}):(.*)$", flags=re.DOTALL)


class VersionableTree(app_commands.CommandTree["Dynamo"]):
class DynamoTree(app_commands.CommandTree["Dynamo"]):
"""Versionable and mentionable command tree"""

type CommandT = commands.Command[Any, ..., Any] | app_commands.Command[Any, ..., Any] | str

application_commands: dict[int | None, list[app_commands.AppCommand]]
cache: dict[int | None, dict[commands.Command[Any, ..., Any] | app_commands.Command[Any, ..., Any] | str, str]]
cache: dict[int | None, dict[CommandT | str, str]]

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
Expand All @@ -59,47 +54,42 @@ async def get_hash(self, tree: app_commands.CommandTree) -> bytes:
return xxhash.xxh64_digest(msgspec.msgpack.encode(payload), seed=0)

# See: https://gist.github.com/LeoCx1000/021dc52981299b95ea7790416e4f5ca4#file-mentionable_tree-py
async def sync(self, *, guild: discord.abc.Snowflake | None = None) -> list[app_commands.AppCommand]:
async def sync(self, *, guild: MaybeSnowflake = None) -> list[app_commands.AppCommand]:
result = await super().sync(guild=guild)
guild_id = guild.id if guild else None
self.application_commands[guild_id] = result
self.cache.pop(guild_id, None)
return result

async def fetch_commands(self, *, guild: discord.abc.Snowflake | None = None) -> list[app_commands.AppCommand]:
async def fetch_commands(self, *, guild: MaybeSnowflake = None) -> list[app_commands.AppCommand]:
result = await super().fetch_commands(guild=guild)
guild_id = guild.id if guild else None
self.application_commands[guild_id] = result
self.cache.pop(guild_id, None)
return result

async def get_or_fetch_commands(self, guild: discord.abc.Snowflake | None = None) -> list[app_commands.AppCommand]:
async def get_or_fetch_commands(self, guild: MaybeSnowflake = None) -> list[app_commands.AppCommand]:
try:
return self.application_commands[guild.id if guild else None]
except KeyError:
return await self.fetch_commands(guild=guild)

async def find_mention_for(
self,
command: commands.Command[Any, ..., Any] | app_commands.Command[Any, ..., Any] | str,
*,
guild: discord.abc.Snowflake | None = None,
) -> str | None:
async def find_mention_for(self, command: CommandT, *, guild: discord.abc.Snowflake | None = None) -> str | None:
guild_id = guild.id if guild else None
try:
return self.cache[guild_id][command]
except KeyError:
pass

check_global = self.fallback_to_global is True or guild is not None
check_global = self.fallback_to_global is True or guild is None

if isinstance(command, str):
# Workaround: discord.py doesn't return children from tree.get_command
_command = discord.utils.get(self.walk_commands(guild=guild), qualified_name=command)
if check_global and not _command:
_command = discord.utils.get(self.walk_commands(), qualified_name=command)
else:
_command = cast(app_commands.Command[Any, ..., Any], command)
_command = cast(AppCommandT[Any, ..., Any], command)

if not _command:
return None
Expand All @@ -114,30 +104,36 @@ async def find_mention_for(
if not app_command_found:
return None

mention = f"</{_command.qualified_name}:{app_command_found.id}>"
self.cache.setdefault(guild_id, {})
self.cache[guild_id][command] = mention
self.cache[guild_id][command] = mention = f"</{_command.qualified_name}:{app_command_found.id}>"
return mention

def _walk_children[AppCommand: (app_commands.Command[Any, ..., Any], app_commands.Group)](
self, commands: list[AppCommand | app_commands.Group]
) -> Generator[AppCommand, None, None]:
def _walk_children[CogT: commands.Cog, **P, T](
self,
commands: list[AppCommandT[CogT, P, T]],
) -> Generator[AppCommandT[CogT, P, T], None, None]:
for command in commands:
if isinstance(command, app_commands.Group):
cmds: list[AppCommand] = cast(list[AppCommand], command.commands)
cmds: list[AppCommandT[CogT, P, T]] = cast(list[AppCommandT[CogT, P, T]], command.commands)
yield from self._walk_children(cmds)
else:
yield command

async def walk_mentions(
self, *, guild: discord.abc.Snowflake | None = None
) -> AsyncGenerator[tuple[app_commands.Command[Any, ..., Any], str], None]:
for command in self._walk_children(self.get_commands(guild=guild, type=discord.AppCommandType.chat_input)):
async def walk_mentions[CogT: commands.Cog, **P, T](
self, *, guild: MaybeSnowflake = None
) -> AsyncGenerator[tuple[AppCommandT[CogT, P, T], str], None]:
commands = cast(
list[AppCommandT[CogT, P, T]], self.get_commands(guild=guild, type=discord.AppCommandType.chat_input)
)
for command in self._walk_children(commands):
mention = await self.find_mention_for(command, guild=guild)
if mention:
yield command, mention
if guild and self.fallback_to_global is True:
for command in self._walk_children(self.get_commands(guild=None, type=discord.AppCommandType.chat_input)):
commands = cast(
list[AppCommandT[CogT, P, T]], self.get_commands(guild=None, type=discord.AppCommandType.chat_input)
)
for command in self._walk_children(commands):
mention = await self.find_mention_for(command, guild=guild)
if mention:
yield command, mention
Expand All @@ -155,6 +151,14 @@ def _prefix_callable(bot: Dynamo, msg: discord.Message) -> list[str]:
return base


class Emojis(dict[str, str]):
def __init__(self, emojis: list[discord.Emoji], *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

for emoji in emojis:
self[emoji.name] = f"<{"a" if emoji.animated else ""}:{emoji.name}:{emoji.id}>"


type Interaction = discord.Interaction[Dynamo]


Expand Down Expand Up @@ -182,7 +186,7 @@ def __init__(self, connector: aiohttp.TCPConnector, conn: apsw.Connection, sessi
allowed_mentions=allowed_mentions,
intents=intents,
enable_debug_events=True,
tree_cls=VersionableTree,
tree_cls=DynamoTree,
activity=discord.Activity(name="The Cursed Apple", type=discord.ActivityType.watching),
)
self.raw_modal_submits: dict[str, RawSubmittable] = {}
Expand All @@ -196,11 +200,11 @@ async def setup_hook(self) -> None:

self.app_emojis = Emojis(await self.fetch_application_emojis())

for ext in initial_extensions:
for extension in initial_extensions:
try:
await self.load_extension(ext)
await self.load_extension(extension)
except commands.ExtensionError:
log.exception("Failed to load extension %s", ext)
log.exception("Failed to load extension %s", extension)

tree_path = resolve_path_with_links(platformdir.user_cache_path / "tree.hash")
tree_hash = await self.tree.get_hash(self.tree)
Expand All @@ -222,8 +226,8 @@ def user(self) -> discord.ClientUser:
return cast(discord.ClientUser, super().user)

@property
def tree(self) -> VersionableTree:
return cast(VersionableTree, super().tree)
def tree(self) -> DynamoTree:
return cast(DynamoTree, super().tree)

@property
def dev_guild(self) -> discord.Guild:
Expand All @@ -242,12 +246,12 @@ async def on_ready(self) -> None:

log.info("Ready: %s (ID: %s)", self.user, self.user.id)

async def on_interaction(self, interaction: discord.Interaction[Self]) -> None:
for typ, regex, mapping in (
async def on_interaction(self, interaction: Interaction) -> None:
for relevant_type, regex, mapping in (
(discord.InteractionType.modal_submit, modal_regex, self.raw_modal_submits),
(discord.InteractionType.component, button_regex, self.raw_button_submits),
):
if interaction.type is typ:
if interaction.type is relevant_type:
assert interaction.data is not None
custom_id = interaction.data.get("custom_id", "")
if match := regex.match(custom_id):
Expand Down
4 changes: 3 additions & 1 deletion dynamo/core/base_cog.py → dynamo/core/cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from dynamo.core import Dynamo


class DynamoCog(commands.Cog):
class Cog(commands.Cog):
"""Dynamo cog. Sets up logging and any existing raw submittables."""

__slots__ = ("bot", "log")

def __init__(
Expand Down
File renamed without changes.
70 changes: 31 additions & 39 deletions dynamo/extensions/cogs/dev.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,39 @@
import importlib
import sys
from collections.abc import Callable
from functools import partial
from typing import Literal

import discord
from discord.ext import commands

from dynamo.core import Dynamo, DynamoCog
from dynamo._types import Coro
from dynamo.core import Cog, Dynamo
from dynamo.core.bot import Emojis
from dynamo.utils.checks import is_owner
from dynamo.utils.context import Context
from dynamo.utils.emoji import Emojis
from dynamo.utils.format import code_block
from dynamo.utils.helper import get_cog

type SyncSpec = Literal["~", "*", "^"]


class Dev(DynamoCog):
class Dev(Cog):
"""Dev-only commands"""

def __init__(self, bot: Dynamo) -> None:
super().__init__(bot)
self.try_load_extension = partial(self._cog_try, self.bot.load_extension)
self.try_unload_extension = partial(self._cog_try, self.bot.unload_extension)

async def _cog_try(self, coro: Callable[[str], Coro[None]], cog: str) -> bool:
try:
await coro(cog)
except commands.ExtensionError:
self.log.exception("Coroutine %s failed for cog %s", coro.__name__, cog)
return False
return True

@commands.hybrid_command(name="sync", aliases=("s",))
@commands.guild_only()
@is_owner()
Expand Down Expand Up @@ -86,15 +102,8 @@ async def load(self, ctx: Context, *, module: str) -> None:
module: str
The name of the cog to load.
"""
message = ctx.message
cog = get_cog(module)
try:
await self.bot.load_extension(cog)
except commands.ExtensionError as ex:
await ctx.send(f"{ex.__class__.__name__}: {ex}")
self.log.exception("Failed to load %s", cog)
else:
await message.add_reaction(ctx.Status.OK)
success = await self.try_load_extension(get_cog(module))
await ctx.message.add_reaction(ctx.Status.OK if success else ctx.Status.FAILURE)

@commands.hybrid_command(aliases=("ul",))
@is_owner()
Expand All @@ -106,15 +115,8 @@ async def unload(self, ctx: Context, *, module: str) -> None:
module: str
The name of the cog to unload.
"""
message = ctx.message
cog = get_cog(module)
try:
await self.bot.unload_extension(cog)
except commands.ExtensionError as ex:
await ctx.send(f"{ex.__class__.__name__}: {ex}")
self.log.exception("Failed to unload %s", cog)
else:
await message.add_reaction(ctx.Status.OK)
success = await self.try_unload_extension(get_cog(module))
await ctx.message.add_reaction(ctx.Status.OK if success else ctx.Status.FAILURE)

@commands.hybrid_group(name="reload", aliases=("r",), invoke_without_command=True)
@is_owner()
Expand All @@ -126,17 +128,10 @@ async def _reload(self, ctx: Context, *, module: str) -> None:
module: str
The name of the cog to reload.
"""
message: discord.Message | None = ctx.message
cog = get_cog(module)
try:
await self.bot.reload_extension(cog)
except commands.ExtensionError as ex:
await ctx.send(f"{ex.__class__.__name__}: {ex}")
self.log.exception("Failed to reload %s", cog)
else:
await message.add_reaction(ctx.Status.OK)

async def reload_or_load_extension(self, module: str) -> None:
success = await self.try_reload_extension(get_cog(module))
await ctx.message.add_reaction(ctx.Status.OK if success else ctx.Status.FAILURE)

async def try_reload_extension(self, module: str) -> bool:
try:
await self.bot.reload_extension(module)
except commands.ExtensionNotLoaded:
Expand All @@ -145,6 +140,8 @@ async def reload_or_load_extension(self, module: str) -> None:
await self.bot.load_extension(module)
except commands.ExtensionError:
self.log.exception("Failed to load %s", module)
return False
return True

@_reload.command(name="all")
@is_owner()
Expand All @@ -167,13 +164,8 @@ async def _reload_all(self, ctx: Context) -> None:
extensions = frozenset(self.bot.extensions)
statuses: set[tuple[ctx.Status, str]] = set()
for ext in extensions:
try:
await self.reload_or_load_extension(ext)
except commands.ExtensionError:
self.log.exception("Failed to reload extension %s", ext)
statuses.add((ctx.Status.FAILURE, ext))
else:
statuses.add((ctx.Status.SUCCESS, ext))
success = await self.try_reload_extension(ext)
statuses.add((ctx.Status.SUCCESS if success else ctx.Status.FAILURE, ext))

success_count = sum(1 for status, _ in statuses if status == ctx.Status.SUCCESS)
self.log.debug("Reloaded %d/%d extensions", success_count, len(extensions))
Expand Down
Loading

0 comments on commit 3408b1f

Please sign in to comment.