Skip to content

Commit

Permalink
Add lazy loading
Browse files Browse the repository at this point in the history
  • Loading branch information
trumully committed Oct 5, 2024
1 parent 3408b1f commit 6bfbccc
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 229 deletions.
42 changes: 37 additions & 5 deletions dynamo/core/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import logging
import re
from collections.abc import AsyncGenerator, Generator
from importlib import import_module
from pathlib import Path
from typing import Any, cast

import aiohttp
Expand Down Expand Up @@ -199,12 +201,10 @@ async def setup_hook(self) -> None:
self.owner_id = self.bot_app_info.owner.id

self.app_emojis = Emojis(await self.fetch_application_emojis())
self.cog_file_times: dict[str, float] = {}

for extension in initial_extensions:
try:
await self.load_extension(extension)
except commands.ExtensionError:
log.exception("Failed to load extension %s", extension)
await self.load_extension_with_timestamp(extension)

tree_path = resolve_path_with_links(platformdir.user_cache_path / "tree.hash")
tree_hash = await self.tree.get_hash(self.tree)
Expand All @@ -217,6 +217,38 @@ async def setup_hook(self) -> None:
fp.seek(0)
fp.write(tree_hash)

def get_cog_path(self, cog: str) -> Path:
module = import_module(cog)
if module.__file__ is None:
error = f"Could not determine file path for cog {cog}"
log.exception(error)
raise RuntimeError(error)
return Path(module.__file__)

async def load_extension_with_timestamp(self, extension: str) -> None:
try:
await self.load_extension(extension)
self.cog_file_times[extension] = self.get_cog_path(extension).lstat().st_mtime
except commands.ExtensionError:
log.exception("Failed to load extension %s", extension)

async def lazy_load_cog(self, cog_name: str) -> None:
"""Lazily load a cog if it has been modified."""
cog_path = self.get_cog_path(cog_name)
current_mtime = cog_path.lstat().st_mtime
if current_mtime > self.cog_file_times.get(cog_name, 0):
try:
await self.reload_extension(cog_name)
self.cog_file_times[cog_name] = current_mtime
log.info("Reloaded modified cog: %s", cog_name)
except commands.ExtensionError:
log.exception("Failed to reload cog: %s", cog_name)

@commands.Cog.listener()
async def on_command_completion(self, ctx: Context) -> None:
if ctx.cog:
await self.lazy_load_cog(ctx.cog.__module__)

@property
def owner(self) -> discord.User:
return self.bot_app_info.owner
Expand All @@ -238,7 +270,7 @@ async def start(self, token: str, *, reconnect: bool = True) -> None:

async def close(self) -> None:
await self.session.close()
return await super().close()
await super().close()

async def on_ready(self) -> None:
if not hasattr(self, "uptime"):
Expand Down
156 changes: 67 additions & 89 deletions dynamo/extensions/cogs/dev.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import contextlib
import importlib
import sys
from collections.abc import Callable
from functools import partial
from typing import Literal
from collections.abc import AsyncGenerator, Callable
from typing import Literal, cast

import discord
from discord.ext import commands
Expand All @@ -23,124 +23,85 @@ class Dev(Cog):

def __init__(self, bot: Dynamo) -> None:
super().__init__(bot)
self.try_load_extension = partial(self._cog_try, self.bot.load_extension)
self.try_unload_extension = partial(self._cog_try, self.bot.unload_extension)

async def _cog_try(self, coro: Callable[[str], Coro[None]], cog: str) -> bool:
async def _execute_extension_action(self, action: Callable[[str], Coro[None]], cog: str) -> bool:
try:
await coro(cog)
await action(get_cog(cog))
except commands.ExtensionError:
self.log.exception("Coroutine %s failed for cog %s", coro.__name__, cog)
self.log.exception("Extension %s failed for cog %s", action.__name__, cog)
return False
return True

