Skip to content

Commit

Permalink
Shuffle out tree and context into core
Browse files Browse the repository at this point in the history
  • Loading branch information
trumully committed Oct 7, 2024
1 parent 7e9d1d9 commit df277a6
Show file tree
Hide file tree
Showing 29 changed files with 444 additions and 380 deletions.
17 changes: 17 additions & 0 deletions dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from dynamo import utils
from dynamo.core.logger import setup_logging

__all__ = (
"Cog",
"Dynamo",
"Context",
"Interaction",
"Tree",
"setup_logging",
"utils",
)

from dynamo.core.bot import Dynamo, Interaction
from dynamo.core.cog import Cog
from dynamo.core.context import Context
from dynamo.core.tree import Tree
9 changes: 0 additions & 9 deletions dynamo/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +0,0 @@
from dynamo.core.bot import Dynamo
from dynamo.core.cog import Cog
from dynamo.core.logger import setup_logging

__all__ = (
"Cog",
"Dynamo",
"setup_logging",
)
187 changes: 26 additions & 161 deletions dynamo/core/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,22 @@

import logging
import re
from collections.abc import AsyncGenerator, Generator
from importlib import import_module
from importlib.util import find_spec
from pathlib import Path
from typing import Any, cast

import aiohttp
import apsw
import discord
import msgspec
import xxhash
from discord import app_commands
from discord.ext import commands

from dynamo._types import AppCommandT, MaybeSnowflake, RawSubmittable
from dynamo.utils.context import Context
from dynamo.utils.helper import get_cog, platformdir, resolve_path_with_links
from dynamo.core.context import Context
from dynamo.core.tree import Tree
from dynamo.typedefs import RawSubmittable
from dynamo.utils.helper import platformdir, resolve_path_with_links

log = logging.getLogger(__name__)

initial_extensions = tuple(get_cog(e) for e in ("errors", "help", "dev", "events", "general", "info", "tags"))

description = """
Quantum entanglement.
"""
Expand All @@ -31,118 +26,6 @@
button_regex = re.compile(r"^b:(.{1,10}):(.*)$", flags=re.DOTALL)


class DynamoTree(app_commands.CommandTree["Dynamo"]):
"""Versionable and mentionable command tree"""

type CommandT = commands.Command[Any, ..., Any] | app_commands.Command[Any, ..., Any] | str

application_commands: dict[int | None, list[app_commands.AppCommand]]
cache: dict[int | None, dict[CommandT | str, str]]

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.application_commands = {}
self.cache = {}

async def get_hash(self, tree: app_commands.CommandTree) -> bytes:
"""Get the hash of the command tree."""
commands = sorted(self._get_all_commands(guild=None), key=lambda c: c.qualified_name)

if translator := self.translator:
payload = [await command.get_translated_payload(tree, translator) for command in commands]
else:
payload = [command.to_dict(tree) for command in commands]

return xxhash.xxh64_digest(msgspec.msgpack.encode(payload), seed=0)

# See: https://gist.github.com/LeoCx1000/021dc52981299b95ea7790416e4f5ca4#file-mentionable_tree-py
async def sync(self, *, guild: MaybeSnowflake = None) -> list[app_commands.AppCommand]:
result = await super().sync(guild=guild)
guild_id = guild.id if guild else None
self.application_commands[guild_id] = result
self.cache.pop(guild_id, None)
return result

async def fetch_commands(self, *, guild: MaybeSnowflake = None) -> list[app_commands.AppCommand]:
result = await super().fetch_commands(guild=guild)
guild_id = guild.id if guild else None
self.application_commands[guild_id] = result
self.cache.pop(guild_id, None)
return result

async def get_or_fetch_commands(self, guild: MaybeSnowflake = None) -> list[app_commands.AppCommand]:
try:
return self.application_commands[guild.id if guild else None]
except KeyError:
return await self.fetch_commands(guild=guild)

async def find_mention_for(self, command: CommandT, *, guild: discord.abc.Snowflake | None = None) -> str | None:
guild_id = guild.id if guild else None
try:
return self.cache[guild_id][command]
except KeyError:
pass

check_global = self.fallback_to_global is True or guild is None

if isinstance(command, str):
# Workaround: discord.py doesn't return children from tree.get_command
_command = discord.utils.get(self.walk_commands(guild=guild), qualified_name=command)
if check_global and not _command:
_command = discord.utils.get(self.walk_commands(), qualified_name=command)
else:
_command = cast(AppCommandT[Any, ..., Any], command)

