Skip to content

Commit

Permalink
support dataclasses.InitVar for python>=3.8 (#171)
Browse files Browse the repository at this point in the history
  • Loading branch information
mixilchenko authored and lebrice committed Dec 8, 2022
1 parent e06b016 commit bffb1b8
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 5 deletions.
29 changes: 26 additions & 3 deletions simple_parsing/annotation_utils/get_field_annotations.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import collections
from contextlib import contextmanager
from dataclasses import InitVar
import inspect
import sys
import types
import typing
from logging import getLogger as get_logger
from typing import Any, Dict, Optional, get_type_hints
from typing import Any, Dict, Iterator, Optional, get_type_hints

logger = get_logger(__name__)

Expand All @@ -19,6 +21,25 @@
}


@contextmanager
def _initvar_patcher() -> Iterator[None]:
"""
Patch InitVar to not fail when annotations are postponed.
`TypeVar('Forward references must evaluate to types. Got dataclasses.InitVar[tp].')` is raised
when postponed annotations are enabled and `get_type_hints` is called
Bug is mentioned here https://github.com/python/cpython/issues/88962
In python 3.11 this is fixed, but backport fix is not planned for old releases
Workaround is mentioned here https://stackoverflow.com/q/70400639
"""
if sys.version_info[:2] < (3, 11):
InitVar.__call__ = lambda *args: None
yield
if sys.version_info[:2] < (3, 11):
del InitVar.__call__


def evaluate_string_annotation(annotation: str, containing_class: Optional[type] = None) -> type:
"""Attempts to evaluate the given annotation string, to get a 'live' type annotation back.
Expand Down Expand Up @@ -160,7 +181,8 @@ def get_field_type_from_annotations(some_class: type, field_name: str) -> type:
global_ns = sys.modules[some_class.__module__].__dict__

try:
annotations_dict = get_type_hints(some_class, localns=local_ns, globalns=global_ns)
with _initvar_patcher():
annotations_dict = get_type_hints(some_class, localns=local_ns, globalns=global_ns)
except TypeError:
annotations_dict = collections.ChainMap(
*[getattr(cls, "__annotations__", {}) for cls in some_class.mro()]
Expand Down Expand Up @@ -197,7 +219,8 @@ class Temp_:
pass

Temp_.__annotations__ = {field_name: field_type}
annotations_dict = get_type_hints(Temp_, globalns=global_ns, localns=local_ns)
with _initvar_patcher():
annotations_dict = get_type_hints(Temp_, globalns=global_ns, localns=local_ns)
field_type = annotations_dict[field_name]
except Exception:
logger.warning(
Expand Down
21 changes: 20 additions & 1 deletion simple_parsing/wrappers/dataclass_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import dataclasses
from dataclasses import MISSING
from logging import getLogger
import sys
from typing import cast

from .. import docstring, utils
Expand Down Expand Up @@ -48,7 +49,25 @@ def __init__(
self.defaults = [default]
self.optional: bool = False

for field in dataclasses.fields(self.dataclass):
# NOTE: `dataclasses.fields` method retrieves only `dataclasses._FIELD`
# NOTE: but we also want to know about `dataclasses._FIELD_INITVAR`
# NOTE: therefore we partly copy-paste its implementation
if sys.version_info[:2] < (3, 8):
# Before 3.8 `InitVar[tp] is InitVar` so it's impossible to retrieve field type
# therefore we should skip it just to be fully backward compatible
dataclass_fields = dataclasses.fields(self.dataclass)
else:
try:
dataclass_fields_map = getattr(self.dataclass, dataclasses._FIELDS)
except AttributeError:
raise TypeError('must be called with a dataclass type or instance')
dataclass_fields = tuple(
field
for field in dataclass_fields_map.values()
if field._field_type in (dataclasses._FIELD, dataclasses._FIELD_INITVAR)
)

for field in dataclass_fields:
if not field.init or field.metadata.get("cmd", True) is False:
# Don't add arguments for these fields.
continue
Expand Down
2 changes: 2 additions & 0 deletions simple_parsing/wrappers/field_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,8 @@ def type(self) -> Type[Any]:

field_type = get_field_type_from_annotations(self.parent.dataclass, self.field.name)
self._type = field_type
elif isinstance(self._type, dataclasses.InitVar):
self._type = self._type.type
return self._type

def __str__(self):
Expand Down
18 changes: 17 additions & 1 deletion test/test_future_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import dataclasses
import typing
from dataclasses import dataclass
from dataclasses import InitVar, dataclass
import sys
from typing import Any, Callable, Generic, TypeVar

import pytest
Expand All @@ -27,6 +28,12 @@ class Foo(TestSetup):

d: list[bool] = field(default_factory=list)

e: InitVar[int] = 5
d: int = field(init=False)

def __post_init__(self, e: int) -> None:
self.d = e + 2


@dataclass
class Bar(TestSetup):
Expand All @@ -44,6 +51,15 @@ def test_future_annotations():
assert foo == Foo(a=2, b="heyo", c=(1, 7.89))


@pytest.mark.skipif(
sys.version_info[:2] < (3, 8),
reason="Before 3.8 `InitVar[tp] is InitVar` so it's impossible to retrieve field type",
)
def test_future_annotations_initvar():
foo = Foo.setup("--e 6")
assert foo.d == 8


def test_future_annotations_nested():
bar = Bar.setup()
assert bar == Bar()
Expand Down
34 changes: 34 additions & 0 deletions test/test_initvar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from dataclasses import dataclass, InitVar
import sys
from typing import Any, List, Tuple, Type

import pytest
from typing_extensions import Literal

from .testutils import TestSetup


@pytest.mark.skipif(
sys.version_info[:2] < (3, 8),
reason="Before 3.8 `InitVar[tp] is InitVar` so it's impossible to retrieve field type",
)
@pytest.mark.parametrize(
'tp, passed_value, expected',
[
(int, '1', 1),
(float, '1.4', 1.4),
(Tuple[int, float], '2 -1.2', (2, -1.2)),
(List[str], '12 abc', ['12', 'abc']),
(Literal[1, 2, 3, '4'], '1', 1),
(Literal[1, 2, 3, '4'], '4', '4'),
],
)
def test_initvar(tp: Type[Any], passed_value: str, expected: Any) -> None:
@dataclass
class Foo(TestSetup):
init_var: InitVar[tp]

def __post_init__(self, init_var: tp) -> None:
assert init_var == expected

Foo.setup(f"--init_var {passed_value}")

0 comments on commit bffb1b8

Please sign in to comment.