Skip to content

Commit

Permalink
Add support for simple Literal fields (#152)
Browse files Browse the repository at this point in the history
This adds support for simple fields, like:

```python
from typing import Literal
from dataclasses import dataclass

@DataClass
class Config:
    model: Literal["resnet18", "resnet50"] = "resnet18
```

Also adds support for:
- Literal[ints]
- Literal[enums]

Doesn't yet add support for:
- Optional[Literal[X]]
- List[Literal[X]]
- Tuple[Literal[X], ...]

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>
  • Loading branch information
lebrice authored Aug 3, 2022
1 parent e33a730 commit 94c3171
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 40 deletions.
35 changes: 22 additions & 13 deletions simple_parsing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
TypeVar,
Union,
)
from typing_extensions import Literal, get_args

# from typing_inspect import get_origin, is_typevar, get_bound, is_forward_ref, get_forward_arg
NEW_TYPING = sys.version_info[:3] >= (3, 7, 0) # PEP 560
Expand Down Expand Up @@ -65,19 +66,6 @@ def get_forward_arg(fr):
return getattr(fr, "__forward_arg__", None)


try:
from typing import get_args
except ImportError:
# try:
# # TODO: Not sure we should depend on typing_inspect, results appear to vary
# # greatly
# # between python versions.
# from typing_inspect import get_args
# except ImportError:
def get_args(some_type: Type) -> Tuple[Type, ...]:
return getattr(some_type, "__args__", ())


logger = getLogger(__name__)

