From 61f3af5dd6ab0fbe9a93946fe446ab3ba7b2ee57 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 23 Mar 2022 23:35:37 -0500 Subject: [PATCH 1/2] implement OmegaConf.to_object handling of init=False fields --- news/789.bugfix | 1 + omegaconf/_utils.py | 18 +++++++++++------- omegaconf/dictconfig.py | 20 +++++++++++--------- 3 files changed, 23 insertions(+), 16 deletions(-) create mode 100644 news/789.bugfix diff --git a/news/789.bugfix b/news/789.bugfix new file mode 100644 index 000000000..2f20ac81d --- /dev/null +++ b/news/789.bugfix @@ -0,0 +1 @@ +`OmegaConf.to_object` now works properly with structured configs that have `init=False` fields diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index 891e72b49..c6be19274 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -268,10 +268,14 @@ def extract_dict_subclass_data(obj: Any, parent: Any) -> Optional[Dict[str, Any] return None -def get_attr_class_field_names(obj: Any) -> List[str]: +def get_attr_class_init_field_names(obj: Any) -> List[str]: is_type = isinstance(obj, type) obj_type = obj if is_type else type(obj) - return list(attr.fields_dict(obj_type)) + return [ + fieldname + for fieldname, attribute in attr.fields_dict(obj_type).items() + if attribute.init + ] def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, Any]: @@ -321,8 +325,8 @@ def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, A return d -def get_dataclass_field_names(obj: Any) -> List[str]: - return [field.name for field in dataclasses.fields(obj)] +def get_dataclass_init_field_names(obj: Any) -> List[str]: + return [field.name for field in dataclasses.fields(obj) if field.init] def get_dataclass_data( @@ -421,11 +425,11 @@ def is_structured_config_frozen(obj: Any) -> bool: return False -def get_structured_config_field_names(obj: Any) -> List[str]: +def get_structured_config_init_field_names(obj: Any) -> List[str]: if is_dataclass(obj): - return get_dataclass_field_names(obj) + return get_dataclass_init_field_names(obj) elif is_attr_class(obj): - return get_attr_class_field_names(obj) + return get_attr_class_init_field_names(obj) else: raise ValueError(f"Unsupported type: {type(obj).__name__}") diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index c8f121911..08d23ea41 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -28,7 +28,7 @@ _valid_dict_key_annotation_type, format_and_raise, get_structured_config_data, - get_structured_config_field_names, + get_structured_config_init_field_names, get_type_of, get_value_kind, is_container_annotation, @@ -727,10 +727,10 @@ def _to_object(self) -> Any: object_type = self._metadata.object_type assert is_structured_config(object_type) - object_type_field_names = set(get_structured_config_field_names(object_type)) + init_field_names = set(get_structured_config_init_field_names(object_type)) - field_items: Dict[str, Any] = {} - nonfield_items: Dict[str, Any] = {} + init_field_items: Dict[str, Any] = {} + non_init_field_items: Dict[str, Any] = {} for k in self.keys(): assert isinstance(k, str) node = self._get_node(k) @@ -740,6 +740,8 @@ def _to_object(self) -> Any: except InterpolationResolutionError as e: self._format_and_raise(key=k, value=None, cause=e) if node._is_missing(): + if k not in init_field_names: + continue # MISSING is ignored for init=False fields self._format_and_raise( key=k, value=None, @@ -752,12 +754,12 @@ def _to_object(self) -> Any: else: v = node._value() - if k in object_type_field_names: - field_items[k] = v + if k in init_field_names: + init_field_items[k] = v else: - nonfield_items[k] = v + non_init_field_items[k] = v - result = object_type(**field_items) - for k, v in nonfield_items.items(): + result = object_type(**init_field_items) + for k, v in non_init_field_items.items(): setattr(result, k, v) return result From 7f47d53ab3f94a64dc98b48b7293cd9b2ea7910a Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 23 Mar 2022 23:37:02 -0500 Subject: [PATCH 2/2] tests for OmegaConf.to_object handling of init=False fields --- tests/structured_conf/data/attr_classes.py | 10 ++++++ tests/structured_conf/data/dataclasses.py | 10 ++++++ tests/test_to_container.py | 36 ++++++++++++++++++++++ tests/test_utils.py | 6 ++-- 4 files changed, 60 insertions(+), 2 deletions(-) diff --git a/tests/structured_conf/data/attr_classes.py b/tests/structured_conf/data/attr_classes.py index a834e4e02..fb7b0999f 100644 --- a/tests/structured_conf/data/attr_classes.py +++ b/tests/structured_conf/data/attr_classes.py @@ -607,3 +607,13 @@ class ParentContainers: class ChildContainers(ParentContainers): list1: List[int] = [1, 2, 3] dict: Dict[str, Any] = {"a": 5, "b": 6} + + +@attr.s(auto_attribs=True) +class HasInitFalseFields: + post_initialized: str = attr.field(init=False) + without_default: str = attr.field(init=False) + with_default: str = attr.field(init=False, default="default") + + def __attrs_post_init__(self) -> None: + self.post_initialized = "set_by_post_init" diff --git a/tests/structured_conf/data/dataclasses.py b/tests/structured_conf/data/dataclasses.py index 146a82442..442cd06d4 100644 --- a/tests/structured_conf/data/dataclasses.py +++ b/tests/structured_conf/data/dataclasses.py @@ -628,3 +628,13 @@ class ParentContainers: class ChildContainers(ParentContainers): list1: List[int] = field(default_factory=lambda: [1, 2, 3]) dict: Dict[str, Any] = field(default_factory=lambda: {"a": 5, "b": 6}) + + +@dataclass +class HasInitFalseFields: + post_initialized: str = field(init=False) + without_default: str = field(init=False) + with_default: str = field(init=False, default="default") + + def __post_init__(self) -> None: + self.post_initialized = "set_by_post_init" diff --git a/tests/test_to_container.py b/tests/test_to_container.py index 5d5db7ad2..518f2a1fc 100644 --- a/tests/test_to_container.py +++ b/tests/test_to_container.py @@ -430,6 +430,42 @@ def test_setattr_for_user_with_extra_field(self, module: Any) -> None: assert type(user) is module.User assert user.extra_field == 123 + def test_init_false_with_default(self, module: Any) -> None: + cfg = OmegaConf.structured(module.HasInitFalseFields) + assert cfg.with_default == "default" + data = self.round_trip_to_object(cfg) + assert data.with_default == "default" + + def test_init_false_with_default_overridden(self, module: Any) -> None: + cfg = OmegaConf.structured(module.HasInitFalseFields) + cfg.with_default = "default_overridden" + data = self.round_trip_to_object(cfg) + assert data.with_default == "default_overridden" + + def test_init_false_without_default(self, module: Any) -> None: + cfg = OmegaConf.structured(module.HasInitFalseFields) + assert OmegaConf.is_missing(cfg, "without_default") + data = self.round_trip_to_object(cfg) + assert not hasattr(data, "without_default") + + def test_init_false_without_default_overridden(self, module: Any) -> None: + cfg = OmegaConf.structured(module.HasInitFalseFields) + cfg.with_default = "default_overridden" + data = self.round_trip_to_object(cfg) + assert data.with_default == "default_overridden" + + def test_init_false_post_initialized(self, module: Any) -> None: + cfg = OmegaConf.structured(module.HasInitFalseFields) + assert OmegaConf.is_missing(cfg, "post_initialized") + data = self.round_trip_to_object(cfg) + assert data.post_initialized == "set_by_post_init" + + def test_init_false_post_initialized_overridden(self, module: Any) -> None: + cfg = OmegaConf.structured(module.HasInitFalseFields) + cfg.post_initialized = "overridden" + data = self.round_trip_to_object(cfg) + assert data.post_initialized == "overridden" + class TestEnumToStr: """Test the `enum_to_str` argument to the `OmegaConf.to_container function`""" diff --git a/tests/test_utils.py b/tests/test_utils.py index be266f3e5..b3fafbbe5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -151,6 +151,7 @@ class _TestDataclass: e: _TestEnum = _TestEnum.A list1: List[int] = field(default_factory=list) dict1: Dict[str, int] = field(default_factory=dict) + init_false: str = field(init=False, default="foo") @attr.s(auto_attribs=True) @@ -163,6 +164,7 @@ class _TestAttrsClass: e: _TestEnum = _TestEnum.A list1: List[int] = [] dict1: Dict[str, int] = {} + init_false: str = attr.field(init=False, default="foo") @dataclass @@ -226,12 +228,12 @@ def test_get_structured_config_data_throws_ValueError(self) -> None: [_TestDataclass, _TestDataclass(), _TestAttrsClass, _TestAttrsClass()], ) def test_get_structured_config_field_names(self, test_cls_or_obj: Any) -> None: - field_names = _utils.get_structured_config_field_names(test_cls_or_obj) + field_names = _utils.get_structured_config_init_field_names(test_cls_or_obj) assert field_names == ["x", "s", "b", "d", "f", "e", "list1", "dict1"] def test_get_structured_config_field_names_throws_ValueError(self) -> None: with raises(ValueError): - _utils.get_structured_config_field_names("invalid") + _utils.get_structured_config_init_field_names("invalid") @mark.parametrize(