Skip to content

Commit

Permalink
New cache again
Browse files Browse the repository at this point in the history
  • Loading branch information
trumully committed Sep 25, 2024
1 parent 767e38b commit b8ec870
Show file tree
Hide file tree
Showing 19 changed files with 308 additions and 165 deletions.
55 changes: 51 additions & 4 deletions dynamo/_typing.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,59 @@
from collections.abc import Callable, Coroutine
from typing import Any, ParamSpec, TypeVar
from collections.abc import Callable, Coroutine, Mapping
from typing import Annotated, Any, ParamSpec, TypeVar

import numpy as np
from discord import app_commands
from discord.ext import commands
from discord.ui import View

P = ParamSpec("P")

T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
T_contra = TypeVar("T_contra", contravariant=True)

S = TypeVar("S", bound=object)
S_co = TypeVar("S_co", bound=object, covariant=True)

AsyncCallable = TypeVar("AsyncCallable", bound=Callable[P, Coroutine[Any, Any, T]])
ArrayRGB = Annotated[np.ndarray[Any, Any], tuple[int, int, int]]
AC = TypeVar("AC", bound=Callable[..., Coroutine[Any, Any, Any]])
CogT = TypeVar("CogT", bound=commands.Cog)
CommandT = TypeVar("CommandT", bound=commands.Command[CogT, P, T])
CommandT = TypeVar("CommandT", bound=commands.Command[Any, ..., Any])
ContextT = TypeVar("ContextT", bound=commands.Context[Any], covariant=True)

V = TypeVar("V", bound="View", covariant=True)


class NotFoundWithHelp(commands.CommandError): ...


command_error_messages: Mapping[type[commands.CommandError], str] = {
commands.CommandNotFound: "Command not found: **`{}`**{}",
NotFoundWithHelp: "Command not found: **`{}`**{}",
commands.MissingRequiredArgument: "Missing required argument: `{}`.",
commands.BadArgument: "Bad argument.",
commands.CommandOnCooldown: "You are on cooldown. Try again in `{:.2f}` seconds.",
commands.TooManyArguments: "Too many arguments.",
commands.MissingPermissions: "You are not allowed to use this command.",
commands.BotMissingPermissions: "I am not allowed to use this command.",
commands.NoPrivateMessage: "This command can only be used in a server.",
commands.NotOwner: "You are not the owner of this bot.",
commands.DisabledCommand: "This command is disabled.",
commands.CheckFailure: "You do not have permission to use this command.",
}

app_command_error_messages: Mapping[type[app_commands.AppCommandError], str] = {
app_commands.CommandNotFound: "Command not found: **`{}`**{}",
app_commands.CommandOnCooldown: "You are on cooldown. Try again in `{:.2f}` seconds.",
app_commands.MissingPermissions: "You are not allowed to use this command.",
app_commands.BotMissingPermissions: "I am not allowed to use this command.",
app_commands.NoPrivateMessage: "This command can only be used in a server.",
app_commands.CheckFailure: "You do not have permission to use this command.",
}


class _MISSING:
__slots__ = ()


MISSING: Any = _MISSING()
2 changes: 1 addition & 1 deletion dynamo/core/base_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class DynamoCog(commands.Cog):
__slots__ = ("bot", "log")

def __init__(self, bot: Dynamo) -> None:
def __init__(self, bot: Dynamo, case_insensitive: bool = True) -> None:
self.bot: Dynamo = bot
self.log = logging.getLogger(get_cog(self.__class__.__name__))

Expand Down
41 changes: 24 additions & 17 deletions dynamo/core/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from discord import app_commands
from discord.ext import commands

from dynamo._typing import CommandT
from dynamo.utils.context import Context
from dynamo.utils.emoji import Emojis
from dynamo.utils.helper import get_cog, platformdir, resolve_path_with_links
Expand All @@ -34,7 +33,7 @@

