Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 66 additions & 18 deletions monai/bundle/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem
from monai.bundle.reference_resolver import ReferenceResolver
from monai.bundle.utils import ID_SEP_KEY, MACRO_KEY
from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY
from monai.config import PathLike
from monai.utils import ensure_tuple, look_up_option, optional_import

Expand Down Expand Up @@ -87,6 +87,8 @@ class ConfigParser:
suffixes = ("json", "yaml", "yml")
suffix_match = rf".*\.({'|'.join(suffixes)})"
path_match = rf"({suffix_match}$)"
# match relative id names, e.g. "@#data", "@##transform#1"
relative_id_prefix = re.compile(rf"(?:{ID_REF_KEY}|{MACRO_KEY}){ID_SEP_KEY}+")
meta_key = "_meta_" # field key to save metadata

def __init__(
Expand Down Expand Up @@ -127,7 +129,7 @@ def __getitem__(self, id: Union[str, int]):
if id == "":
return self.config
config = self.config
for k in str(id).split(self.ref_resolver.sep):
for k in str(id).split(ID_SEP_KEY):
if not isinstance(config, (dict, list)):
raise ValueError(f"config must be dict or list for key `{k}`, but got {type(config)}: {config}.")
indexing = k if isinstance(config, dict) else int(k)
Expand All @@ -151,9 +153,9 @@ def __setitem__(self, id: Union[str, int], config: Any):
self.config = config
self.ref_resolver.reset()
return
keys = str(id).split(self.ref_resolver.sep)
keys = str(id).split(ID_SEP_KEY)
# get the last parent level config item and replace it
last_id = self.ref_resolver.sep.join(keys[:-1])
last_id = ID_SEP_KEY.join(keys[:-1])
conf_ = self[last_id]
indexing = keys[-1] if isinstance(conf_, dict) else int(keys[-1])
conf_[indexing] = config
Expand Down Expand Up @@ -192,7 +194,7 @@ def parse(self, reset: bool = True):
"""
if reset:
self.ref_resolver.reset()
self.resolve_macro()
self.resolve_macro_and_relative_ids()
self._do_parse(config=self.get())

def get_parsed_content(self, id: str = "", **kwargs):
Expand Down Expand Up @@ -247,28 +249,37 @@ def read_config(self, f: Union[PathLike, Sequence[PathLike], Dict], **kwargs):
content.update(self.load_config_files(f, **kwargs))
self.set(config=content)

def _do_resolve(self, config: Any):
def _do_resolve(self, config: Any, id: str = ""):
"""
Recursively resolve the config content to replace the macro tokens with target content.
Recursively resolve `self.config` to replace the relative ids with absolute ids, for example,
`@##A` means `A` in the upper level. and replace the macro tokens with target content,
The macro tokens start with "%", can be from another structured file, like:
``{"net": "%default_net"}``, ``{"net": "%/data/config.json#net"}``.
``"%default_net"``, ``"%/data/config.json#net"``.

Args:
config: input config file to resolve.
id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to
go one level further into the nested structures.
Use digits indexing from "0" for list or other strings for dict.
For example: ``"xform#5"``, ``"net#channels"``. ``""`` indicates the entire ``self.config``.

"""
if isinstance(config, (dict, list)):
for k, v in enumerate(config) if isinstance(config, list) else config.items():
config[k] = self._do_resolve(v)
if isinstance(config, str) and config.startswith(MACRO_KEY):
path, ids = ConfigParser.split_path_id(config[len(MACRO_KEY) :])
parser = ConfigParser(config=self.get() if not path else ConfigParser.load_config_file(path))
return self._do_resolve(config=deepcopy(parser[ids]))
sub_id = f"{id}{ID_SEP_KEY}{k}" if id != "" else k
config[k] = self._do_resolve(v, sub_id)
if isinstance(config, str):
config = self.resolve_relative_ids(id, config)
if config.startswith(MACRO_KEY):
path, ids = ConfigParser.split_path_id(config[len(MACRO_KEY) :])
parser = ConfigParser(config=self.get() if not path else ConfigParser.load_config_file(path))
return self._do_resolve(config=deepcopy(parser[ids]))
return config

def resolve_macro(self):
def resolve_macro_and_relative_ids(self):
"""
Recursively resolve `self.config` to replace the macro tokens with target content.
Recursively resolve `self.config` to replace the relative ids with absolute ids, for example,
`@##A` means `A` in the upper level. and replace the macro tokens with target content,
The macro tokens are marked as starting with "%", can be from another structured file, like:
``"%default_net"``, ``"%/data/config.json#net"``.

Expand All @@ -288,9 +299,8 @@ def _do_parse(self, config, id: str = ""):

"""
if isinstance(config, (dict, list)):
subs = enumerate(config) if isinstance(config, list) else config.items()
for k, v in subs:
sub_id = f"{id}{self.ref_resolver.sep}{k}" if id != "" else k
for k, v in enumerate(config) if isinstance(config, list) else config.items():
sub_id = f"{id}{ID_SEP_KEY}{k}" if id != "" else k
self._do_parse(config=v, id=sub_id)

# copy every config item to make them independent and add them to the resolver
Expand Down Expand Up @@ -376,3 +386,41 @@ def split_path_id(cls, src: str) -> Tuple[str, str]:
path_name = result[0][0] # at most one path_name
_, ids = src.rsplit(path_name, 1)
return path_name, ids[len(ID_SEP_KEY) :] if ids.startswith(ID_SEP_KEY) else ""

@classmethod
def resolve_relative_ids(cls, id: str, value: str) -> str:
"""
To simplify the reference or macro tokens ID in the nested config content, it's available to use
relative ID name which starts with the `ID_SEP_KEY`, for example, "@#A" means `A` in the same level,
`@##A` means `A` in the upper level.
It resolves the relative ids to absolute ids. For example, if the input data is:

.. code-block:: python

{
"A": 1,
"B": {"key": "@##A", "value1": 2, "value2": "%#value1", "value3": [3, 4, "@#1"]},
}

It will resolve `B` to `{"key": "@A", "value1": 2, "value2": "%B#value1", "value3": [3, 4, "@B#value3#1"]}`.

Args:
id: id name for current config item to compute relative id.
value: input value to resolve relative ids.

"""
# get the prefixes like: "@####", "%###", "@#"
prefixes = sorted(set().union(cls.relative_id_prefix.findall(value)), reverse=True)
current_id = id.split(ID_SEP_KEY)

for p in prefixes:
sym = ID_REF_KEY if ID_REF_KEY in p else MACRO_KEY
length = p[len(sym) :].count(ID_SEP_KEY)
if length > len(current_id):
raise ValueError(f"the relative id in `{value}` is out of the range of config content.")
if length == len(current_id):
new = "" # root id is `""`
else:
new = ID_SEP_KEY.join(current_id[:-length]) + ID_SEP_KEY
value = value.replace(p, sym + new)
return value
30 changes: 26 additions & 4 deletions tests/test_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,17 @@
"_target_": "Compose",
"transforms": [
{"_target_": "LoadImaged", "keys": "image"},
{"_target_": "RandTorchVisiond", "keys": "image", "name": "ColorJitter", "brightness": 0.25},
# test relative id in `keys`
{"_target_": "RandTorchVisiond", "keys": "@##0#keys", "name": "ColorJitter", "brightness": 0.25},
],
},
"dataset": {"_target_": "Dataset", "data": [1, 2], "transform": "@transform"},
"dataloader": {
"_target_": "DataLoader",
"dataset": "@dataset",
# test relative id in `dataset`
"dataset": "@##dataset",
"batch_size": 2,
"collate_fn": "monai.data.list_data_collate",
"collate_fn": "$monai.data.list_data_collate",
},
},
["transform", "transform#transforms#0", "transform#transforms#1", "dataset", "dataloader"],
Expand Down Expand Up @@ -73,7 +75,17 @@ def __call__(self, a, b):
]


class TestConfigComponent(unittest.TestCase):
TEST_CASE_3 = [
{
"A": 1,
"B": "@A",
"C": "@#A",
"D": {"key": "@##A", "value1": 2, "value2": "%#value1", "value3": [3, 4, "@#1", "$100 + @#0 + @##value1"]},
}
]


class TestConfigParser(unittest.TestCase):
def test_config_content(self):
test_config = {"preprocessing": [{"_target_": "LoadImage"}], "dataset": {"_target_": "Dataset"}}
parser = ConfigParser(config=test_config)
Expand Down Expand Up @@ -120,6 +132,16 @@ def test_function(self, config):
continue
self.assertEqual(func(1, 2), 3)

@parameterized.expand([TEST_CASE_3])
def test_relative_id(self, config):
parser = ConfigParser(config=config)
for id in config:
item = parser.get_parsed_content(id=id)
if isinstance(item, int):
self.assertEqual(item, 1)
if isinstance(item, dict):
self.assertEqual(str(item), str({"key": 1, "value1": 2, "value2": 2, "value3": [3, 4, 4, 105]}))


if __name__ == "__main__":
unittest.main()