Skip to content

Commit

Permalink
[commands] custom default arguments
Browse files Browse the repository at this point in the history
Modeled slightly after Converters, allow specifying Converter-like class
for context-based default parameters. e.g.

```py
class Author(ParamDefault):
  async def default(self, ctx):
    return ctx.author

async def my_command(ctx, user: discord.Member=Author):
  ...
```

Also adds a few common cases (Author, Channel, Guild) for current
author, ...
  • Loading branch information
khazhyk committed May 8, 2019
1 parent 6dcd68b commit 48f9c91
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 4 deletions.
1 change: 1 addition & 0 deletions discord/ext/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from .converter import *
from .cooldowns import *
from .cog import *
from .default import CustomDefault
22 changes: 18 additions & 4 deletions discord/ext/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .errors import *
from .cooldowns import Cooldown, BucketType, CooldownMapping
from . import converter as converters
from . import default as defaults
from ._types import _BaseCommand
from .cog import Cog

Expand Down Expand Up @@ -407,12 +408,25 @@ async def do_conversion(self, ctx, converter, argument, param):
def _get_converter(self, param):
converter = param.annotation
if converter is param.empty:
if param.default is not param.empty:
converter = str if param.default is None else type(param.default)
else:
if param.default is param.empty or param.default is None or (inspect.isclass(param.default) and issubclass(param.default, defaults.CustomDefault)):
converter = str
else:
converter = type(param.default)
return converter

async def _resolve_default(self, ctx, param):
try:
if inspect.isclass(param.default) and issubclass(param.default, defaults.CustomDefault):
instance = param.default()
return await instance.default(ctx=ctx, param=param)
elif isinstance(param.default, converters.CustomDefault):
return await param.default.default(ctx=ctx, param=param)
except CommandError as e:
raise e
except Exception as e:
raise ConversionError(param.default, e) from e
return param.default

async def transform(self, ctx, param):
required = param.default is param.empty
converter = self._get_converter(param)
Expand Down Expand Up @@ -440,7 +454,7 @@ async def transform(self, ctx, param):
if self._is_typing_optional(param.annotation):
return None
raise MissingRequiredArgument(param)
return param.default
return await self._resolve_default(ctx, param)

previous = view.index
if consume_rest_is_special:
Expand Down
88 changes: 88 additions & 0 deletions discord/ext/commands/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-

"""
The MIT License (MIT)
Copyright (c) 2015-2019 Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""

from .errors import MissingRequiredArgument

__all__ = (
'CustomDefault',
'Author',
'Channel',
'Guild',
'Call',
)

class CustomDefault:
"""The base class of custom defaults that require the :class:`.Context`.
Classes that derive from this should override the :meth:`~.CustomDefault.default`
method to do its conversion logic. This method must be a coroutine.
"""

async def default(self, ctx, param):
"""|coro|
The method to override to do conversion logic.
If an error is found while converting, it is recommended to
raise a :exc:`.CommandError` derived exception as it will
properly propagate to the error handlers.
Parameters
-----------
ctx: :class:`.Context`
The invocation context that the argument is being used in.
"""
raise NotImplementedError('Derived classes need to implement this.')


class Author(CustomDefault):
"""Default parameter which returns the author for this context."""

async def default(self, ctx, param):
return ctx.author

class Channel(CustomDefault):
"""Default parameter which returns the channel for this context."""

async def default(self, ctx, param):
return ctx.channel

class Guild(CustomDefault):
"""Default parameter which returns the guild for this context."""

async def default(self, ctx, param):
if ctx.guild:
return ctx.guild
raise MissingRequiredArgument(param)

class Call(CustomDefault):
"""Easy wrapper for lambdas/inline defaults."""

def __init__(self, callback):
self._callback = callback

async def default(self, ctx, param):
return self._callback(ctx, param)
54 changes: 54 additions & 0 deletions docs/ext/commands/commands.rst
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,60 @@ handlers that allow us to do just that. First we decorate an error handler funct
The first parameter of the error handler is the :class:`.Context` while the second one is an exception that is derived from
:exc:`~ext.commands.CommandError`. A list of errors is found in the :ref:`ext_commands_api_errors` page of the documentation.


Custom Defaults
---------------

Custom defaults allow us to specify :class:`.Context`-based defaults. Custom defaults are always classes which inherit from
:class:`.CustomDefault`.

The library provides some simple default implementations in the :module:`default` - :class:`default.Author`, :class:`default.Channel`,
and :class:`default.Guild` returning the corresponding properties from the Context. These can be used along with Converters to
simplify your individual commands. You can also use :class:`default.Call` to quickly wrap existing functions.

A DefaultParam returning `None` is valid - if this should be an error, raise :class:`.MissingRequiredArgument`.

.. code-block:: python3
class AnyImage(Converter):
"""Find images associated with the message."""
async def convert(self, ctx, argument):
if argument.startswith("http://") or argument.startswith("https://"):
return argument
member = await UserMemberConverter().convert(ctx, argument)
if member:
return str(member.avatar_url_as(format="png"))
raise errors.BadArgument(f"{argument} isn't a member or url.")
class LastImage(CustomDefault):
"""Default param which finds the last image in chat."""
async def default(self, ctx, param):
for attachment in message.attachments:
if attachment.proxy_url:
return attachment.proxy_url
async for message in ctx.history(ctx, limit=100):
for embed in message.embeds:
if embed.thumbnail and embed.thumbnail.proxy_url:
return embed.thumbnail.proxy_url
for attachment in message.attachments:
if attachment.proxy_url:
return attachment.proxy_url
raise errors.MissingRequiredArgument(param)
@bot.command()
async def echo_image(ctx, *, image: Image = LastImage):
async with aiohttp.ClientSession() as sess:
async with sess.get(image) as resp:
resp.raise_for_status()
my_bytes = io.BytesIO(await resp.content.read())
await ctx.send(file=discord.File(name="your_image", fp=my_bytes))
Checks
-------

Expand Down

0 comments on commit 48f9c91

Please sign in to comment.