Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support dataclasses.InitVar for python>=3.8 #171

Merged
merged 1 commit into from
Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions simple_parsing/annotation_utils/get_field_annotations.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import collections
from contextlib import contextmanager
from dataclasses import InitVar
import inspect
import sys
import types
import typing
from logging import getLogger as get_logger
from typing import Any, Dict, Optional, get_type_hints
from typing import Any, Dict, Iterator, Optional, get_type_hints

logger = get_logger(__name__)

Expand All @@ -19,6 +21,25 @@
}


@contextmanager
def _initvar_patcher() -> Iterator[None]:
"""
Patch InitVar to not fail when annotations are postponed.

`TypeVar('Forward references must evaluate to types. Got dataclasses.InitVar[tp].')` is raised
when postponed annotations are enabled and `get_type_hints` is called
Bug is mentioned here https://github.com/python/cpython/issues/88962
In python 3.11 this is fixed, but backport fix is not planned for old releases

Workaround is mentioned here https://stackoverflow.com/q/70400639
"""
if sys.version_info[:2] < (3, 11):
InitVar.__call__ = lambda *args: None
yield
if sys.version_info[:2] < (3, 11):
del InitVar.__call__


def evaluate_string_annotation(annotation: str, containing_class: Optional[type] = None) -> type:
"""Attempts to evaluate the given annotation string, to get a 'live' type annotation back.

Expand Down Expand Up @@ -160,7 +181,8 @@ def get_field_type_from_annotations(some_class: type, field_name: str) -> type:
global_ns = sys.modules[some_class.__module__].__dict__

try:
annotations_dict = get_type_hints(some_class, localns=local_ns, globalns=global_ns)
with _initvar_patcher():
annotations_dict = get_type_hints(some_class, localns=local_ns, globalns=global_ns)
except TypeError:
annotations_dict = collections.ChainMap(
*[getattr(cls, "__annotations__", {}) for cls in some_class.mro()]
Expand Down Expand Up @@ -197,7 +219,8 @@ class Temp_:
pass

Temp_.__annotations__ = {field_name: field_type}
annotations_dict = get_type_hints(Temp_, globalns=global_ns, localns=local_ns)
with _initvar_patcher():
annotations_dict = get_type_hints(Temp_, globalns=global_ns, localns=local_ns)
field_type = annotations_dict[field_name]
except Exception:
logger.warning(
Expand Down
21 changes: 20 additions & 1 deletion simple_parsing/wrappers/dataclass_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import dataclasses
from dataclasses import MISSING
from logging import getLogger
import sys
from typing import cast

from .. import docstring, utils
Expand Down Expand Up @@ -48,7 +49,25 @@ def __init__(
self.defaults = [default]
self.optional: bool = False

for field in dataclasses.fields(self.dataclass):
# NOTE: `dataclasses.fields` method retrieves only `dataclasses._FIELD`
# NOTE: but we also want to know about `dataclasses._FIELD_INITVAR`
# NOTE: therefore we partly copy-paste its implementation
if sys.version_info[:2] < (3, 8):
# Before 3.8 `InitVar[tp] is InitVar` so it's impossible to retrieve field type
# therefore we should skip it just to be fully backward compatible
dataclass_fields = dataclasses.fields(self.dataclass)
else:
try:
dataclass_fields_map = getattr(self.dataclass, dataclasses._FIELDS)
except AttributeError:
raise TypeError('must be called with a dataclass type or instance')
dataclass_fields = tuple(
field
for field in dataclass_fields_map.values()
if field._field_type in (dataclasses._FIELD, dataclasses._FIELD_INITVAR)
)

for field in dataclass_fields:
if not field.init or field.metadata.get("cmd", True) is False:
# Don't add arguments for these fields.
continue
Expand Down
2 changes: 2 additions & 0 deletions simple_parsing/wrappers/field_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,8 @@ def type(self) -> Type[Any]:

field_type = get_field_type_from_annotations(self.parent.dataclass, self.field.name)
self._type = field_type
elif isinstance(self._type, dataclasses.InitVar):
self._type = self._type.type
return self._type

def __str__(self):
Expand Down
18 changes: 17 additions & 1 deletion test/test_future_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import dataclasses
import typing
from dataclasses import dataclass
from dataclasses import InitVar, dataclass
import sys
from typing import Any, Callable, Generic, TypeVar

import pytest
Expand All @@ -27,6 +28,12 @@ class Foo(TestSetup):

d: list[bool] = field(default_factory=list)

e: InitVar[int] = 5
d: int = field(init=False)

def __post_init__(self, e: int) -> None:
self.d = e + 2


@dataclass
class Bar(TestSetup):
Expand All @@ -44,6 +51,15 @@ def test_future_annotations():
assert foo == Foo(a=2, b="heyo", c=(1, 7.89))


@pytest.mark.skipif(
sys.version_info[:2] < (3, 8),
reason="Before 3.8 `InitVar[tp] is InitVar` so it's impossible to retrieve field type",
)
def test_future_annotations_initvar():
foo = Foo.setup("--e 6")
mixilchenko marked this conversation as resolved.
Show resolved Hide resolved
assert foo.d == 8


def test_future_annotations_nested():
bar = Bar.setup()
assert bar == Bar()
Expand Down
34 changes: 34 additions & 0 deletions test/test_initvar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from dataclasses import dataclass, InitVar
import sys
from typing import Any, List, Tuple, Type

import pytest
from typing_extensions import Literal

from .testutils import TestSetup


@pytest.mark.skipif(
sys.version_info[:2] < (3, 8),
reason="Before 3.8 `InitVar[tp] is InitVar` so it's impossible to retrieve field type",
)
@pytest.mark.parametrize(
'tp, passed_value, expected',
[
(int, '1', 1),
(float, '1.4', 1.4),
(Tuple[int, float], '2 -1.2', (2, -1.2)),
(List[str], '12 abc', ['12', 'abc']),
(Literal[1, 2, 3, '4'], '1', 1),
(Literal[1, 2, 3, '4'], '4', '4'),
],
)
def test_initvar(tp: Type[Any], passed_value: str, expected: Any) -> None:
@dataclass
class Foo(TestSetup):
init_var: InitVar[tp]

def __post_init__(self, init_var: tp) -> None:
assert init_var == expected

Foo.setup(f"--init_var {passed_value}")