Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stricter typing with mypy #19

Merged
merged 6 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,17 @@ jobs:
cache: 'poetry'
- run: poetry install --with=dev
- run: poetry run pytest tests --asyncio-mode=strict -n logical

type-checks:
name: Type check with mypy
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- name: Install poetry
run: pipx install poetry
- uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: 'poetry'
- run: poetry install --with=dev
- run: poetry run mypy --pretty --config-file ./pyproject.toml dynamo
10 changes: 4 additions & 6 deletions dynamo/_evt_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,20 @@


def get_event_loop_policy() -> asyncio.AbstractEventLoopPolicy:
policy = asyncio.DefaultEventLoopPolicy

if sys.platform in ("win32", "cygwin", "cli"):
try:
import winloop
except ImportError:
policy = asyncio.WindowsSelectorEventLoopPolicy
return asyncio.WindowsSelectorEventLoopPolicy()
else:
policy = winloop.EventLoopPolicy
return winloop.EventLoopPolicy()

else:
try:
import uvloop
except ImportError:
pass
else:
policy = uvloop.EventLoopPolicy
return uvloop.EventLoopPolicy()

return policy()
return asyncio.DefaultEventLoopPolicy()
59 changes: 35 additions & 24 deletions dynamo/bot.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import logging
from collections.abc import AsyncGenerator, Coroutine, Generator
from typing import Any
from collections.abc import AsyncGenerator, Generator
from typing import Any, Generic, TypeVar, cast

import aiohttp
import discord
Expand All @@ -27,12 +27,21 @@
Quantum entanglement.
"""

CogT = TypeVar("CogT", bound=commands.Cog)
CommandT = TypeVar(
"CommandT",
bound=commands.Command[Any, ..., Any] | app_commands.AppCommand | commands.HybridCommand,
)


class VersionableTree(app_commands.CommandTree["Dynamo"], Generic[CommandT]):
application_commands: dict[int | None, list[app_commands.AppCommand]]
cache: dict[int | None, dict[CommandT | str, str]]

class VersionableTree(app_commands.CommandTree["Dynamo"]):
def __init__(self, *args: tuple[Any, ...], **kwargs: dict[str, Any]) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.application_commands: dict[int | None, list[app_commands.AppCommand]] = {}
self.cache: dict[int | None, dict[app_commands.AppCommand | commands.HybridCommand | str, str]] = {}
self.application_commands = {}
self.cache = {}

async def get_hash(self, tree: app_commands.CommandTree[Dynamo]) -> bytes:
"""Get the hash of the command tree.
Expand All @@ -57,9 +66,7 @@ async def get_hash(self, tree: app_commands.CommandTree[Dynamo]) -> bytes:
return xxhash.xxh64_digest(msgspec.msgpack.encode(payload), seed=0)

