From 94c31718b013aa136c33c6b0604c87a0157eb28c Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 3 Aug 2022 14:46:15 -0400 Subject: [PATCH] Add support for simple Literal fields (#152) This adds support for simple fields, like: ```python from typing import Literal from dataclasses import dataclass @dataclass class Config: model: Literal["resnet18", "resnet50"] = "resnet18 ``` Also adds support for: - Literal[ints] - Literal[enums] Doesn't yet add support for: - Optional[Literal[X]] - List[Literal[X]] - Tuple[Literal[X], ...] Signed-off-by: Fabrice Normandin --- simple_parsing/utils.py | 35 +++++--- simple_parsing/wrappers/field_wrapper.py | 60 ++++++++------ test/test_literal.py | 100 +++++++++++++++++++++++ test/testutils.py | 10 ++- 4 files changed, 165 insertions(+), 40 deletions(-) create mode 100644 test/test_literal.py diff --git a/simple_parsing/utils.py b/simple_parsing/utils.py index 25bc3c71..97c20a8d 100644 --- a/simple_parsing/utils.py +++ b/simple_parsing/utils.py @@ -32,6 +32,7 @@ TypeVar, Union, ) +from typing_extensions import Literal, get_args # from typing_inspect import get_origin, is_typevar, get_bound, is_forward_ref, get_forward_arg NEW_TYPING = sys.version_info[:3] >= (3, 7, 0) # PEP 560 @@ -65,19 +66,6 @@ def get_forward_arg(fr): return getattr(fr, "__forward_arg__", None) -try: - from typing import get_args -except ImportError: - # try: - # # TODO: Not sure we should depend on typing_inspect, results appear to vary - # # greatly - # # between python versions. - # from typing_inspect import get_args - # except ImportError: - def get_args(some_type: Type) -> Tuple[Type, ...]: - return getattr(some_type, "__args__", ()) - - logger = getLogger(__name__) builtin_types = [ @@ -248,6 +236,27 @@ def _mro(t: Type) -> List[Type]: return [] +def is_literal(t: type) -> bool: + """Returns True with `t` is a Literal type. + + >>> from typing_extensions import Literal + >>> from typing import * + >>> is_literal(list) + False + >>> is_literal("foo") + False + >>> is_literal(Literal[True, False]) + True + >>> is_literal(Literal[1,2,3]) + True + >>> is_literal(Literal["foo", "bar"]) + True + >>> is_literal(Optional[Literal[1,2]]) + False + """ + return get_origin(t) is Literal + + def is_list(t: Type) -> bool: """returns True when `t` is a List type. diff --git a/simple_parsing/wrappers/field_wrapper.py b/simple_parsing/wrappers/field_wrapper.py index 9a541e75..d6430f0c 100644 --- a/simple_parsing/wrappers/field_wrapper.py +++ b/simple_parsing/wrappers/field_wrapper.py @@ -6,7 +6,6 @@ from enum import Enum, auto from logging import getLogger from typing import Any, ClassVar, Dict, List, Optional, Set, Tuple, Type, Union, cast - from simple_parsing.help_formatter import TEMPORARY_TOKEN from .. import docstring, utils @@ -234,9 +233,15 @@ def get_arg_options(self) -> Dict[str, Any]: # automatically adds the (default: '123'). We then remove it. _arg_options["help"] = TEMPORARY_TOKEN - if utils.is_choice(self.field): - _arg_options["type"] = str - _arg_options["choices"] = list(self.choices) + # TODO: Possible duplication between utils.is_foo(Field) and self.is_foo where foo in + # [choice, optional, list, tuple, dataclass, etc.] + if self.is_choice: + choices = self.choices + assert choices + item_type = str + _arg_options["type"] = item_type + _arg_options["choices"] = choices + # TODO: Refactor this. is_choice and is_list are both contributing, so it's unclear. if utils.is_list(self.type): _arg_options["nargs"] = argparse.ZERO_OR_MORE # We use the default 'metavar' generated by argparse. @@ -305,6 +310,7 @@ def get_arg_options(self) -> Dict[str, Any]: # we actually parse enums as string, and convert them back to enums # in the `process` method. logger.debug(f"self.choices = {self.choices}") + assert issubclass(self.type, Enum) _arg_options["choices"] = list(e.name for e in self.type) _arg_options["type"] = str # if the default value is an Enum, we convert it to a string. @@ -438,7 +444,7 @@ def postprocess(self, raw_parsed_value: Any) -> Any: return raw_parsed_value elif self.is_choice: - choice_dict = self.field.metadata.get("choice_dict") + choice_dict = self.choice_dict if choice_dict: key_type = type(next(iter(choice_dict.keys()))) if self.is_list and isinstance(raw_parsed_value[0], key_type): @@ -784,13 +790,6 @@ def type(self) -> Type[Any]: field_type = get_field_type_from_annotations(self.parent.dataclass, self.field.name) self._type = field_type - - if self.is_choice and self.choice_dict: - if utils.is_list(self.field.type): - self._type = List[str] - else: - self._type = str - return self._type def __str__(self): @@ -802,19 +801,30 @@ def is_choice(self) -> bool: @property def choices(self) -> Optional[List]: - choices = self.custom_arg_options.get("choices") - choice_dict = self.choice_dict - if not choices: - return None - if len(choices) == 1 and isinstance(choices[0], dict): - choice_dict = choices[0] - assert False, "HERE" # FIXME: Remove this, was debugging something? - return choice_dict - return choices - - @property - def choice_dict(self) -> Optional[Dict]: - return self.field.metadata.get("choice_dict") + """The list of possible values that can be passed on the command-line for this field, or None.""" + if "choices" in self.custom_arg_options: + return self.custom_arg_options["choices"] + if "choice_dict" in self.field.metadata: + return list(self.field.metadata["choice_dict"].keys()) + if utils.is_literal(self.type): + literal_values = list(utils.get_args(self.type)) + literal_value_names = [ + v.name if isinstance(v, Enum) else str(v) for v in literal_values + ] + return literal_value_names + return None + + @property + def choice_dict(self) -> Optional[Dict[str, Any]]: + if "choice_dict" in self.field.metadata: + return self.field.metadata["choice_dict"] + if utils.is_literal(self.type): + literal_values = list(utils.get_args(self.type)) + assert literal_values, "Literal always has at least one argument." + # We map from literal values (as strings) to the actual values. + # e.g. from BLUE -> Color.Blue + return {(v.name if isinstance(v, Enum) else str(v)): v for v in literal_values} + return None @property def help(self) -> Optional[str]: diff --git a/test/test_literal.py b/test/test_literal.py new file mode 100644 index 00000000..49df9f85 --- /dev/null +++ b/test/test_literal.py @@ -0,0 +1,100 @@ +from .testutils import ( + TestSetup, + raises_expected_n_args, + raises_missing_required_arg, + raises_invalid_choice, + xfail_param, +) +from dataclasses import dataclass +from typing_extensions import Literal +from typing import NamedTuple, Any +from typing import Any, List, NamedTuple, Type, Optional +import enum +import pytest + + +class FieldComponents(NamedTuple): + field_annotation: Any + passed_value: Any + parsed_value: Any + incorrect_value: Any + + +class Color(enum.Enum): + RED = enum.auto() + BLUE = enum.auto() + GREEN = enum.auto() + + +Fingers = Literal[0, 1, 2, 3, 4, 5] + + +@pytest.fixture( + params=[ + FieldComponents(Literal["bob", "alice"], "bob", "bob", "clarice"), + FieldComponents(Literal[True, False], "True", True, "bob"), + xfail_param( + [FieldComponents(Literal[True, False], "true", True, "bob")], + reason="The support for boolean literals currently assumes just 'True' and 'False'.", + ), + FieldComponents(Literal[1, 2, 3], "1", 1, "foobar"), + FieldComponents(Literal[1, 2, 3], "2", 2, "9"), + FieldComponents(Literal["bob", "alice"], "bob", "bob", "clarice"), + FieldComponents(Literal[Color.BLUE, Color.GREEN], "BLUE", Color.BLUE, "red"), + FieldComponents(Literal[Color.BLUE, Color.GREEN], "BLUE", Color.BLUE, "foobar"), + FieldComponents(Fingers, "1", 1, "foobar"), + FieldComponents("Fingers", "1", 1, "foobar"), + ] +) +def literal_field(request: pytest.FixtureRequest): + field = request.param # type: ignore + return field + + +def test_literal(literal_field: FieldComponents): + field_annotation, passed_value, parsed_value, incorrect_value = literal_field + + @dataclass + class Foo(TestSetup): + bar: field_annotation # type: ignore + + with raises_missing_required_arg(): + Foo.setup("") + + assert Foo.setup(f"--bar {passed_value}") == Foo(bar=parsed_value) + + with raises_invalid_choice(): + assert Foo.setup(f"--bar {incorrect_value}") + + +@pytest.mark.xfail(reason="TODO: add support for optional literals") +def test_optional_literal(literal_field: FieldComponents): + field_annotation, passed_value, parsed_value, incorrect_value = literal_field + + @dataclass + class Foo(TestSetup): + bar: Optional[field_annotation] = None # type: ignore + + assert Foo.setup("") == Foo(bar=None) + assert Foo.setup(f"--bar {passed_value}") == Foo(bar=parsed_value) + + with raises_invalid_choice(): + assert Foo.setup(f"--bar {incorrect_value}") + + +@pytest.mark.xfail(reason="TODO: Support lists of literals.") +def test_list_of_literal(literal_field: FieldComponents): + field_annotation, passed_value, parsed_value, incorrect_value = literal_field + + @dataclass + class Foo(TestSetup): + values: List[field_annotation] # type: ignore + + with raises_missing_required_arg(): + Foo.setup(f"") + + assert Foo.setup(f"--values {passed_value} {passed_value}") == Foo( + values=[parsed_value, parsed_value] + ) + with raises_invalid_choice(): + assert Foo.setup(f"--values {incorrect_value}") diff --git a/test/testutils.py b/test/testutils.py index 790b9086..d5209726 100644 --- a/test/testutils.py +++ b/test/testutils.py @@ -2,7 +2,7 @@ import string import sys from contextlib import contextmanager -from typing import Any, Callable, Generic, List, Optional, Tuple, Type, TypeVar, cast +from typing import Any, Callable, Generic, List, Optional, Tuple, Type, TypeVar, cast, Union import pytest @@ -54,7 +54,13 @@ def raises_missing_required_arg(): @contextmanager -def raises_expected_n_args(n: int): +def raises_invalid_choice(): + with exits_and_writes_to_stderr("invalid choice:"): + yield + + +@contextmanager +def raises_expected_n_args(n: Union[int, str]): with exits_and_writes_to_stderr(f"expected {n} arguments"): yield