diff --git a/src/rez/config.py b/src/rez/config.py index 964200ff59..6e63720aa4 100644 --- a/src/rez/config.py +++ b/src/rez/config.py @@ -3,6 +3,7 @@ from rez.utils.data_utils import AttrDictWrapper, RO_AttrDictWrapper, \ convert_dicts, cached_property, cached_class_property, LazyAttributeMeta, \ deep_update, ModifyList +from rez.utils.datatypes import HashableDict from rez.utils.formatting import expandvars, expanduser from rez.utils.logging_ import get_debug_printer from rez.utils.scope import scoped_format @@ -819,7 +820,7 @@ def _replace_config(other): @lru_cache() -def _load_config_py(filepath): +def _load_config_py(filepath, fallback_platform_map): from rez.utils.data_utils import Conditional, PlatformDependent, \ InConfigArchDependent, InConfigOsDependent reserved = dict( @@ -828,13 +829,14 @@ def _load_config_py(filepath): # and later excluded from the `Config` class __name__=os.path.splitext(os.path.basename(filepath))[0], __file__=filepath, + __fallback_platform_map=fallback_platform_map, rez_version=__version__, ModifyList=ModifyList, Conditional=Conditional, PlatformDependent=PlatformDependent, ArchDependent=InConfigArchDependent, - OsDependent=InConfigOsDependent + OsDependent=InConfigOsDependent, ) g = reserved.copy() @@ -851,14 +853,15 @@ def _load_config_py(filepath): for k, v in g.items(): if k != '__builtins__' \ and not ismodule(v) \ - and k not in reserved: + and k not in reserved \ + and k != "__fallback_platform_map": result[k] = v return result @lru_cache() -def _load_config_yaml(filepath): +def _load_config_yaml(filepath, _): with open(filepath) as f: content = f.read() try: @@ -890,7 +893,11 @@ def _load_config_from_filepaths(filepaths): if not os.path.isfile(filepath_with_ext): continue - data_ = loader(filepath_with_ext) + previous_platform_map = data.get("platform_map", None) + if previous_platform_map is not None: + previous_platform_map = HashableDict(previous_platform_map) + + data_ = loader(filepath_with_ext, previous_platform_map) deep_update(data, data_) sourced_filepaths.append(filepath_with_ext) break diff --git a/src/rez/tests/data/config/test_conditional_dependee.py b/src/rez/tests/data/config/test_conditional_dependee.py new file mode 100644 index 0000000000..71fcc3860c --- /dev/null +++ b/src/rez/tests/data/config/test_conditional_dependee.py @@ -0,0 +1,4 @@ + +dot_image_format = ArchDependent({ + "IMPOSSIBLE_ARCH": "hello", +}) diff --git a/src/rez/tests/test_config.py b/src/rez/tests/test_config.py index 011139d8ce..bde32d2f1c 100644 --- a/src/rez/tests/test_config.py +++ b/src/rez/tests/test_config.py @@ -139,6 +139,19 @@ def test_conditional_in_file(self): self.assertEqual(c.prune_failed_graph, True) self.assertEqual(c.warn_all, True) + def test_conidition_nested(self): + conf = os.path.join(self.config_path, "test_conditional.py") + conf_dependee = os.path.join(self.config_path, "test_conditional_dependee.py") + c = Config([conf, conf_dependee]) + self.assertEqual(c.dot_image_format, "hello") + + def test_conidition_nested_inbeween(self): + conf = os.path.join(self.config_path, "test_conditional.py") + conf_middle = os.path.join(self.config_path, "test2.py") + conf_dependee = os.path.join(self.config_path, "test_conditional_dependee.py") + c = Config([conf, conf_middle, conf_dependee]) + self.assertEqual(c.dot_image_format, "hello") + def test_1(self): """Test just the root config file.""" diff --git a/src/rez/utils/data_utils.py b/src/rez/utils/data_utils.py index 395ccf3534..f1ce0f8805 100644 --- a/src/rez/utils/data_utils.py +++ b/src/rez/utils/data_utils.py @@ -683,7 +683,10 @@ def __new__(cls, base, frame, options, default=ConditionalConfigurationError): # a global config. platform_map = frame.f_locals.get( "platform_map", - None + frame.f_locals.get( + "__fallback_platform_map", + None + ) ) return base(options, default, platform_map=platform_map) diff --git a/src/rez/utils/datatypes.py b/src/rez/utils/datatypes.py new file mode 100644 index 0000000000..09ba538fb7 --- /dev/null +++ b/src/rez/utils/datatypes.py @@ -0,0 +1,26 @@ + + +class HashableDict(dict): + """ + Hashable dict (immutable by definition) + """ + + def __init__(self, other=None): + super(HashableDict, self).__init__() + if other: + for k, v in other.items(): + if isinstance(v, dict): + self[k] = HashableDict(v) + else: + self[k] = v + + def __key(self): + return tuple((k,self[k]) for k in sorted(self)) + + def __hash__(self): + return hash(self.__key()) + + def __eq__(self, other): + return self.__key() == other.__key() + +