# See: https://gist.github.com/LeoCx1000/021dc52981299b95ea7790416e4f5ca4#file-mentionable_tree-py
async def sync(
self, *, guild: discord.abc.Snowflake | None = None
) -> Coroutine[Any, Any, list[app_commands.AppCommand]]:
async def sync(self, *, guild: discord.abc.Snowflake | None = None) -> list[app_commands.AppCommand]:
result = await super().sync(guild=guild)
guild_id = guild.id if guild else None
self.application_commands[guild_id] = result
Expand All @@ -80,10 +87,7 @@ async def get_or_fetch_commands(self, guild: discord.abc.Snowflake | None = None
return await self.fetch_commands(guild=guild)

async def find_mention_for(
self,
command: app_commands.AppCommand | commands.HybridCommand | str,
*,
guild: discord.abc.Snowflake | None = None,
self, command: CommandT | str, *, guild: discord.abc.Snowflake | None = None
) -> str | None:
guild_id = guild.id if guild else None
try:
Expand All @@ -99,7 +103,7 @@ async def find_mention_for(
if check_global and not _command:
_command = discord.utils.get(self.walk_commands(), qualified_name=command)
else:
_command = command
_command = cast(app_commands.Command, command)

if not _command:
return None
Expand Down Expand Up @@ -130,14 +134,14 @@ def _walk_children(

async def walk_mentions(
self, *, guild: discord.abc.Snowflake | None = None
) -> AsyncGenerator[tuple[app_commands.Command, str], None, None]:
) -> AsyncGenerator[tuple[app_commands.Command, str], None]:
for command in self._walk_children(self.get_commands(guild=guild, type=discord.AppCommandType.chat_input)):
mention = await self.find_mention_for(command, guild=guild)
mention = await self.find_mention_for(cast(CommandT, command), guild=guild)
if mention:
yield command, mention
if guild and self.fallback_to_global is True:
for command in self._walk_children(self.get_commands(guild=None, type=discord.AppCommandType.chat_input)):
mention = await self.find_mention_for(command, guild=guild)
mention = await self.find_mention_for(cast(CommandT, command), guild=guild)
if mention:
yield command, mention
else:
Expand All @@ -157,11 +161,11 @@ def _prefix_callable(bot: Dynamo, msg: discord.Message) -> list[str]:
class Dynamo(commands.AutoShardedBot):
session: aiohttp.ClientSession
user: discord.ClientUser
context: Context
logging_handler: Any
bot_app_info: discord.AppInfo
tree: VersionableTree

def __init__(self, session: aiohttp.ClientSession, *args: tuple[Any, ...], **kwargs: dict[str, Any]) -> None:
def __init__(self, connector: aiohttp.TCPConnector, session: aiohttp.ClientSession) -> None:
self.session = session
allowed_mentions = discord.AllowedMentions(roles=False, everyone=False, users=True)
intents = discord.Intents(
Expand All @@ -172,7 +176,7 @@ def __init__(self, session: aiohttp.ClientSession, *args: tuple[Any, ...], **kwa
presences=True,
)
super().__init__(
*args,
connector=connector,
command_prefix=_prefix_callable,
description=description,
pm_help=None,
Expand All @@ -183,7 +187,6 @@ def __init__(self, session: aiohttp.ClientSession, *args: tuple[Any, ...], **kwa
intents=intents,
enable_debug_events=True,
tree_cls=VersionableTree,
**kwargs,
)

async def setup_hook(self) -> None:
Expand Down Expand Up @@ -213,13 +216,17 @@ async def setup_hook(self) -> None:
fp.seek(0)
fp.write(tree_hash)

@property
def tree(self) -> VersionableTree:
return self.tree

@property
def owner(self) -> discord.User:
return self.bot_app_info.owner

@property
def dev_guild(self) -> discord.Guild:
return discord.Object(id=681408104495448088, type=discord.Guild)
return cast(discord.Guild, discord.Object(id=681408104495448088, type=discord.Guild))

async def start(self, token: str, *, reconnect: bool = True) -> None:
return await super().start(token, reconnect=reconnect)
Expand All @@ -234,7 +241,11 @@ async def on_ready(self) -> None:

log.info("Ready: %s (ID: %s)", self.user, self.user.id)

async def get_context(
self, origin: discord.Interaction | discord.Message, /, *, cls: type[Context] = Context
async def get_context( # type: ignore
self,
origin: discord.Message | discord.Interaction[Dynamo],
/,
*,
cls: type[Context] = Context,
) -> Context:
return await super().get_context(origin, cls=cls)
32 changes: 19 additions & 13 deletions dynamo/extensions/cogs/dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from dynamo.bot import Dynamo
from dynamo.utils.cache import cached_functions
from dynamo.utils.context import Status
from dynamo.utils.context import Context, Status
from dynamo.utils.converter import GuildConverter
from dynamo.utils.helper import ROOT, get_cog

log = logging.getLogger(__name__)
Expand All @@ -25,11 +26,16 @@ class Dev(commands.GroupCog, group_name="dev"):
def __init__(self, bot: Dynamo) -> None:
self.bot: Dynamo = bot

async def cog_check(self, ctx: commands.Context) -> bool:
async def cog_check(self, ctx: commands.Context) -> bool: # type: ignore[override]
return await self.bot.is_owner(ctx.author)

@commands.hybrid_group(invoke_without_command=True, name="sync", aliases=("s",))
async def sync(self, ctx: commands.Context, guild_id: int | None, copy: bool = False) -> None:
async def sync(
self,
ctx: commands.Context,
guild: discord.Guild = commands.param(converter=GuildConverter, default=None, displayed_name="guild_id"),
copy: bool = False,
) -> None:
"""Sync slash commands

Parameters
Expand All @@ -39,8 +45,6 @@ async def sync(self, ctx: commands.Context, guild_id: int | None, copy: bool = F
copy: bool
Copy global commands to the specified guild. (Default: False)
"""
guild: discord.Guild = discord.Object(id=guild_id, type=discord.Guild) if guild_id else ctx.guild

if copy:
self.bot.tree.copy_global_to(guild=guild)

Expand All @@ -54,7 +58,11 @@ async def sync_global(self, ctx: commands.Context) -> None:
await ctx.send(f"Successfully synced {len(commands)} commands")

@sync.command(name="clear", aliases=("c",))
async def clear_commands(self, ctx: commands.Context, guild_id: int | None) -> None:
async def clear_commands(
self,
ctx: Context,
guild: discord.Guild = commands.param(converter=GuildConverter, default=None, displayed_name="guild_id"),
) -> None:
"""Clear all slash commands

Parameters
Expand All @@ -65,8 +73,6 @@ async def clear_commands(self, ctx: commands.Context, guild_id: int | None) -> N
if not await ctx.prompt("Are you sure you want to clear all commands?"):
return

guild: discord.Guild | None = discord.Object(id=guild_id, type=discord.Guild) if guild_id else None

self.bot.tree.clear_commands(guild=guild)
await ctx.send("Successfully cleared all commands")

Expand Down Expand Up @@ -125,7 +131,7 @@ async def reload_or_load_extension(self, module: str) -> None:
await self.bot.load_extension(module)

@_reload.command(name="all")
async def _reload_all(self, ctx: commands.Context) -> None:
async def _reload_all(self, ctx: Context) -> None:
"""Reload all cogs"""
confirm = await ctx.prompt("Are you sure you want to reload all cogs?")
if not confirm:
Expand All @@ -143,16 +149,16 @@ async def _reload_all(self, ctx: commands.Context) -> None:
log.exception("Failed to reload %s", module)
log.debug("Reloaded %d/%d utilities", len(utils_modules), len(all_utils))

extensions = self.bot.extensions.copy()
statuses: set[tuple[Status, str]] = []
extensions = set(self.bot.extensions)
statuses: set[tuple[Status, str]] = set()
for ext in extensions:
try:
await self.reload_or_load_extension(ext)
except commands.ExtensionError:
log.exception("Failed to reload extension %s", ext)
statuses.append((Status.FAILURE, ext))
statuses.add((Status.FAILURE, ext))
else:
statuses.append((Status.SUCCESS, ext))
statuses.add((Status.SUCCESS, ext))

success_count = sum(1 for status, _ in statuses if status == Status.SUCCESS)
log.debug("Reloaded %d/%d extensions", success_count, len(extensions))
Expand Down
18 changes: 9 additions & 9 deletions dynamo/extensions/cogs/events.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Any

import discord
from discord.ext import commands
Expand All @@ -10,12 +11,12 @@


class Dropdown(discord.ui.Select):
def __init__(self, events: list[discord.ScheduledEvent]) -> None:
def __init__(self, events: list[discord.ScheduledEvent], *args: Any, **kwargs: Any) -> None:
self.events: list[discord.ScheduledEvent] = events

options = [discord.SelectOption(label=e.name, value=str(e.id), description="An event") for e in events]

super().__init__(placeholder="Select an event", min_values=1, max_values=1, options=options)
super().__init__(*args, placeholder="Select an event", min_values=1, max_values=1, options=options, **kwargs)

async def callback(self, interaction: discord.Interaction) -> None:
if (event := next((e for e in self.events if str(e.id) == self.values[0]), None)) is None:
Expand All @@ -26,9 +27,9 @@ async def callback(self, interaction: discord.Interaction) -> None:


class DropdownView(discord.ui.View):
def __init__(self, events: list[discord.ScheduledEvent]) -> None:
def __init__(self, events: list[discord.ScheduledEvent], *args: Any, **kwargs: Any) -> None:
super().__init__()
self.add_item(Dropdown(events))
self.add_item(Dropdown(events, *args, **kwargs))


@async_lru_cache()
Expand All @@ -52,7 +53,7 @@ class Events(commands.Cog, name="Events"):
def __init__(self, bot: Dynamo) -> None:
self.bot: Dynamo = bot

async def cog_check(self, ctx: commands.Context) -> bool:
def cog_check(self, ctx: commands.Context) -> bool:
return ctx.guild is not None

@commands.hybrid_command(name="event")
Expand All @@ -66,20 +67,19 @@ async def event(self, ctx: commands.Context, event: int | None) -> None:
"""
if event is not None:
try:
ev = await ctx.guild.fetch_scheduled_event(event)
ev = await ctx.guild.fetch_scheduled_event(event) # type: ignore
except discord.NotFound:
await ctx.send(f"No event with id: {event}", ephemeral=True)
return
interested = await get_interested(ev)
await ctx.send(interested, ephemeral=True)
return

if not (events := await fetch_events(ctx.guild)):
if not (events := await fetch_events(ctx.guild)): # type: ignore
await ctx.send("No events found!", ephemeral=True)
return
view = DropdownView(events)
view.message = await ctx.send("Select an event", ephemeral=True, view=view)
await view.wait()
await ctx.send("Select an event", ephemeral=True, view=view)


async def setup(bot: Dynamo) -> None:
Expand Down
Loading