Skip to content

Commit

Permalink
Deprecate get_resolver() and add new has_resolver() (#609)
Browse files Browse the repository at this point in the history
* Deprecate `get_resolver()` and add new `has_resolver()`

Fixes #608

* Refactor: move private functions at bottom of class

* Add test for coverage of deprecation warning
  • Loading branch information
odelalleau authored Mar 16, 2021
1 parent cf6c9a5 commit 4968d70
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 108 deletions.
1 change: 1 addition & 0 deletions news/608.api_change
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`OmegaConf.get_resolver()` is deprecated: use the new `OmegaConf.has_resolver()` to check for the existence of a resolver.
1 change: 1 addition & 0 deletions news/608.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
New function `OmegaConf.has_resolver()` allows checking whether a resolver has already been registered.
2 changes: 1 addition & 1 deletion omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def _evaluate_custom_resolver(
) -> Any:
from omegaconf import OmegaConf

resolver = OmegaConf.get_resolver(inter_type)
resolver = OmegaConf._get_resolver(inter_type)
if resolver is not None:
root_node = self._get_root()
return resolver(root_node, self, inter_args, inter_args_str)
Expand Down
220 changes: 121 additions & 99 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,74 +211,6 @@ def create( # noqa F811
flags=flags,
)

@staticmethod
def _create_impl( # noqa F811
obj: Any = _EMPTY_MARKER_,
parent: Optional[BaseContainer] = None,
flags: Optional[Dict[str, bool]] = None,
) -> Union[DictConfig, ListConfig]:
try:
from ._utils import get_yaml_loader
from .dictconfig import DictConfig
from .listconfig import ListConfig

if obj is _EMPTY_MARKER_:
obj = {}
if isinstance(obj, str):
obj = yaml.load(obj, Loader=get_yaml_loader())
if obj is None:
return OmegaConf.create({}, flags=flags)
elif isinstance(obj, str):
return OmegaConf.create({obj: None}, flags=flags)
else:
assert isinstance(obj, (list, dict))
return OmegaConf.create(obj, flags=flags)

else:
if (
is_primitive_dict(obj)
or OmegaConf.is_dict(obj)
or is_structured_config(obj)
or obj is None
):
if isinstance(obj, DictConfig):
key_type = obj._metadata.key_type
element_type = obj._metadata.element_type
else:
obj_type = OmegaConf.get_type(obj)
key_type, element_type = get_dict_key_value_types(obj_type)
return DictConfig(
content=obj,
parent=parent,
ref_type=Any,
key_type=key_type,
element_type=element_type,
flags=flags,
)
elif is_primitive_list(obj) or OmegaConf.is_list(obj):
obj_type = OmegaConf.get_type(obj)
element_type = get_list_element_type(obj_type)
return ListConfig(
element_type=element_type,
ref_type=Any,
content=obj,
parent=parent,
flags=flags,
)
else:
if isinstance(obj, type):
raise ValidationError(
f"Input class '{obj.__name__}' is not a structured config. "
"did you forget to decorate it as a dataclass?"
)
else:
raise ValidationError(
f"Object of unsupported type: '{type(obj).__name__}'"
)
except OmegaConfBaseException as e:
format_and_raise(node=None, key=None, value=None, msg=str(e), cause=e)
assert False

@staticmethod
def load(file_: Union[str, pathlib.Path, IO[Any]]) -> Union[DictConfig, ListConfig]:
from ._utils import get_yaml_loader
Expand Down Expand Up @@ -529,17 +461,26 @@ def resolver_wrapper(
# noinspection PyProtectedMember
BaseContainer._resolvers[name] = resolver_wrapper

@staticmethod
@classmethod
def has_resolver(cls, name: str) -> bool:
return cls._get_resolver(name) is not None

# DEPRECATED: remove in 2.2
@classmethod
def get_resolver(
cls,
name: str,
) -> Optional[
Callable[[Container, Container, Tuple[Any, ...], Tuple[str, ...]], Any]
]:
# noinspection PyProtectedMember
return (
BaseContainer._resolvers[name] if name in BaseContainer._resolvers else None
warnings.warn(
"`OmegaConf.get_resolver()` is deprecated (see https://github.com/omry/omegaconf/issues/608)",
UserWarning,
stacklevel=2,
)

return cls._get_resolver(name)

# noinspection PyProtectedMember
@staticmethod
def clear_resolvers() -> None:
Expand Down Expand Up @@ -720,33 +661,6 @@ def get_type(obj: Any, key: Optional[str] = None) -> Optional[Type[Any]]:
c = obj
return OmegaConf._get_obj_type(c)

@staticmethod
def _get_obj_type(c: Any) -> Optional[Type[Any]]:
if is_structured_config(c):
return get_type_of(c)
elif c is None:
return None
elif isinstance(c, DictConfig):
if c._is_none():
return None
elif c._is_missing():
return None
else:
if is_structured_config(c._metadata.object_type):
return c._metadata.object_type
else:
return dict
elif isinstance(c, ListConfig):
return list
elif isinstance(c, ValueNode):
return type(c._value())
elif isinstance(c, dict):
return dict
elif isinstance(c, (list, tuple)):
return list
else:
return get_type_of(c)

@staticmethod
def select(
cfg: Container,
Expand Down Expand Up @@ -870,6 +784,114 @@ def to_yaml(cfg: Any, *, resolve: bool = False, sort_keys: bool = False) -> str:
Dumper=get_omega_conf_dumper(),
)

# === private === #

@staticmethod
def _create_impl( # noqa F811
obj: Any = _EMPTY_MARKER_,
parent: Optional[BaseContainer] = None,
flags: Optional[Dict[str, bool]] = None,
) -> Union[DictConfig, ListConfig]:
try:
from ._utils import get_yaml_loader
from .dictconfig import DictConfig
from .listconfig import ListConfig

if obj is _EMPTY_MARKER_:
obj = {}
if isinstance(obj, str):
obj = yaml.load(obj, Loader=get_yaml_loader())
if obj is None:
return OmegaConf.create({}, flags=flags)
elif isinstance(obj, str):
return OmegaConf.create({obj: None}, flags=flags)
else:
assert isinstance(obj, (list, dict))
return OmegaConf.create(obj, flags=flags)

else:
if (
is_primitive_dict(obj)
or OmegaConf.is_dict(obj)
or is_structured_config(obj)
or obj is None
):
if isinstance(obj, DictConfig):
key_type = obj._metadata.key_type
element_type = obj._metadata.element_type
else:
obj_type = OmegaConf.get_type(obj)
key_type, element_type = get_dict_key_value_types(obj_type)
return DictConfig(
content=obj,
parent=parent,
ref_type=Any,
key_type=key_type,
element_type=element_type,
flags=flags,
)
elif is_primitive_list(obj) or OmegaConf.is_list(obj):
obj_type = OmegaConf.get_type(obj)
element_type = get_list_element_type(obj_type)
return ListConfig(
element_type=element_type,
ref_type=Any,
content=obj,
parent=parent,
flags=flags,
)
else:
if isinstance(obj, type):
raise ValidationError(
f"Input class '{obj.__name__}' is not a structured config. "
"did you forget to decorate it as a dataclass?"
)
else:
raise ValidationError(
f"Object of unsupported type: '{type(obj).__name__}'"
)
except OmegaConfBaseException as e:
format_and_raise(node=None, key=None, value=None, msg=str(e), cause=e)
assert False

@staticmethod
def _get_obj_type(c: Any) -> Optional[Type[Any]]:
if is_structured_config(c):
return get_type_of(c)
elif c is None:
return None
elif isinstance(c, DictConfig):
if c._is_none():
return None
elif c._is_missing():
return None
else:
if is_structured_config(c._metadata.object_type):
return c._metadata.object_type
else:
return dict
elif isinstance(c, ListConfig):
return list
elif isinstance(c, ValueNode):
return type(c._value())
elif isinstance(c, dict):
return dict
elif isinstance(c, (list, tuple)):
return list
else:
return get_type_of(c)

@staticmethod
def _get_resolver(
name: str,
) -> Optional[
Callable[[Container, Container, Tuple[Any, ...], Tuple[str, ...]], Any]
]:
# noinspection PyProtectedMember
return (
BaseContainer._resolvers[name] if name in BaseContainer._resolvers else None
)


# register all default resolvers
register_default_resolvers()
Expand Down
23 changes: 15 additions & 8 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,20 +363,27 @@ def foo() -> int:
OmegaConf.register_new_resolver("foo", lambda: 10)


def test_clear_resolvers(restore_resolvers: Any) -> None:
assert OmegaConf.get_resolver("foo") is None
def test_clear_resolvers_and_has_resolver(restore_resolvers: Any) -> None:
assert not OmegaConf.has_resolver("foo")
OmegaConf.register_new_resolver("foo", lambda x: x + 10)
assert OmegaConf.get_resolver("foo") is not None
assert OmegaConf.has_resolver("foo")
OmegaConf.clear_resolvers()
assert OmegaConf.get_resolver("foo") is None
assert not OmegaConf.has_resolver("foo")


def test_clear_resolvers_legacy(restore_resolvers: Any) -> None:
assert OmegaConf.get_resolver("foo") is None
def test_clear_resolvers_and_has_resolver_legacy(restore_resolvers: Any) -> None:
assert not OmegaConf.has_resolver("foo")
OmegaConf.legacy_register_resolver("foo", lambda x: int(x) + 10)
assert OmegaConf.get_resolver("foo") is not None
assert OmegaConf.has_resolver("foo")
OmegaConf.clear_resolvers()
assert OmegaConf.get_resolver("foo") is None
assert not OmegaConf.has_resolver("foo")


def test_get_resolver_deprecation() -> None:
with pytest.warns(
UserWarning, match=re.escape("https://github.com/omry/omegaconf/issues/608")
):
assert OmegaConf.get_resolver("foo") is None


def test_register_resolver_1(restore_resolvers: Any) -> None:
Expand Down

0 comments on commit 4968d70

Please sign in to comment.