class VersionableTree(app_commands.CommandTree["Dynamo"]):
application_commands: dict[int | None, list[app_commands.AppCommand]]
cache: dict[int | None, dict[CommandT | str, str]]
cache: dict[int | None, dict[commands.Command[Any, ..., Any] | app_commands.Command[Any, ..., Any] | str, str]]

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
Expand All @@ -60,7 +59,7 @@ async def sync(self, *, guild: discord.abc.Snowflake | None = None) -> list[app_
self.cache.pop(guild_id, None)
return result

async def fetch_commands(self, guild: discord.abc.Snowflake | None = None) -> list[app_commands.AppCommand]:
async def fetch_commands(self, *, guild: discord.abc.Snowflake | None = None) -> list[app_commands.AppCommand]:
result = await super().fetch_commands(guild=guild)
guild_id = guild.id if guild else None
self.application_commands[guild_id] = result
Expand All @@ -74,7 +73,10 @@ async def get_or_fetch_commands(self, guild: discord.abc.Snowflake | None = None
return await self.fetch_commands(guild=guild)

async def find_mention_for(
self, command: CommandT | str, *, guild: discord.abc.Snowflake | None = None
self,
command: commands.Command[Any, ..., Any] | app_commands.Command[Any, ..., Any] | str,
*,
guild: discord.abc.Snowflake | None = None,
) -> str | None:
guild_id = guild.id if guild else None
try:
Expand All @@ -90,7 +92,7 @@ async def find_mention_for(
if check_global and not _command:
_command = discord.utils.get(self.walk_commands(), qualified_name=command)
else:
_command = cast(app_commands.Command, command)
_command = cast(app_commands.Command[Any, ..., Any], command)

if not _command:
return None
Expand All @@ -110,25 +112,26 @@ async def find_mention_for(
self.cache[guild_id][command] = mention
return mention

def _walk_children(
self, commands: list[app_commands.Group | app_commands.Command]
) -> Generator[app_commands.Command, None, None]:
def _walk_children[AppCommand: (app_commands.Command[Any, ..., Any], app_commands.Group)](
self, commands: list[AppCommand | app_commands.Group]
) -> Generator[AppCommand, None, None]:
for command in commands:
if isinstance(command, app_commands.Group):
yield from self._walk_children(command.commands)
cmds: list[AppCommand] = cast(list[AppCommand], command.commands)
yield from self._walk_children(cmds)
else:
yield command

async def walk_mentions(
self, *, guild: discord.abc.Snowflake | None = None
) -> AsyncGenerator[tuple[app_commands.Command, str], None]:
) -> AsyncGenerator[tuple[app_commands.Command[Any, ..., Any], str], None]:
for command in self._walk_children(self.get_commands(guild=guild, type=discord.AppCommandType.chat_input)):
mention = await self.find_mention_for(cast(CommandT, command), guild=guild)
mention = await self.find_mention_for(command, guild=guild)
if mention:
yield command, mention
if guild and self.fallback_to_global is True:
for command in self._walk_children(self.get_commands(guild=None, type=discord.AppCommandType.chat_input)):
mention = await self.find_mention_for(cast(CommandT, command), guild=guild)
mention = await self.find_mention_for(command, guild=guild)
if mention:
yield command, mention
else:
Expand All @@ -147,7 +150,6 @@ def _prefix_callable(bot: Dynamo, msg: discord.Message) -> list[str]:

class Dynamo(commands.AutoShardedBot):
session: aiohttp.ClientSession
user: discord.ClientUser
context: Context
logging_handler: Any
bot_app_info: discord.AppInfo
Expand Down Expand Up @@ -186,9 +188,6 @@ async def setup_hook(self) -> None:
self.bot_app_info = await self.application_info()
self.owner_id = self.bot_app_info.owner.id

# Case insensitive cogs for help commands.
self._BotBase__cogs = commands.core._CaseInsensitiveDict()

self.app_emojis = Emojis(await self.fetch_application_emojis())

for ext in initial_extensions:
Expand All @@ -212,6 +211,14 @@ async def setup_hook(self) -> None:
def owner(self) -> discord.User:
return self.bot_app_info.owner

@property
def user(self) -> discord.ClientUser:
return cast(discord.ClientUser, super().user)

@property
def tree(self) -> VersionableTree:
return cast(VersionableTree, super().tree)

@property
def dev_guild(self) -> discord.Guild:
return cast(discord.Guild, discord.Object(id=681408104495448088, type=discord.Guild))
Expand All @@ -231,7 +238,7 @@ async def on_ready(self) -> None:

async def get_context(
self,
origin: discord.Message | discord.Interaction[Dynamo],
origin: discord.Message | discord.Interaction,
/,
*,
cls: type[Context] = Context,
Expand Down
1 change: 1 addition & 0 deletions dynamo/core/logging.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import logging.handlers
import queue
from collections.abc import Generator
from contextlib import contextmanager
Expand Down
8 changes: 2 additions & 6 deletions dynamo/extensions/cogs/dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ def __init__(self, bot: Dynamo) -> None:
async def sync(
self,
ctx: Context,
guild: discord.Guild = commands.param(
converter=GuildConverter, default=lambda ctx: ctx.guild, displayed_name="guild_id"
),
guild: discord.Guild = commands.param(converter=GuildConverter, displayed_name="guild_id"),
copy: bool = False,
) -> None:
"""Sync slash commands
Expand Down Expand Up @@ -56,9 +54,7 @@ async def sync_global(self, ctx: Context) -> None:
async def clear_commands(
self,
ctx: Context,
guild: discord.Guild = commands.param(
converter=GuildConverter, default=lambda ctx: ctx.guild, displayed_name="guild_id"
),
guild: discord.Guild = commands.param(converter=GuildConverter, displayed_name="guild_id"),
) -> None:
"""Clear all slash commands
Expand Down
17 changes: 9 additions & 8 deletions dynamo/extensions/cogs/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from discord.ext import commands
from rapidfuzz import fuzz

from dynamo._typing import NotFoundWithHelp, app_command_error_messages, command_error_messages
from dynamo.core import Dynamo, DynamoCog
from dynamo.utils.context import Context
from dynamo.utils.error_types import NotFoundWithHelp, app_command_error_messages, command_error_messages


class Errors(DynamoCog):
Expand Down Expand Up @@ -74,7 +74,8 @@ async def on_command_error(self, ctx: Context, error: commands.CommandError) ->
error_message = self.get_command_error_message(error)

if isinstance(error, (commands.CommandNotFound, NotFoundWithHelp)):
trigger = ctx.invoked_with if isinstance(error, commands.CommandNotFound) else error.args[0]
invoked = ctx.invoked_with
trigger: str = invoked if invoked and isinstance(error, commands.CommandNotFound) else error.args[0]

matches = [
f"**{command.qualified_name}** - {command.short_doc or 'No description provided'}"
Expand Down Expand Up @@ -107,11 +108,11 @@ async def on_app_command_error(self, interaction: Interaction, error: app_comman
error : app_commands.AppCommandError
The exception.
"""
if interaction.command is None:
self.log.error("Command not found: %s.", interaction.data)
command_name = interaction.data.get("name", "")
if (command := interaction.command) is None:
command_name: str = "Unknown" if interaction.data is None else interaction.data.get("name", "Unknown")
self.log.error("Command not found: %s.", command_name)
matches = [
command for command in self.bot.tree.get_commands() if fuzz.ratio(command_name, command.name) > 70
str(command) for command in self.bot.tree.get_commands() if fuzz.ratio(command_name, command.name) > 70
]
msg = f"Command not found: '{command_name}'"
if matches:
Expand All @@ -123,12 +124,12 @@ async def on_app_command_error(self, interaction: Interaction, error: app_comman
)
return

self.log.error("%s called by %s raised an exception: %s.", interaction.command.name, interaction.user, error)
self.log.error("%s called by %s raised an exception: %s.", command.name, interaction.user, error)

error_message = self.get_app_command_error_message(error)

if isinstance(error, app_commands.CommandNotFound):
error_message = error_message.format(interaction.command.name)
error_message = error_message.format(command.name)

elif isinstance(error, app_commands.CommandOnCooldown):
error_message = error_message.format(error.retry_after)
Expand Down
30 changes: 23 additions & 7 deletions dynamo/extensions/cogs/events.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
from __future__ import annotations

import contextlib
from copy import copy
from typing import Any

import discord
from discord.ext import commands

from dynamo._typing import V
from dynamo.core import Dynamo, DynamoCog
from dynamo.utils.cache import async_cache
from dynamo.utils.context import Context
from dynamo.utils.format import shorten_string


class EventsDropdown(discord.ui.Select):
class EventsDropdown(discord.ui.Select[V]):
"""Base dropdown for selecting an event. Functionality can be defined with callback."""

def __init__(self, events: list[discord.ScheduledEvent], *args: Any, **kwargs: Any) -> None:
self.events: list[discord.ScheduledEvent] = events

options = [
discord.SelectOption(label=e.name, value=str(e.id), description=shorten_string(e.description))
discord.SelectOption(label=e.name, value=str(e.id), description=shorten_string(e.description or "..."))
for e in events
]

Expand All @@ -27,28 +31,36 @@ def __init__(self, events: list[discord.ScheduledEvent], *args: Any, **kwargs: A
class EventsView(discord.ui.View):
"""View for selecting an event"""

message: discord.Message

def __init__(
self,
author_id: int,
events: list[discord.ScheduledEvent],
dropdown: type[EventsDropdown],
dropdown: type[EventsDropdown[EventsView]],
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.author_id: int = author_id
self.add_item(dropdown(events))

@property
def __children(self) -> list[EventsDropdown[EventsView]]:
return getattr(self, "_children", []).copy()

async def interaction_check(self, interaction: discord.Interaction) -> bool:
return bool(interaction.user and interaction.user.id == self.author_id)

async def on_timeout(self) -> None:
for item in self.children:
item.disabled = True
for item in self.__children:
new_item = copy(item)
new_item.disabled = True
self.add_item(item)
await self.message.edit(view=self)


class InterestedDropdown(EventsDropdown):
class InterestedDropdown(EventsDropdown[EventsView]):
async def callback(self, interaction: discord.Interaction) -> None:
event = next((e for e in self.events if e.id == int(self.values[0])), None)
await interaction.response.send_message(await get_interested(event) or "No users found", ephemeral=True)
Expand All @@ -69,6 +81,7 @@ def __init__(self, bot: Dynamo) -> None:
super().__init__(bot)

async def fetch_events(self, guild: discord.Guild) -> list[discord.ScheduledEvent]:
events: list[discord.ScheduledEvent] = []
try:
events = await guild.fetch_scheduled_events(with_counts=False)
except discord.HTTPException:
Expand Down Expand Up @@ -97,9 +110,12 @@ async def event(self, ctx: Context, event: int | None = None) -> None:
event: int | None, optional
The event ID to get attendees of
"""
if ctx.guild is None:
return

message = await ctx.send(f"{self.bot.app_emojis.get('loading2', '⏳')}\tFetching events...")

event_check = await self.event_check(ctx.guild, event)
event_check: str | list[discord.ScheduledEvent] = await self.event_check(ctx.guild, event)
if isinstance(event_check, str):
await message.edit(content=event_check)
await message.delete(delay=10)
Expand Down
Loading

0 comments on commit b8ec870

Please sign in to comment.