@commands.hybrid_command(name="sync", aliases=("s",))
@commands.guild_only()
@is_owner()
async def sync(self, ctx: Context, guilds: commands.Greedy[discord.Object], spec: SyncSpec | None = None) -> None:
"""Sync application commands globally or with guilds
Parameters
----------
guilds: commands.Greedy[discord.Object]
The guilds to sync the commands to
spec: SyncSpec | None, optional
The sync specification, by default None
See
---
- https://about.abstractumbra.dev/discord.py/2023/01/29/sync-command-example.html
"""
"""Sync application commands globally or with guilds"""
if not ctx.guild:
return

if not guilds:
synced = await self._sync_commands(ctx.guild, spec)
scope = "globally" if spec is None else "to the current guild"
await ctx.send(f"Synced {len(synced)} commands {scope}.")
return

success = await self._sync_to_guilds(guilds)
await ctx.send(f"Synced the tree to {success}/{len(guilds)} guilds.")
else:
success = await self._sync_to_guilds(guilds)
await ctx.send(f"Synced the tree to {success}/{len(guilds)} guilds.")

async def _sync_commands(
self, guild: discord.Guild, spec: SyncSpec | None
) -> list[discord.app_commands.AppCommand]:
# This will sync all guild commands for the current context's guild.
if spec == "~":
return await self.bot.tree.sync(guild=guild)
# This will copy all global commands to the current guild (within the CommandTree) and syncs.
if spec == "*":
self.bot.tree.copy_global_to(guild=guild)
return await self.bot.tree.sync(guild=guild)
# This command will remove all guild commands from the CommandTree and syncs,
# which effectively removes all commands from the guild.
if spec == "^":
self.bot.tree.clear_commands(guild=guild)
await self.bot.tree.sync(guild=guild)
return []
# This takes all global commands within the CommandTree and sends them to Discord
return await self.bot.tree.sync()

async def _sync_to_guilds(self, guilds: commands.Greedy[discord.Object]) -> int:
success = 0
for guild in guilds:
try:
await self.bot.tree.sync(guild=guild)
success += 1
except discord.HTTPException:
self.log.exception("Failed to sync guild %s", guild.id)
return success
async with contextlib.aclosing(cast(AsyncGenerator[discord.Guild], guilds)) as gen:
results: list[bool] = [await self._sync_guild(guild) async for guild in gen]
return sum(results)

async def _sync_guild(self, guild: discord.Guild) -> bool:
try:
await self.bot.tree.sync(guild=guild)
except discord.HTTPException:
self.log.exception("Failed to sync guild %s", guild.id)
return False
return True

@commands.hybrid_command(name="load", aliases=("l",))
@is_owner()
async def load(self, ctx: Context, *, module: str) -> None:
"""Load a cog
Parameters
----------
module: str
The name of the cog to load.
"""
success = await self.try_load_extension(get_cog(module))
"""Load a cog"""
success = await self._execute_extension_action(self.bot.load_extension, module)
await ctx.message.add_reaction(ctx.Status.OK if success else ctx.Status.FAILURE)

@commands.hybrid_command(aliases=("ul",))
@is_owner()
async def unload(self, ctx: Context, *, module: str) -> None:
"""Unload a cog
Parameters
----------
module: str
The name of the cog to unload.
"""
success = await self.try_unload_extension(get_cog(module))
"""Unload a cog"""
success = await self._execute_extension_action(self.bot.unload_extension, module)
await ctx.message.add_reaction(ctx.Status.OK if success else ctx.Status.FAILURE)

@commands.hybrid_group(name="reload", aliases=("r",), invoke_without_command=True)
@is_owner()
async def _reload(self, ctx: Context, *, module: str) -> None:
"""Reload a cog.
Parameters
----------
module: str
The name of the cog to reload.
"""
success = await self.try_reload_extension(get_cog(module))
"""Reload a cog."""
success = await self._reload_extension(module)
await ctx.message.add_reaction(ctx.Status.OK if success else ctx.Status.FAILURE)

async def try_reload_extension(self, module: str) -> bool:
async def _reload_extension(self, module: str) -> bool:
try:
await self.bot.reload_extension(module)
except commands.ExtensionNotLoaded:
self.log.exception("%s is not loaded. Attempting to load...", module)
try:
await self.bot.load_extension(module)
except commands.ExtensionError:
self.log.exception("Failed to load %s", module)
return False
self.log.warning("Extension %s is not loaded. Attempting to load...", module)
return await self._execute_extension_action(self.bot.load_extension, module)
return True