if not _command:
return None

local_commands = await self.get_or_fetch_commands(guild=guild)
app_command_found = discord.utils.get(local_commands, name=(_command.root_parent or _command).name)

if check_global and not app_command_found:
global_commands = await self.get_or_fetch_commands(guild=None)
app_command_found = discord.utils.get(global_commands, name=(_command.root_parent or _command).name)

if not app_command_found:
return None

self.cache.setdefault(guild_id, {})
self.cache[guild_id][command] = mention = f"</{_command.qualified_name}:{app_command_found.id}>"
return mention

def _walk_children[CogT: commands.Cog, **P, T](
self,
commands: list[AppCommandT[CogT, P, T]],
) -> Generator[AppCommandT[CogT, P, T], None, None]:
for command in commands:
if isinstance(command, app_commands.Group):
cmds: list[AppCommandT[CogT, P, T]] = cast(list[AppCommandT[CogT, P, T]], command.commands)
yield from self._walk_children(cmds)
else:
yield command

async def walk_mentions[CogT: commands.Cog, **P, T](
self, *, guild: MaybeSnowflake = None
) -> AsyncGenerator[tuple[AppCommandT[CogT, P, T], str], None]:
commands = cast(
list[AppCommandT[CogT, P, T]], self.get_commands(guild=guild, type=discord.AppCommandType.chat_input)
)
for command in self._walk_children(commands):
mention = await self.find_mention_for(command, guild=guild)
if mention:
yield command, mention
if guild and self.fallback_to_global is True:
commands = cast(
list[AppCommandT[CogT, P, T]], self.get_commands(guild=None, type=discord.AppCommandType.chat_input)
)
for command in self._walk_children(commands):
mention = await self.find_mention_for(command, guild=guild)
if mention:
yield command, mention
else:
log.warning("Could not find a mention for command %s in the API. Are you out of sync?", command)


def _prefix_callable(bot: Dynamo, msg: discord.Message) -> list[str]:
user_id = bot.user.id
base = [f"<@{user_id}> ", f"<@!{user_id}> "]
Expand All @@ -161,7 +44,7 @@ def __init__(self, emojis: list[discord.Emoji], *args: Any, **kwargs: Any) -> No
self[emoji.name] = f"<{"a" if emoji.animated else ""}:{emoji.name}:{emoji.id}>"


type Interaction = discord.Interaction[Dynamo]
type Interaction = discord.Interaction["Dynamo"]


class Dynamo(commands.AutoShardedBot):
Expand All @@ -188,7 +71,7 @@ def __init__(self, connector: aiohttp.TCPConnector, conn: apsw.Connection, sessi
allowed_mentions=allowed_mentions,
intents=intents,
enable_debug_events=True,
tree_cls=DynamoTree,
tree_cls=Tree,
activity=discord.Activity(name="The Cursed Apple", type=discord.ActivityType.watching),
)
self.raw_modal_submits: dict[str, RawSubmittable] = {}
Expand All @@ -201,10 +84,21 @@ async def setup_hook(self) -> None:
self.owner_id = self.bot_app_info.owner.id

self.app_emojis = Emojis(await self.fetch_application_emojis())
self.cog_file_times: dict[str, float] = {}

for extension in initial_extensions:
await self.load_extension_with_timestamp(extension)
self.extension_files: dict[Path, float] = {}

self.cog_spec = find_spec("dynamo.extensions.cogs")
if self.cog_spec is None or self.cog_spec.origin is None:
log.critical("Failed to find cog spec! Loading without cogs.")
return

all_cogs = Path(self.cog_spec.origin).parent
for cog_path in all_cogs.rglob("**/*.py"):
if cog_path.is_file() and not cog_path.name.startswith("_"):
cog_name = self.get_cog_name(cog_path.stem)
try:
await self.load_extension(cog_name)
except commands.ExtensionError:
log.exception("Failed to load cog %s", cog_name)

tree_path = resolve_path_with_links(platformdir.user_cache_path / "tree.hash")
tree_hash = await self.tree.get_hash(self.tree)
Expand All @@ -217,37 +111,8 @@ async def setup_hook(self) -> None:
fp.seek(0)
fp.write(tree_hash)

