Skip to content

Commit

Permalink
🐛 Fix support for UnionType (e.g. str | None) with Python 3.11 (#548
Browse files Browse the repository at this point in the history
)

Co-authored-by: svlandeg <svlandeg@github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 17, 2024
1 parent ad421bd commit 218bf89
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 13 deletions.
22 changes: 22 additions & 0 deletions tests/test_type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import typer
from typer.testing import CliRunner

from .utils import needs_py310

runner = CliRunner()


Expand All @@ -29,6 +31,26 @@ def opt(user: Optional[str] = None):
assert "User: Camila" in result.output


@needs_py310
def test_union_type_optional():
app = typer.Typer()

@app.command()
def opt(user: str | None = None):
if user:
print(f"User: {user}")
else:
print("No user")

result = runner.invoke(app)
assert result.exit_code == 0
assert "No user" in result.output

result = runner.invoke(app, ["--user", "Camila"])
assert result.exit_code == 0
assert "User: Camila" in result.output


def test_optional_tuple():
app = typer.Typer()

Expand Down
28 changes: 15 additions & 13 deletions typer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import click

from ._typing import get_args, get_origin, is_union
from .completion import get_completion_inspect_parameters
from .core import MarkupMode, TyperArgument, TyperCommand, TyperGroup, TyperOption
from .models import (
Expand Down Expand Up @@ -825,30 +826,31 @@ def get_click_param(
is_tuple = False
parameter_type: Any = None
is_flag = None
origin = getattr(main_type, "__origin__", None)
origin = get_origin(main_type)

if origin is not None:
# Handle Optional[SomeType]
if origin is Union:
# Handle SomeType | None and Optional[SomeType]
if is_union(origin):
types = []
for type_ in main_type.__args__:
for type_ in get_args(main_type):
if type_ is NoneType:
continue
types.append(type_)
assert len(types) == 1, "Typer Currently doesn't support Union types"
main_type = types[0]
origin = getattr(main_type, "__origin__", None)
origin = get_origin(main_type)
# Handle Tuples and Lists
if lenient_issubclass(origin, List):
main_type = main_type.__args__[0]
assert not getattr(
main_type, "__origin__", None
main_type = get_args(main_type)[0]
assert not get_origin(
main_type
), "List types with complex sub-types are not currently supported"
is_list = True
elif lenient_issubclass(origin, Tuple): # type: ignore
types = []
for type_ in main_type.__args__:
assert not getattr(
type_, "__origin__", None
for type_ in get_args(main_type):
assert not get_origin(
type_
), "Tuple types with complex sub-types are not currently supported"
types.append(
get_click_type(annotation=type_, parameter_info=parameter_info)
Expand All @@ -865,7 +867,7 @@ def get_click_param(
convertor=convertor, default_value=default_value
)
if is_tuple:
convertor = generate_tuple_convertor(main_type.__args__)
convertor = generate_tuple_convertor(get_args(main_type))
if isinstance(parameter_info, OptionInfo):
if main_type is bool and parameter_info.is_flag is not False:
is_flag = True
Expand Down Expand Up @@ -1020,7 +1022,7 @@ def get_param_completion(
incomplete_name = None
unassigned_params = list(parameters.values())
for param_sig in unassigned_params[:]:
origin = getattr(param_sig.annotation, "__origin__", None)
origin = get_origin(param_sig.annotation)
if lenient_issubclass(param_sig.annotation, click.Context):
ctx_name = param_sig.name
unassigned_params.remove(param_sig)
Expand Down

0 comments on commit 218bf89

Please sign in to comment.