Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli committed Sep 7, 2023
1 parent 8896b22 commit b34f025
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 25 deletions.
26 changes: 10 additions & 16 deletions monai/bundle/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ def __getitem__(self, id: str | int) -> Any:
if id == "":
return self.config
config = self.config
id = str(id).replace("#", ID_SEP_KEY)
for k in str(id).split(ID_SEP_KEY):
for k in ReferenceResolver.split_id(id):
if not isinstance(config, (dict, list)):
raise ValueError(f"config must be dict or list for key `{k}`, but got {type(config)}: {config}.")
try:
Expand Down Expand Up @@ -179,13 +178,11 @@ def __setitem__(self, id: str | int, config: Any) -> None:
self.config = config
self.ref_resolver.reset()
return
id = str(id).replace("#", ID_SEP_KEY)
keys = id.split(ID_SEP_KEY)
last_id, base_id = ReferenceResolver.split_id(id, last=True)
# get the last parent level config item and replace it
last_id = ID_SEP_KEY.join(keys[:-1])
conf_ = self[last_id]

indexing = keys[-1] if isinstance(conf_, dict) else int(keys[-1])
indexing = base_id if isinstance(conf_, dict) else int(base_id)
conf_[indexing] = config
self.ref_resolver.reset()
return
Expand Down Expand Up @@ -215,8 +212,7 @@ def set(self, config: Any, id: str = "", recursive: bool = True) -> None:
default to `True`. for the nested id, only support `dict` for the missing section.
"""
id = str(id).replace("#", ID_SEP_KEY)
keys = id.split(ID_SEP_KEY)
keys = ReferenceResolver.split_id(id)
conf_ = self.get()
if recursive:
if conf_ is None:
Expand All @@ -225,7 +221,7 @@ def set(self, config: Any, id: str = "", recursive: bool = True) -> None:
if isinstance(conf_, dict) and k not in conf_:
conf_[k] = {}
conf_ = conf_[k if isinstance(conf_, dict) else int(k)]
self[id] = config
self[ReferenceResolver.normalize_id(id)] = config

def update(self, pairs: dict[str, Any]) -> None:
"""
Expand Down Expand Up @@ -340,9 +336,8 @@ def _do_resolve(self, config: Any, id: str = "") -> Any:
"""
if isinstance(config, (dict, list)):
for k, v in enumerate(config) if isinstance(config, list) else config.items():
sub_id = f"{id}{ID_SEP_KEY}{k}" if id != "" else k
config[k] = self._do_resolve(v, sub_id)
for k, sub_id, v in self.ref_resolver.iter_subconfigs(id=id, config=config):
config[k] = self._do_resolve(v, sub_id) # type: ignore
if isinstance(config, str):
config = self.resolve_relative_ids(id, config)
if config.startswith(MACRO_KEY):
Expand Down Expand Up @@ -375,8 +370,7 @@ def _do_parse(self, config: Any, id: str = "") -> None:
"""
if isinstance(config, (dict, list)):
for k, v in enumerate(config) if isinstance(config, list) else config.items():
sub_id = f"{id}{ID_SEP_KEY}{k}" if id != "" else k
for _, sub_id, v in self.ref_resolver.iter_subconfigs(id=id, config=config):
self._do_parse(config=v, id=sub_id)

if ConfigComponent.is_instantiable(config):
Expand Down Expand Up @@ -461,7 +455,7 @@ def split_path_id(cls, src: str) -> tuple[str, str]:
src: source string to split.
"""
src = str(src).replace("#", ID_SEP_KEY)
src = ReferenceResolver.normalize_id(src)
result = re.compile(rf"({cls.suffix_match}(?=(?:{ID_SEP_KEY}.*)|$))", re.IGNORECASE).findall(src)
if not result:
return "", src # the src is a pure id
Expand Down Expand Up @@ -492,7 +486,7 @@ def resolve_relative_ids(cls, id: str, value: str) -> str:
"""
# get the prefixes like: "@####", "%###", "@#"
value = str(value).replace("#", ID_SEP_KEY)
value = ReferenceResolver.normalize_id(value)
prefixes = sorted(set().union(cls.relative_id_prefix.findall(value)), reverse=True)
current_id = id.split(ID_SEP_KEY)

Expand Down
54 changes: 45 additions & 9 deletions monai/bundle/reference_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import re
import warnings
from collections.abc import Sequence
from typing import Any
from typing import Any, Iterator

