Skip to content

Commit

Permalink
fixed Command.get_argument_type bug with UnionType (#110)
Browse files Browse the repository at this point in the history
`Command.get_argument_type` currently crashes when `UnionType` is
encountered. Add special handling for this type
  • Loading branch information
nkvuong authored May 30, 2024
1 parent be215f1 commit 8300dd7
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
6 changes: 5 additions & 1 deletion src/databricks/labs/blueprint/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import inspect
import json
import logging
import types
from collections.abc import Callable
from dataclasses import dataclass

Expand Down Expand Up @@ -40,7 +41,10 @@ def get_argument_type(self, argument_name: str) -> str | None:
sig = inspect.signature(self.fn)
if argument_name not in sig.parameters:
return None
return sig.parameters[argument_name].annotation.__name__
annotation = sig.parameters[argument_name].annotation
if isinstance(annotation, types.UnionType):
return str(annotation)
return annotation.__name__


class App:
Expand Down
28 changes: 22 additions & 6 deletions tests/unit/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"address": "",
"is_customer": "true",
"log_level": "disabled",
"optional_arg": "optional",
},
}
)
Expand All @@ -26,27 +27,42 @@ def test_commands():
app = App(inspect.getfile(App))

@app.command(is_unauthenticated=True)
def foo(name: str, age: int, salary: float, is_customer: bool, address: str = "default"):
def foo(
name: str,
age: int,
salary: float,
is_customer: bool,
address: str = "default",
optional_arg: str | None = None,
):
"""Some comment"""
some(name, age, salary, is_customer, address)
some(name, age, salary, is_customer, address, optional_arg)

with mock.patch.object(sys, "argv", [..., FOO_COMMAND]):
app()

some.assert_called_with("y", 100, 100.5, True, "default")
some.assert_called_with("y", 100, 100.5, True, "default", "optional")


def test_injects_prompts():
some = mock.Mock()
app = App(inspect.getfile(App))

@app.command(is_unauthenticated=True)
def foo(name: str, age: int, salary: float, is_customer: bool, prompts: Prompts, address: str = "default"):
def foo(
name: str,
age: int,
salary: float,
is_customer: bool,
prompts: Prompts,
address: str = "default",
optional_arg: str | None = None,
):
"""Some comment"""
assert isinstance(prompts, Prompts)
some(name, age, salary, is_customer, address)
some(name, age, salary, is_customer, address, optional_arg)

with mock.patch.object(sys, "argv", [..., FOO_COMMAND]):
app()

some.assert_called_with("y", 100, 100.5, True, "default")
some.assert_called_with("y", 100, 100.5, True, "default", "optional")

0 comments on commit 8300dd7

Please sign in to comment.