@_reload.command(name="all")
Expand All @@ -150,26 +111,43 @@ async def _reload_all(self, ctx: Context) -> None:
if not await ctx.prompt("Are you sure you want to reload all cogs?"):
return

# Reload all pre-existing modules from the utils folder
utils_modules: frozenset[str] = frozenset(mod for mod in sys.modules if mod.startswith("dynamo.utils."))
fail = 0
for module in utils_modules:
try:
importlib.reload(sys.modules[module])
except (KeyError, ModuleNotFoundError):
fail += 1
self.log.exception("Failed to reload %s", module)
self.log.debug("Reloaded %d/%d utilities", len(utils_modules) - fail, len(utils_modules))

extensions = frozenset(self.bot.extensions)
statuses: set[tuple[ctx.Status, str]] = set()
for ext in extensions:
success = await self.try_reload_extension(ext)
statuses.add((ctx.Status.SUCCESS if success else ctx.Status.FAILURE, ext))

success_count = sum(1 for status, _ in statuses if status == ctx.Status.SUCCESS)
self.log.debug("Reloaded %d/%d extensions", success_count, len(extensions))
await ctx.send("\n".join(f"{status} `{ext}`" for status, ext in statuses))
utils_modules = self._reload_utils_modules()
extensions_status = await self._reload_all_extensions()

await ctx.send(self._format_reload_results(utils_modules, extensions_status))

def _reload_utils_modules(self) -> tuple[int, int]:
utils_modules = [mod for mod in sys.modules if mod.startswith("dynamo.utils.")]
success = sum(self._reload_module(mod) for mod in utils_modules)
return success, len(utils_modules)

def _reload_module(self, module: str) -> bool:
try:
importlib.reload(sys.modules[module])
except (KeyError, ModuleNotFoundError, NameError):
self.log.exception("Failed to reload %s", module)
return False
return True

async def _reload_all_extensions(self) -> list[tuple[Context.Status, str]]:
extensions = list(self.bot.extensions)
success = Context.Status.SUCCESS
failure = Context.Status.FAILURE
return [(success if await self._reload_extension(ext) else failure, ext) for ext in extensions]

def _format_reload_results(
self, utils_result: tuple[int, int], extensions_status: list[tuple[Context.Status, str]]
) -> str:
utils_success, utils_total = utils_result
extensions_success = sum(1 for status, _ in extensions_status if status == Context.Status.SUCCESS)
extensions_total = len(extensions_status)

result = [
f"Reloaded {utils_success}/{utils_total} utilities",
f"Reloaded {extensions_success}/{extensions_total} extensions",
"\n".join(f"{status} `{ext}`" for status, ext in extensions_status),
]
return "\n".join(result)

@commands.hybrid_command(name="quit", aliases=("exit", "shutdown", "q"))
@is_owner()
Expand Down
9 changes: 7 additions & 2 deletions dynamo/extensions/cogs/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from dynamo.core import Cog, Dynamo
from dynamo.utils.context import Context
from dynamo.utils.converter import MemberLikeConverter
from dynamo.utils.time_utils import format_relative, human_timedelta

PYTHON = "https://s3.dualstack.us-east-2.amazonaws.com/pythondotorg-assets/media/community/logos/python-logo-only.png"
Expand Down Expand Up @@ -75,12 +76,16 @@ async def about(self, ctx: Context) -> None:
await ctx.send(embed=embed, ephemeral=True)

@commands.hybrid_command(name="avatar")
async def avatar(self, ctx: Context, user: discord.Member | discord.User | None = None) -> None:
async def avatar(
self,
ctx: Context,
user: discord.Member | discord.User | None = commands.param(default=None, converter=MemberLikeConverter),
) -> None:
"""Get the avatar of a user
Parameters
----------
user: discord.Member | discord.User
user: discord.Member | discord.User | None
The user to get the avatar of
"""
await ctx.send(embed=embed_from_user(ctx.author if user is None else user), ephemeral=True)
Expand Down
Loading

0 comments on commit 6bfbccc

Please sign in to comment.