from monai.bundle.config_item import ConfigComponent, ConfigExpression, ConfigItem
from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY
Expand Down Expand Up @@ -101,7 +101,7 @@ def get_item(self, id: str, resolve: bool = False, **kwargs: Any) -> ConfigItem
"""
if resolve and id not in self.resolved_content:
self._resolve_one_item(id=id, **kwargs)
id = str(id).replace("#", self.sep)
id = self.normalize_id(id)
return self.items.get(id)

def _resolve_one_item(
Expand All @@ -122,7 +122,7 @@ def _resolve_one_item(
if the `id` is not in the config content, must be a `ConfigItem` object.
"""
id = str(id).replace("#", self.sep)
id = self.normalize_id(id)
if id in self.resolved_content:
return self.resolved_content[id]
try:
Expand Down Expand Up @@ -192,6 +192,44 @@ def get_resolved_content(self, id: str, **kwargs: Any) -> ConfigExpression | str
"""
return self._resolve_one_item(id=id, **kwargs)

@classmethod
def normalize_id(cls, id: str | int) -> str:
"""
Normalize the id string to consistently use `cls.sep`.
Args:
id: id string to be normalized.
"""
return str(id).replace("#", cls.sep) # backward compatibility `#` is the old separator

@classmethod
def split_id(cls, id: str | int, last: bool = False) -> list[str]:
"""
Split the id string into a tuple of strings.
Args:
id: id string to be split.
last: whether to split the rightmost part of the id. default is False (split all parts).
"""
if not last:
return cls.normalize_id(id).split(cls.sep)
res = cls.normalize_id(id).rsplit(cls.sep, 1)
return ["".join(res[:-1]), res[-1]]

@classmethod
def iter_subconfigs(cls, id: str, config: Any) -> Iterator[tuple[str, str, Any]]:
"""
Iterate over the sub-configs of the input config.
Args:
id: id string of the current input config.
config: input config to be iterated.
"""
for k, v in config.items() if isinstance(config, dict) else enumerate(config):
sub_id = f"{id}{cls.sep}{k}" if id != "" else f"{k}"
yield k, sub_id, v

@classmethod
def match_refs_pattern(cls, value: str) -> dict[str, int]:
"""
Expand All @@ -204,7 +242,7 @@ def match_refs_pattern(cls, value: str) -> dict[str, int]:
"""
refs: dict[str, int] = {}
# regular expression pattern to match "@XXX" or "@XXX#YYY"
value = str(value).replace("#", cls.sep)
value = cls.normalize_id(value)
result = cls.id_matcher.findall(value)
value_is_expr = ConfigExpression.is_expression(value)
for item in result:
Expand All @@ -227,7 +265,7 @@ def update_refs_pattern(cls, value: str, refs: dict) -> str:
"""
# regular expression pattern to match "@XXX" or "@XXX#YYY"
value = str(value).replace("#", cls.sep)
value = cls.normalize_id(value)
result = cls.id_matcher.findall(value)
# reversely sort the matched references by length
# and handle the longer first in case a reference item is substring of another longer item
Expand Down Expand Up @@ -273,8 +311,7 @@ def find_refs_in_config(cls, config: Any, id: str, refs: dict[str, int] | None =
refs_[id] = refs_.get(id, 0) + count
if not isinstance(config, (list, dict)):
return refs_
for k, v in config.items() if isinstance(config, dict) else enumerate(config):
sub_id = f"{id}{cls.sep}{k}" if id != "" else f"{k}"
for _, sub_id, v in cls.iter_subconfigs(id, config):
if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v) and sub_id not in refs_:
refs_[sub_id] = 1
refs_ = cls.find_refs_in_config(v, sub_id, refs_)
Expand All @@ -298,8 +335,7 @@ def update_config_with_refs(cls, config: Any, id: str, refs: dict | None = None)
if not isinstance(config, (list, dict)):
return config
ret = type(config)()
for idx, v in config.items() if isinstance(config, dict) else enumerate(config):
sub_id = f"{id}{cls.sep}{idx}" if id != "" else f"{idx}"
for idx, sub_id, v in cls.iter_subconfigs(id, config):
if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v):
updated = refs_[sub_id]
if ConfigComponent.is_instantiable(v) and updated is None:
Expand Down

0 comments on commit b34f025

Please sign in to comment.