Skip to content

Commit

Permalink
Add Client.walk_commands
Browse files Browse the repository at this point in the history
  • Loading branch information
hypergonial committed Jan 8, 2024
1 parent 2f4ff2e commit f10fddc
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 5 deletions.
86 changes: 84 additions & 2 deletions arc/abc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
import alluka
import hikari

from arc.abc.command import _CommandSettings
from arc.abc.command import CallableCommandProto, CommandProto, _CommandSettings
from arc.abc.plugin import PluginBase
from arc.command.message import MessageCommand
from arc.command.slash import SlashCommand, SlashGroup
from arc.command.slash import SlashCommand, SlashGroup, SlashSubCommand
from arc.command.user import UserCommand
from arc.context import AutodeferMode, Context
from arc.errors import ExtensionLoadError, ExtensionUnloadError
Expand Down Expand Up @@ -380,6 +380,88 @@ async def on_autocomplete_interaction(

return await command._on_autocomplete(interaction)

@t.overload
def walk_commands(
self,
*,
callable_only: t.Literal[True] = True,
types: set[hikari.CommandType] = {
hikari.CommandType.SLASH,
hikari.CommandType.MESSAGE,
hikari.CommandType.USER,
},
) -> t.Iterator[CallableCommandProto[te.Self]]:
...

@t.overload
def walk_commands(
self,
*,
callable_only: t.Literal[False],
types: set[hikari.CommandType] = {
hikari.CommandType.SLASH,
hikari.CommandType.MESSAGE,
hikari.CommandType.USER,
},
) -> t.Iterator[CommandProto]:
...

def walk_commands( # noqa: C901
self,
*,
callable_only: bool = True,
types: set[hikari.CommandType] = {
hikari.CommandType.SLASH,
hikari.CommandType.MESSAGE,
hikari.CommandType.USER,
},
) -> t.Iterator[CommandProto | CallableCommandProto[te.Self]]:
"""Iterate over all commands added to this client.
Parameters
----------
callable_only : bool
Whether to only return commands that are directly callable.
If True, command groups and subgroups will be skipped.
types : set[hikari.CommandType]
The types of commands to return.
Yields
------
CommandBase[te.Self, t.Any]
The next command that matches the given criteria.
"""
if hikari.CommandType.SLASH in types:
for command in self.slash_commands.values():
if isinstance(command, SlashCommand):
yield command
continue

if not callable_only:
yield command

for sub in command.children.values():
if isinstance(sub, SlashSubCommand):
print("Yielded slash subcommand")
yield sub
continue

if not callable_only:
yield sub

for subsub in sub.children.values():
yield subsub

if hikari.CommandType.MESSAGE in types:
for command in self.message_commands.values():
print("Yielded message command")
yield command

if hikari.CommandType.USER in types:
for command in self.user_commands.values():
print("Yielded user command")
yield command

@t.overload
def include(self) -> t.Callable[[CommandBase[te.Self, BuilderT]], CommandBase[te.Self, BuilderT]]:
...
Expand Down
2 changes: 1 addition & 1 deletion arc/abc/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ class CallableCommandBase(CommandBase[ClientT, BuilderT], CallableCommandProto[C
callback: CommandCallbackT[ClientT]
"""The callback to invoke when this command is called."""

_invoke_task: asyncio.Task[t.Any] | None = attr.field(init=False, default=None)
_invoke_task: asyncio.Task[t.Any] | None = attr.field(init=False, default=None, repr=False)

def reset_all_limiters(self, context: Context[ClientT]) -> None:
"""Reset all limiters for this command.
Expand Down
4 changes: 3 additions & 1 deletion arc/command/slash.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,9 @@ def qualified_name(self) -> t.Sequence[str]:
@property
def parent(self) -> SlashGroup[ClientT] | SlashSubGroup[ClientT]:
"""The parent of this subcommand."""
return self.parent
if self._parent is None:
raise ValueError("Cannot get parent of subcommand without parent.")
return self._parent

@property
def command_type(self) -> hikari.CommandType:
Expand Down
3 changes: 2 additions & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ Here you can find all the changelogs for `hikari-arc`.

## Unreleased

- Add `Command.display_name`, `SlashCommand.make_mention`, `SlashSubCommand.make_mention`.
- Add `Client.walk_commands()`.
- Add `Command.display_name`, `SlashCommand.make_mention()`, `SlashSubCommand.make_mention()`.

## v0.5.0

Expand Down
65 changes: 65 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import hikari

import arc

bot = hikari.GatewayBot("...", banner=None)
client = arc.GatewayClient(bot)

group = client.include_slash_group(
"my_group", "My group description", default_permissions=hikari.Permissions.ADMINISTRATOR
)

subgroup = group.include_subgroup("my_subgroup", "My subgroup description")


@client.include
@arc.message_command("Message Command")
async def among_us(ctx: arc.GatewayContext, message: hikari.Message) -> None:
pass


@group.include
@arc.slash_subcommand("test_subcommand", "My subcommand description")
async def my_subcommand(
ctx: arc.GatewayContext,
a: arc.Option[int, arc.IntParams(description="foo", min=10)],
b: arc.Option[str, arc.StrParams(description="bar", min_length=100)],
) -> None:
pass


@subgroup.include()
@arc.slash_subcommand("test_subsubcommand", "My subsubcommand description")
async def my_subsubcommand(
ctx: arc.GatewayContext,
a: arc.Option[int, arc.IntParams(description="foo", min=10)],
b: arc.Option[str, arc.StrParams(description="bar", min_length=100)],
) -> None:
pass


def test_walk_commands() -> None:
cmds = list(client.walk_commands(callable_only=False))

assert len(cmds) == 5

assert among_us in cmds
assert my_subcommand in cmds
assert my_subsubcommand in cmds
assert group in cmds
assert subgroup in cmds

cmds = list(client.walk_commands(callable_only=True))

assert len(cmds) == 3

assert among_us in cmds
assert my_subcommand in cmds
assert my_subsubcommand in cmds

cmds = list(client.walk_commands(callable_only=True, types={hikari.CommandType.SLASH}))

assert len(cmds) == 2

assert my_subcommand in cmds
assert my_subsubcommand in cmds

0 comments on commit f10fddc

Please sign in to comment.