diff --git a/RELEASE.md b/RELEASE.md index 2fcf553bcf..ae0714b8d9 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -11,6 +11,7 @@ # Upcoming Release 0.18.13 ## Major features and improvements +* Allowed registering of custom resolvers to `OmegaConfigLoader` through `CONFIG_LOADER_ARGS`. ## Bug fixes and other changes diff --git a/kedro/config/omegaconf_config.py b/kedro/config/omegaconf_config.py index d7d9bd245b..4d2ace59d4 100644 --- a/kedro/config/omegaconf_config.py +++ b/kedro/config/omegaconf_config.py @@ -7,7 +7,7 @@ import logging import mimetypes from pathlib import Path -from typing import Any, Iterable +from typing import Any, Callable, Iterable import fsspec from omegaconf import OmegaConf @@ -82,6 +82,7 @@ def __init__( # noqa: too-many-arguments config_patterns: dict[str, list[str]] = None, base_env: str = "base", default_run_env: str = "local", + custom_resolvers: dict[str, Callable] = None, ): """Instantiates a ``OmegaConfigLoader``. @@ -97,6 +98,8 @@ def __init__( # noqa: too-many-arguments the configuration paths. default_run_env: Name of the default run environment. Defaults to `"local"`. Can be overridden by supplying the `env` argument. + custom_resolvers: A dictionary of custom resolvers to be registered. For more information, + see here: https://omegaconf.readthedocs.io/en/2.3_branch/custom_resolvers.html#custom-resolvers """ self.base_env = base_env self.default_run_env = default_run_env @@ -111,6 +114,9 @@ def __init__( # noqa: too-many-arguments # Deactivate oc.env built-in resolver for OmegaConf OmegaConf.clear_resolver("oc.env") + # Register user provided custom resolvers + if custom_resolvers: + self._register_new_resolvers(custom_resolvers) file_mimetype, _ = mimetypes.guess_type(conf_source) if file_mimetype == "application/x-tar": @@ -302,6 +308,15 @@ def _is_valid_config_path(self, path): ".json", ] + @staticmethod + def _register_new_resolvers(resolvers: dict[str, Callable]): + """Register custom resolvers""" + for name, resolver in resolvers.items(): + if not OmegaConf.has_resolver(name): + msg = f"Registering new custom resolver: {name}" + _config_logger.debug(msg) + OmegaConf.register_new_resolver(name=name, resolver=resolver) + @staticmethod def _check_duplicates(seen_files_to_keys: dict[Path, set[Any]]): duplicates = [] diff --git a/tests/config/test_omegaconf_config.py b/tests/config/test_omegaconf_config.py index dd49292019..af57b52224 100644 --- a/tests/config/test_omegaconf_config.py +++ b/tests/config/test_omegaconf_config.py @@ -649,3 +649,25 @@ def test_variable_interpolation_in_catalog_with_separate_templates_file( conf = OmegaConfigLoader(str(tmp_path)) conf.default_run_env = "" assert conf["catalog"]["companies"]["type"] == "pandas.CSVDataSet" + + def test_custom_resolvers(self, tmp_path): + base_params = tmp_path / _BASE_ENV / "parameters.yml" + param_config = { + "model_options": { + "param1": "${add: 3, 4}", + "param2": "${plus_2: 1}", + "param3": "${oc.env: VAR}", + } + } + _write_yaml(base_params, param_config) + custom_resolvers = { + "add": lambda *x: sum(x), + "plus_2": lambda x: x + 2, + "oc.env": oc.env, + } + os.environ["VAR"] = "my_env_variable" + conf = OmegaConfigLoader(tmp_path, custom_resolvers=custom_resolvers) + conf.default_run_env = "" + assert conf["parameters"]["model_options"]["param1"] == 7 + assert conf["parameters"]["model_options"]["param2"] == 3 + assert conf["parameters"]["model_options"]["param3"] == "my_env_variable"