builtin_types = [
Expand Down Expand Up @@ -248,6 +236,27 @@ def _mro(t: Type) -> List[Type]:
return []


def is_literal(t: type) -> bool:
"""Returns True with `t` is a Literal type.
>>> from typing_extensions import Literal
>>> from typing import *
>>> is_literal(list)
False
>>> is_literal("foo")
False
>>> is_literal(Literal[True, False])
True
>>> is_literal(Literal[1,2,3])
True
>>> is_literal(Literal["foo", "bar"])
True
>>> is_literal(Optional[Literal[1,2]])
False
"""
return get_origin(t) is Literal


def is_list(t: Type) -> bool:
"""returns True when `t` is a List type.
Expand Down
60 changes: 35 additions & 25 deletions simple_parsing/wrappers/field_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from enum import Enum, auto
from logging import getLogger
from typing import Any, ClassVar, Dict, List, Optional, Set, Tuple, Type, Union, cast

from simple_parsing.help_formatter import TEMPORARY_TOKEN

from .. import docstring, utils
Expand Down Expand Up @@ -234,9 +233,15 @@ def get_arg_options(self) -> Dict[str, Any]:
# automatically adds the (default: '123'). We then remove it.
_arg_options["help"] = TEMPORARY_TOKEN

if utils.is_choice(self.field):
_arg_options["type"] = str
_arg_options["choices"] = list(self.choices)
# TODO: Possible duplication between utils.is_foo(Field) and self.is_foo where foo in
# [choice, optional, list, tuple, dataclass, etc.]
if self.is_choice:
choices = self.choices
assert choices
item_type = str
_arg_options["type"] = item_type
_arg_options["choices"] = choices
# TODO: Refactor this. is_choice and is_list are both contributing, so it's unclear.
if utils.is_list(self.type):
_arg_options["nargs"] = argparse.ZERO_OR_MORE
# We use the default 'metavar' generated by argparse.
Expand Down Expand Up @@ -305,6 +310,7 @@ def get_arg_options(self) -> Dict[str, Any]:
# we actually parse enums as string, and convert them back to enums
# in the `process` method.
logger.debug(f"self.choices = {self.choices}")
assert issubclass(self.type, Enum)
_arg_options["choices"] = list(e.name for e in self.type)
_arg_options["type"] = str
# if the default value is an Enum, we convert it to a string.
Expand Down Expand Up @@ -438,7 +444,7 @@ def postprocess(self, raw_parsed_value: Any) -> Any:
return raw_parsed_value

elif self.is_choice:
choice_dict = self.field.metadata.get("choice_dict")
choice_dict = self.choice_dict
if choice_dict:
key_type = type(next(iter(choice_dict.keys())))
if self.is_list and isinstance(raw_parsed_value[0], key_type):
Expand Down Expand Up @@ -784,13 +790,6 @@ def type(self) -> Type[Any]:

field_type = get_field_type_from_annotations(self.parent.dataclass, self.field.name)
self._type = field_type

if self.is_choice and self.choice_dict:
if utils.is_list(self.field.type):
self._type = List[str]
else:
self._type = str

return self._type

def __str__(self):
Expand All @@ -802,19 +801,30 @@ def is_choice(self) -> bool:

@property
def choices(self) -> Optional[List]:
choices = self.custom_arg_options.get("choices")
choice_dict = self.choice_dict
if not choices:
return None
if len(choices) == 1 and isinstance(choices[0], dict):
choice_dict = choices[0]
assert False, "HERE" # FIXME: Remove this, was debugging something?
return choice_dict
return choices

@property
def choice_dict(self) -> Optional[Dict]:
return self.field.metadata.get("choice_dict")
"""The list of possible values that can be passed on the command-line for this field, or None."""
if "choices" in self.custom_arg_options:
return self.custom_arg_options["choices"]
if "choice_dict" in self.field.metadata:
return list(self.field.metadata["choice_dict"].keys())
if utils.is_literal(self.type):
literal_values = list(utils.get_args(self.type))
literal_value_names = [
v.name if isinstance(v, Enum) else str(v) for v in literal_values
]
return literal_value_names
return None

@property
def choice_dict(self) -> Optional[Dict[str, Any]]:
if "choice_dict" in self.field.metadata:
return self.field.metadata["choice_dict"]
if utils.is_literal(self.type):
literal_values = list(utils.get_args(self.type))
assert literal_values, "Literal always has at least one argument."
# We map from literal values (as strings) to the actual values.
# e.g. from BLUE -> Color.Blue
return {(v.name if isinstance(v, Enum) else str(v)): v for v in literal_values}
return None

@property
def help(self) -> Optional[str]:
Expand Down
100 changes: 100 additions & 0 deletions test/test_literal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from .testutils import (
TestSetup,
raises_expected_n_args,
raises_missing_required_arg,
raises_invalid_choice,
xfail_param,
)
from dataclasses import dataclass
from typing_extensions import Literal
from typing import NamedTuple, Any
from typing import Any, List, NamedTuple, Type, Optional
import enum
import pytest


class FieldComponents(NamedTuple):
field_annotation: Any
passed_value: Any
parsed_value: Any
incorrect_value: Any


class Color(enum.Enum):
RED = enum.auto()
BLUE = enum.auto()
GREEN = enum.auto()


Fingers = Literal[0, 1, 2, 3, 4, 5]


@pytest.fixture(
params=[
FieldComponents(Literal["bob", "alice"], "bob", "bob", "clarice"),
FieldComponents(Literal[True, False], "True", True, "bob"),
xfail_param(
[FieldComponents(Literal[True, False], "true", True, "bob")],
reason="The support for boolean literals currently assumes just 'True' and 'False'.",
),
FieldComponents(Literal[1, 2, 3], "1", 1, "foobar"),
FieldComponents(Literal[1, 2, 3], "2", 2, "9"),
FieldComponents(Literal["bob", "alice"], "bob", "bob", "clarice"),
FieldComponents(Literal[Color.BLUE, Color.GREEN], "BLUE", Color.BLUE, "red"),
FieldComponents(Literal[Color.BLUE, Color.GREEN], "BLUE", Color.BLUE, "foobar"),
FieldComponents(Fingers, "1", 1, "foobar"),
FieldComponents("Fingers", "1", 1, "foobar"),
]
)
def literal_field(request: pytest.FixtureRequest):
field = request.param # type: ignore
return field


def test_literal(literal_field: FieldComponents):
field_annotation, passed_value, parsed_value, incorrect_value = literal_field

@dataclass
class Foo(TestSetup):
bar: field_annotation # type: ignore

with raises_missing_required_arg():
Foo.setup("")

assert Foo.setup(f"--bar {passed_value}") == Foo(bar=parsed_value)

with raises_invalid_choice():
assert Foo.setup(f"--bar {incorrect_value}")


@pytest.mark.xfail(reason="TODO: add support for optional literals")
def test_optional_literal(literal_field: FieldComponents):
field_annotation, passed_value, parsed_value, incorrect_value = literal_field

@dataclass
class Foo(TestSetup):
bar: Optional[field_annotation] = None # type: ignore

assert Foo.setup("") == Foo(bar=None)
assert Foo.setup(f"--bar {passed_value}") == Foo(bar=parsed_value)

with raises_invalid_choice():
assert Foo.setup(f"--bar {incorrect_value}")


@pytest.mark.xfail(reason="TODO: Support lists of literals.")
def test_list_of_literal(literal_field: FieldComponents):
field_annotation, passed_value, parsed_value, incorrect_value = literal_field

@dataclass
class Foo(TestSetup):
values: List[field_annotation] # type: ignore

with raises_missing_required_arg():
Foo.setup(f"")

assert Foo.setup(f"--values {passed_value} {passed_value}") == Foo(
values=[parsed_value, parsed_value]
)
with raises_invalid_choice():
assert Foo.setup(f"--values {incorrect_value}")
10 changes: 8 additions & 2 deletions test/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import string
import sys
from contextlib import contextmanager
from typing import Any, Callable, Generic, List, Optional, Tuple, Type, TypeVar, cast
from typing import Any, Callable, Generic, List, Optional, Tuple, Type, TypeVar, cast, Union

import pytest

Expand Down Expand Up @@ -54,7 +54,13 @@ def raises_missing_required_arg():


@contextmanager
def raises_expected_n_args(n: int):
def raises_invalid_choice():
with exits_and_writes_to_stderr("invalid choice:"):
yield


@contextmanager
def raises_expected_n_args(n: Union[int, str]):
with exits_and_writes_to_stderr(f"expected {n} arguments"):
yield

Expand Down

0 comments on commit 94c3171

Please sign in to comment.