Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OmegaConf.to_object handling of init=False fields #879

Merged
merged 2 commits into from
Mar 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions news/789.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`OmegaConf.to_object` now works properly with structured configs that have `init=False` fields
18 changes: 11 additions & 7 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,14 @@ def extract_dict_subclass_data(obj: Any, parent: Any) -> Optional[Dict[str, Any]
return None


def get_attr_class_field_names(obj: Any) -> List[str]:
def get_attr_class_init_field_names(obj: Any) -> List[str]:
is_type = isinstance(obj, type)
obj_type = obj if is_type else type(obj)
return list(attr.fields_dict(obj_type))
return [
fieldname
for fieldname, attribute in attr.fields_dict(obj_type).items()
if attribute.init
]


def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, Any]:
Expand Down Expand Up @@ -321,8 +325,8 @@ def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, A
return d


def get_dataclass_field_names(obj: Any) -> List[str]:
return [field.name for field in dataclasses.fields(obj)]
def get_dataclass_init_field_names(obj: Any) -> List[str]:
return [field.name for field in dataclasses.fields(obj) if field.init]


def get_dataclass_data(
Expand Down Expand Up @@ -421,11 +425,11 @@ def is_structured_config_frozen(obj: Any) -> bool:
return False


def get_structured_config_field_names(obj: Any) -> List[str]:
def get_structured_config_init_field_names(obj: Any) -> List[str]:
if is_dataclass(obj):
return get_dataclass_field_names(obj)
return get_dataclass_init_field_names(obj)
elif is_attr_class(obj):
return get_attr_class_field_names(obj)
return get_attr_class_init_field_names(obj)
else:
raise ValueError(f"Unsupported type: {type(obj).__name__}")

Expand Down
20 changes: 11 additions & 9 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
_valid_dict_key_annotation_type,
format_and_raise,
get_structured_config_data,
get_structured_config_field_names,
get_structured_config_init_field_names,
get_type_of,
get_value_kind,
is_container_annotation,
Expand Down Expand Up @@ -727,10 +727,10 @@ def _to_object(self) -> Any:

object_type = self._metadata.object_type
assert is_structured_config(object_type)
object_type_field_names = set(get_structured_config_field_names(object_type))
init_field_names = set(get_structured_config_init_field_names(object_type))

field_items: Dict[str, Any] = {}
nonfield_items: Dict[str, Any] = {}
init_field_items: Dict[str, Any] = {}
non_init_field_items: Dict[str, Any] = {}
for k in self.keys():
assert isinstance(k, str)
node = self._get_node(k)
Expand All @@ -740,6 +740,8 @@ def _to_object(self) -> Any:
except InterpolationResolutionError as e:
self._format_and_raise(key=k, value=None, cause=e)
if node._is_missing():
if k not in init_field_names:
continue # MISSING is ignored for init=False fields
self._format_and_raise(
key=k,
value=None,
Expand All @@ -752,12 +754,12 @@ def _to_object(self) -> Any:
else:
v = node._value()

if k in object_type_field_names:
field_items[k] = v
if k in init_field_names:
init_field_items[k] = v
else:
nonfield_items[k] = v
non_init_field_items[k] = v

result = object_type(**field_items)
for k, v in nonfield_items.items():
result = object_type(**init_field_items)
for k, v in non_init_field_items.items():
setattr(result, k, v)
return result
10 changes: 10 additions & 0 deletions tests/structured_conf/data/attr_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,3 +607,13 @@ class ParentContainers:
class ChildContainers(ParentContainers):
list1: List[int] = [1, 2, 3]
dict: Dict[str, Any] = {"a": 5, "b": 6}


@attr.s(auto_attribs=True)
class HasInitFalseFields:
post_initialized: str = attr.field(init=False)
without_default: str = attr.field(init=False)
with_default: str = attr.field(init=False, default="default")

def __attrs_post_init__(self) -> None:
self.post_initialized = "set_by_post_init"
10 changes: 10 additions & 0 deletions tests/structured_conf/data/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,3 +628,13 @@ class ParentContainers:
class ChildContainers(ParentContainers):
list1: List[int] = field(default_factory=lambda: [1, 2, 3])
dict: Dict[str, Any] = field(default_factory=lambda: {"a": 5, "b": 6})


@dataclass
class HasInitFalseFields:
post_initialized: str = field(init=False)
without_default: str = field(init=False)
with_default: str = field(init=False, default="default")

def __post_init__(self) -> None:
self.post_initialized = "set_by_post_init"
36 changes: 36 additions & 0 deletions tests/test_to_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,42 @@ def test_setattr_for_user_with_extra_field(self, module: Any) -> None:
assert type(user) is module.User
assert user.extra_field == 123

def test_init_false_with_default(self, module: Any) -> None:
cfg = OmegaConf.structured(module.HasInitFalseFields)
assert cfg.with_default == "default"
data = self.round_trip_to_object(cfg)
assert data.with_default == "default"

def test_init_false_with_default_overridden(self, module: Any) -> None:
cfg = OmegaConf.structured(module.HasInitFalseFields)
cfg.with_default = "default_overridden"
data = self.round_trip_to_object(cfg)
assert data.with_default == "default_overridden"

def test_init_false_without_default(self, module: Any) -> None:
cfg = OmegaConf.structured(module.HasInitFalseFields)
assert OmegaConf.is_missing(cfg, "without_default")
data = self.round_trip_to_object(cfg)
assert not hasattr(data, "without_default")

def test_init_false_without_default_overridden(self, module: Any) -> None:
cfg = OmegaConf.structured(module.HasInitFalseFields)
cfg.with_default = "default_overridden"
data = self.round_trip_to_object(cfg)
assert data.with_default == "default_overridden"

def test_init_false_post_initialized(self, module: Any) -> None:
cfg = OmegaConf.structured(module.HasInitFalseFields)
assert OmegaConf.is_missing(cfg, "post_initialized")
data = self.round_trip_to_object(cfg)
assert data.post_initialized == "set_by_post_init"

def test_init_false_post_initialized_overridden(self, module: Any) -> None:
cfg = OmegaConf.structured(module.HasInitFalseFields)
cfg.post_initialized = "overridden"
data = self.round_trip_to_object(cfg)
assert data.post_initialized == "overridden"


class TestEnumToStr:
"""Test the `enum_to_str` argument to the `OmegaConf.to_container function`"""
Expand Down
6 changes: 4 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class _TestDataclass:
e: _TestEnum = _TestEnum.A
list1: List[int] = field(default_factory=list)
dict1: Dict[str, int] = field(default_factory=dict)
init_false: str = field(init=False, default="foo")


@attr.s(auto_attribs=True)
Expand All @@ -163,6 +164,7 @@ class _TestAttrsClass:
e: _TestEnum = _TestEnum.A
list1: List[int] = []
dict1: Dict[str, int] = {}
init_false: str = attr.field(init=False, default="foo")


@dataclass
Expand Down Expand Up @@ -226,12 +228,12 @@ def test_get_structured_config_data_throws_ValueError(self) -> None:
[_TestDataclass, _TestDataclass(), _TestAttrsClass, _TestAttrsClass()],
)
def test_get_structured_config_field_names(self, test_cls_or_obj: Any) -> None:
field_names = _utils.get_structured_config_field_names(test_cls_or_obj)
field_names = _utils.get_structured_config_init_field_names(test_cls_or_obj)
assert field_names == ["x", "s", "b", "d", "f", "e", "list1", "dict1"]

def test_get_structured_config_field_names_throws_ValueError(self) -> None:
with raises(ValueError):
_utils.get_structured_config_field_names("invalid")
_utils.get_structured_config_init_field_names("invalid")


@mark.parametrize(
Expand Down