From 48f9c914a0110616b2e2d8a2e381ea6d52c52034 Mon Sep 17 00:00:00 2001 From: khazhyk Date: Thu, 24 Jan 2019 21:40:42 -0800 Subject: [PATCH] [commands] custom default arguments Modeled slightly after Converters, allow specifying Converter-like class for context-based default parameters. e.g. ```py class Author(ParamDefault): async def default(self, ctx): return ctx.author async def my_command(ctx, user: discord.Member=Author): ... ``` Also adds a few common cases (Author, Channel, Guild) for current author, ... --- discord/ext/commands/__init__.py | 1 + discord/ext/commands/core.py | 22 ++++++-- discord/ext/commands/default.py | 88 ++++++++++++++++++++++++++++++++ docs/ext/commands/commands.rst | 54 ++++++++++++++++++++ 4 files changed, 161 insertions(+), 4 deletions(-) create mode 100644 discord/ext/commands/default.py diff --git a/discord/ext/commands/__init__.py b/discord/ext/commands/__init__.py index b14fd655953d..30e05dba6f01 100644 --- a/discord/ext/commands/__init__.py +++ b/discord/ext/commands/__init__.py @@ -18,3 +18,4 @@ from .converter import * from .cooldowns import * from .cog import * +from .default import CustomDefault diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index f58902021dbf..5aa9787aabd3 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -35,6 +35,7 @@ from .errors import * from .cooldowns import Cooldown, BucketType, CooldownMapping from . import converter as converters +from . import default as defaults from ._types import _BaseCommand from .cog import Cog @@ -407,12 +408,25 @@ async def do_conversion(self, ctx, converter, argument, param): def _get_converter(self, param): converter = param.annotation if converter is param.empty: - if param.default is not param.empty: - converter = str if param.default is None else type(param.default) - else: + if param.default is param.empty or param.default is None or (inspect.isclass(param.default) and issubclass(param.default, defaults.CustomDefault)): converter = str + else: + converter = type(param.default) return converter + async def _resolve_default(self, ctx, param): + try: + if inspect.isclass(param.default) and issubclass(param.default, defaults.CustomDefault): + instance = param.default() + return await instance.default(ctx=ctx, param=param) + elif isinstance(param.default, converters.CustomDefault): + return await param.default.default(ctx=ctx, param=param) + except CommandError as e: + raise e + except Exception as e: + raise ConversionError(param.default, e) from e + return param.default + async def transform(self, ctx, param): required = param.default is param.empty converter = self._get_converter(param) @@ -440,7 +454,7 @@ async def transform(self, ctx, param): if self._is_typing_optional(param.annotation): return None raise MissingRequiredArgument(param) - return param.default + return await self._resolve_default(ctx, param) previous = view.index if consume_rest_is_special: diff --git a/discord/ext/commands/default.py b/discord/ext/commands/default.py new file mode 100644 index 000000000000..401622659345 --- /dev/null +++ b/discord/ext/commands/default.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2019 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from .errors import MissingRequiredArgument + +__all__ = ( + 'CustomDefault', + 'Author', + 'Channel', + 'Guild', + 'Call', +) + +class CustomDefault: + """The base class of custom defaults that require the :class:`.Context`. + + Classes that derive from this should override the :meth:`~.CustomDefault.default` + method to do its conversion logic. This method must be a coroutine. + """ + + async def default(self, ctx, param): + """|coro| + + The method to override to do conversion logic. + + If an error is found while converting, it is recommended to + raise a :exc:`.CommandError` derived exception as it will + properly propagate to the error handlers. + + Parameters + ----------- + ctx: :class:`.Context` + The invocation context that the argument is being used in. + """ + raise NotImplementedError('Derived classes need to implement this.') + + +class Author(CustomDefault): + """Default parameter which returns the author for this context.""" + + async def default(self, ctx, param): + return ctx.author + +class Channel(CustomDefault): + """Default parameter which returns the channel for this context.""" + + async def default(self, ctx, param): + return ctx.channel + +class Guild(CustomDefault): + """Default parameter which returns the guild for this context.""" + + async def default(self, ctx, param): + if ctx.guild: + return ctx.guild + raise MissingRequiredArgument(param) + +class Call(CustomDefault): + """Easy wrapper for lambdas/inline defaults.""" + + def __init__(self, callback): + self._callback = callback + + async def default(self, ctx, param): + return self._callback(ctx, param) diff --git a/docs/ext/commands/commands.rst b/docs/ext/commands/commands.rst index 19e2bf4a1688..7ca22982d6aa 100644 --- a/docs/ext/commands/commands.rst +++ b/docs/ext/commands/commands.rst @@ -585,6 +585,60 @@ handlers that allow us to do just that. First we decorate an error handler funct The first parameter of the error handler is the :class:`.Context` while the second one is an exception that is derived from :exc:`~ext.commands.CommandError`. A list of errors is found in the :ref:`ext_commands_api_errors` page of the documentation. + +Custom Defaults +--------------- + +Custom defaults allow us to specify :class:`.Context`-based defaults. Custom defaults are always classes which inherit from +:class:`.CustomDefault`. + +The library provides some simple default implementations in the :module:`default` - :class:`default.Author`, :class:`default.Channel`, +and :class:`default.Guild` returning the corresponding properties from the Context. These can be used along with Converters to +simplify your individual commands. You can also use :class:`default.Call` to quickly wrap existing functions. + +A DefaultParam returning `None` is valid - if this should be an error, raise :class:`.MissingRequiredArgument`. + +.. code-block:: python3 + + class AnyImage(Converter): + """Find images associated with the message.""" + + async def convert(self, ctx, argument): + if argument.startswith("http://") or argument.startswith("https://"): + return argument + + member = await UserMemberConverter().convert(ctx, argument) + if member: + return str(member.avatar_url_as(format="png")) + + raise errors.BadArgument(f"{argument} isn't a member or url.") + + class LastImage(CustomDefault): + """Default param which finds the last image in chat.""" + + async def default(self, ctx, param): + for attachment in message.attachments: + if attachment.proxy_url: + return attachment.proxy_url + async for message in ctx.history(ctx, limit=100): + for embed in message.embeds: + if embed.thumbnail and embed.thumbnail.proxy_url: + return embed.thumbnail.proxy_url + for attachment in message.attachments: + if attachment.proxy_url: + return attachment.proxy_url + + raise errors.MissingRequiredArgument(param) + + @bot.command() + async def echo_image(ctx, *, image: Image = LastImage): + async with aiohttp.ClientSession() as sess: + async with sess.get(image) as resp: + resp.raise_for_status() + my_bytes = io.BytesIO(await resp.content.read()) + await ctx.send(file=discord.File(name="your_image", fp=my_bytes)) + + Checks -------