def get_cog_path(self, cog: str) -> Path:
module = import_module(cog)
if module.__file__ is None:
error = f"Could not determine file path for cog {cog}"
log.exception(error)
raise RuntimeError(error)
return Path(module.__file__)

async def load_extension_with_timestamp(self, extension: str) -> None:
try:
await self.load_extension(extension)
self.cog_file_times[extension] = self.get_cog_path(extension).lstat().st_mtime
except commands.ExtensionError:
log.exception("Failed to load extension %s", extension)

async def lazy_load_cog(self, cog_name: str) -> None:
"""Lazily load a cog if it has been modified."""
cog_path = self.get_cog_path(cog_name)
current_mtime = cog_path.lstat().st_mtime
if current_mtime > self.cog_file_times.get(cog_name, 0):
try:
await self.reload_extension(cog_name)
self.cog_file_times[cog_name] = current_mtime
log.info("Reloaded modified cog: %s", cog_name)
except commands.ExtensionError:
log.exception("Failed to reload cog: %s", cog_name)

@commands.Cog.listener()
async def on_command_completion(self, ctx: Context) -> None:
if ctx.cog:
await self.lazy_load_cog(ctx.cog.__module__)
def get_cog_name(self, name: str) -> str:
return name.lower() if self.cog_spec is None else f"{self.cog_spec.name}.{name.lower()}"

@property
def owner(self) -> discord.User:
Expand All @@ -258,8 +123,8 @@ def user(self) -> discord.ClientUser:
return cast(discord.ClientUser, super().user)

@property
def tree(self) -> DynamoTree:
return cast(DynamoTree, super().tree)
def tree(self) -> Tree:
return cast(Tree, super().tree)

@property
def dev_guild(self) -> discord.Guild:
Expand Down
7 changes: 3 additions & 4 deletions dynamo/core/cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

from discord.ext import commands

from dynamo._types import RawSubmittable
from dynamo.utils.helper import get_cog
from dynamo.typedefs import RawSubmittable

if TYPE_CHECKING:
from dynamo.core import Dynamo
from dynamo import Dynamo

type Submittables = dict[str, type[RawSubmittable]]

Expand All @@ -26,7 +25,7 @@ def __init__(
raw_button_submits: Submittables | None = None,
) -> None:
self.bot: Dynamo = bot
self.log = logging.getLogger(get_cog(self.__class__.__name__))
self.log = logging.getLogger(self.bot.get_cog_name(self.__class__.__name__))
if raw_modal_submits is not None:
self.bot.raw_modal_submits.update(raw_modal_submits)
if raw_button_submits is not None:
Expand Down
10 changes: 5 additions & 5 deletions dynamo/utils/context.py → dynamo/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from discord.ui import View

if TYPE_CHECKING:
from dynamo.core.bot import Dynamo
from dynamo.core.bot import Dynamo, Interaction # noqa: F401


class ConfirmationView(View):
Expand All @@ -28,7 +28,7 @@ async def interaction_check(self, interaction: discord.Interaction) -> bool:
"""Check if the interaction is from the author of the view"""
return bool(interaction.user and interaction.user.id == self.author_id)

async def _defer_and_stop(self, interaction: discord.Interaction[Dynamo]) -> None:
async def _defer_and_stop(self, interaction: Interaction) -> None:
"""Defer the interaction and stop the view."""
await interaction.response.defer()
if self.delete_after and self.message:
Expand All @@ -45,19 +45,19 @@ async def on_timeout(self) -> None:
await self.message.delete()

@discord.ui.button(label="Confirm", style=discord.ButtonStyle.green)
async def confirm[V: View](self, interaction: discord.Interaction[Dynamo], button: discord.ui.Button[V]) -> None:
async def confirm[V: View](self, interaction: Interaction, button: discord.ui.Button[V]) -> None:
"""Confirm the action"""
self.value = True
await self._defer_and_stop(interaction)

@discord.ui.button(label="Cancel", style=discord.ButtonStyle.red)
async def cancel[V: View](self, interaction: discord.Interaction[Dynamo], button: discord.ui.Button[V]) -> None:
async def cancel[V: View](self, interaction: Interaction, button: discord.ui.Button[V]) -> None:
"""Cancel the action"""
await self._defer_and_stop(interaction)


class Context(commands.Context["Dynamo"]):
bot: Dynamo
interaction: Interaction | None

class Status(StrEnum):
"""Status emojis for the bot"""
Expand Down
Loading

0 comments on commit df277a6

Please sign in to comment.