Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for forward references #93

Merged
merged 5 commits into from
Apr 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Documentation = "https://typer.tiangolo.com/"
[tool.flit.metadata.requires-extra]
test = [
"shellingham",
"pytest >=4.4.0",
"pytest >=4.4.0,< 5.4",
"pytest-cov",
"coverage",
"pytest-xdist",
Expand Down
23 changes: 23 additions & 0 deletions tests/test_others.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,26 @@ def main(name: str = typer.Option(..., autocompletion=name_callback)):
with pytest.raises(click.ClickException) as exc_info:
runner.invoke(app, ["--name", "Camila"])
assert exc_info.value.message == "Invalid autocompletion callback parameters: val2"


def test_forward_references():
app = typer.Typer()

@app.command()
def main(arg1, arg2: int, arg3: "int", arg4: bool = False, arg5: "bool" = False):
typer.echo(f"arg1: {type(arg1)} {arg1}")
typer.echo(f"arg2: {type(arg2)} {arg2}")
typer.echo(f"arg3: {type(arg3)} {arg3}")
typer.echo(f"arg4: {type(arg4)} {arg4}")
typer.echo(f"arg5: {type(arg5)} {arg5}")

result = runner.invoke(app, ["Hello", "2", "invalid"])
assert (
"Error: Invalid value for 'ARG3': invalid is not a valid integer"
in result.stdout
)
result = runner.invoke(app, ["Hello", "2", "3", "--arg4", "--arg5"])
assert (
"arg1: <class 'str'> Hello\narg2: <class 'int'> 2\narg3: <class 'int'> 3\narg4: <class 'bool'> True\narg5: <class 'bool'> True\n"
in result.stdout
)
17 changes: 10 additions & 7 deletions typer/completion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import inspect
import os
import re
import subprocess
Expand All @@ -10,7 +9,9 @@
import click
import click._bashcomplete

from .models import ParamMeta
from .params import Option
from .utils import get_params_from_function

try:
import shellingham
Expand All @@ -21,14 +22,16 @@
_click_patched = False


def get_completion_inspect_parameters() -> Tuple[inspect.Parameter, inspect.Parameter]:
def get_completion_inspect_parameters() -> Tuple[ParamMeta, ParamMeta]:
completion_init()
test_disable_detection = os.getenv("_TYPER_COMPLETE_TEST_DISABLE_SHELL_DETECTION")
if shellingham and not test_disable_detection:
signature = inspect.signature(_install_completion_placeholder_function)
parameters = get_params_from_function(_install_completion_placeholder_function)
else:
signature = inspect.signature(_install_completion_no_auto_placeholder_function)
install_param, show_param = signature.parameters.values()
parameters = get_params_from_function(
_install_completion_no_auto_placeholder_function
)
install_param, show_param = parameters.values()
return install_param, show_param


Expand Down Expand Up @@ -204,7 +207,7 @@ def install_bash(*, prog_name: str, complete_var: str, shell: str) -> Path:
rc_content = rc_path.read_text()
completion_init_lines = [f"source {completion_path}"]
for line in completion_init_lines:
if line not in rc_content:
if line not in rc_content: # pragma: nocover
rc_content += f"\n{line}"
rc_content += "\n"
rc_path.write_text(rc_content)
Expand All @@ -231,7 +234,7 @@ def install_zsh(*, prog_name: str, complete_var: str, shell: str) -> Path:
"fpath+=~/.zfunc",
]
for line in completion_init_lines:
if line not in zshrc_content:
if line not in zshrc_content: # pragma: nocover
zshrc_content += f"\n{line}"
zshrc_content += "\n"
zshrc_path.write_text(zshrc_content)
Expand Down
20 changes: 11 additions & 9 deletions typer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
NoneType,
OptionInfo,
ParameterInfo,
ParamMeta,
Required,
TyperInfo,
)
from .utils import get_params_from_function


def get_install_completion_arguments() -> Tuple[click.Parameter, click.Parameter]:
Expand Down Expand Up @@ -393,8 +395,8 @@ def get_params_convertors_ctx_param_name_from_function(
convertors = {}
context_param_name = None
if callback:
signature = inspect.signature(callback)
for param_name, param in signature.parameters.items():
parameters = get_params_from_function(callback)
for param_name, param in parameters.items():
if lenient_issubclass(param.annotation, click.Context):
context_param_name = param_name
continue
Expand Down Expand Up @@ -476,9 +478,9 @@ def get_callback(
) -> Optional[Callable]:
if not callback:
return None
signature = inspect.signature(callback)
parameters = get_params_from_function(callback)
use_params: Dict[str, Any] = {}
for param_name, param_sig in signature.parameters.items():
for param_name in parameters:
use_params[param_name] = None
for param in params:
use_params[param.name] = param.default
Expand Down Expand Up @@ -591,7 +593,7 @@ def lenient_issubclass(


def get_click_param(
param: inspect.Parameter,
param: ParamMeta,
) -> Tuple[Union[click.Argument, click.Option], Any]:
# First, find out what will be:
# * ParamInfo (ArgumentInfo or OptionInfo)
Expand Down Expand Up @@ -744,12 +746,12 @@ def get_param_callback(
) -> Optional[Callable]:
if not callback:
return None
signature = inspect.signature(callback)
parameters = get_params_from_function(callback)
ctx_name = None
click_param_name = None
value_name = None
untyped_names: List[str] = []
for param_name, param_sig in signature.parameters.items():
for param_name, param_sig in parameters.items():
if lenient_issubclass(param_sig.annotation, click.Context):
ctx_name = param_name
elif lenient_issubclass(param_sig.annotation, click.Parameter):
Expand Down Expand Up @@ -792,11 +794,11 @@ def wrapper(ctx: click.Context, param: click.Parameter, value: Any) -> Any:
def get_param_completion(callback: Optional[Callable] = None) -> Optional[Callable]:
if not callback:
return None
signature = inspect.signature(callback)
parameters = get_params_from_function(callback)
ctx_name = None
args_name = None
incomplete_name = None
unassigned_params = [param for param in signature.parameters.values()]
unassigned_params = [param for param in parameters.values()]
for param_sig in unassigned_params[:]:
origin = getattr(param_sig.annotation, "__origin__", None)
if lenient_issubclass(param_sig.annotation, click.Context):
Expand Down
16 changes: 16 additions & 0 deletions typer/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import io
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -388,3 +389,18 @@ def __init__(
allow_dash=allow_dash,
path_type=path_type,
)


class ParamMeta:
empty = inspect.Parameter.empty

def __init__(
self,
*,
name: str,
default: Any = inspect.Parameter.empty,
annotation: Any = inspect.Parameter.empty,
) -> None:
self.name = name
self.default = default
self.annotation = annotation
18 changes: 18 additions & 0 deletions typer/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import inspect
from typing import Callable, Dict, get_type_hints

from .models import ParamMeta


def get_params_from_function(func: Callable) -> Dict[str, ParamMeta]:
signature = inspect.signature(func)
type_hints = get_type_hints(func)
params = {}
for param in signature.parameters.values():
annotation = param.annotation
if param.name in type_hints:
annotation = type_hints[param.name]
params[param.name] = ParamMeta(
name=param.name, default=param.default, annotation=annotation
)
return params