diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py index 23d4ac7c55..800e18ade0 100644 --- a/monai/bundle/config_parser.py +++ b/monai/bundle/config_parser.py @@ -17,7 +17,7 @@ from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem from monai.bundle.reference_resolver import ReferenceResolver -from monai.bundle.utils import ID_SEP_KEY, MACRO_KEY +from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY from monai.config import PathLike from monai.utils import ensure_tuple, look_up_option, optional_import @@ -87,6 +87,8 @@ class ConfigParser: suffixes = ("json", "yaml", "yml") suffix_match = rf".*\.({'|'.join(suffixes)})" path_match = rf"({suffix_match}$)" + # match relative id names, e.g. "@#data", "@##transform#1" + relative_id_prefix = re.compile(rf"(?:{ID_REF_KEY}|{MACRO_KEY}){ID_SEP_KEY}+") meta_key = "_meta_" # field key to save metadata def __init__( @@ -127,7 +129,7 @@ def __getitem__(self, id: Union[str, int]): if id == "": return self.config config = self.config - for k in str(id).split(self.ref_resolver.sep): + for k in str(id).split(ID_SEP_KEY): if not isinstance(config, (dict, list)): raise ValueError(f"config must be dict or list for key `{k}`, but got {type(config)}: {config}.") indexing = k if isinstance(config, dict) else int(k) @@ -151,9 +153,9 @@ def __setitem__(self, id: Union[str, int], config: Any): self.config = config self.ref_resolver.reset() return - keys = str(id).split(self.ref_resolver.sep) + keys = str(id).split(ID_SEP_KEY) # get the last parent level config item and replace it - last_id = self.ref_resolver.sep.join(keys[:-1]) + last_id = ID_SEP_KEY.join(keys[:-1]) conf_ = self[last_id] indexing = keys[-1] if isinstance(conf_, dict) else int(keys[-1]) conf_[indexing] = config @@ -192,7 +194,7 @@ def parse(self, reset: bool = True): """ if reset: self.ref_resolver.reset() - self.resolve_macro() + self.resolve_macro_and_relative_ids() self._do_parse(config=self.get()) def get_parsed_content(self, id: str = "", **kwargs): @@ -247,28 +249,37 @@ def read_config(self, f: Union[PathLike, Sequence[PathLike], Dict], **kwargs): content.update(self.load_config_files(f, **kwargs)) self.set(config=content) - def _do_resolve(self, config: Any): + def _do_resolve(self, config: Any, id: str = ""): """ - Recursively resolve the config content to replace the macro tokens with target content. + Recursively resolve `self.config` to replace the relative ids with absolute ids, for example, + `@##A` means `A` in the upper level. and replace the macro tokens with target content, The macro tokens start with "%", can be from another structured file, like: - ``{"net": "%default_net"}``, ``{"net": "%/data/config.json#net"}``. + ``"%default_net"``, ``"%/data/config.json#net"``. Args: config: input config file to resolve. + id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to + go one level further into the nested structures. + Use digits indexing from "0" for list or other strings for dict. + For example: ``"xform#5"``, ``"net#channels"``. ``""`` indicates the entire ``self.config``. """ if isinstance(config, (dict, list)): for k, v in enumerate(config) if isinstance(config, list) else config.items(): - config[k] = self._do_resolve(v) - if isinstance(config, str) and config.startswith(MACRO_KEY): - path, ids = ConfigParser.split_path_id(config[len(MACRO_KEY) :]) - parser = ConfigParser(config=self.get() if not path else ConfigParser.load_config_file(path)) - return self._do_resolve(config=deepcopy(parser[ids])) + sub_id = f"{id}{ID_SEP_KEY}{k}" if id != "" else k + config[k] = self._do_resolve(v, sub_id) + if isinstance(config, str): + config = self.resolve_relative_ids(id, config) + if config.startswith(MACRO_KEY): + path, ids = ConfigParser.split_path_id(config[len(MACRO_KEY) :]) + parser = ConfigParser(config=self.get() if not path else ConfigParser.load_config_file(path)) + return self._do_resolve(config=deepcopy(parser[ids])) return config - def resolve_macro(self): + def resolve_macro_and_relative_ids(self): """ - Recursively resolve `self.config` to replace the macro tokens with target content. + Recursively resolve `self.config` to replace the relative ids with absolute ids, for example, + `@##A` means `A` in the upper level. and replace the macro tokens with target content, The macro tokens are marked as starting with "%", can be from another structured file, like: ``"%default_net"``, ``"%/data/config.json#net"``. @@ -288,9 +299,8 @@ def _do_parse(self, config, id: str = ""): """ if isinstance(config, (dict, list)): - subs = enumerate(config) if isinstance(config, list) else config.items() - for k, v in subs: - sub_id = f"{id}{self.ref_resolver.sep}{k}" if id != "" else k + for k, v in enumerate(config) if isinstance(config, list) else config.items(): + sub_id = f"{id}{ID_SEP_KEY}{k}" if id != "" else k self._do_parse(config=v, id=sub_id) # copy every config item to make them independent and add them to the resolver @@ -376,3 +386,41 @@ def split_path_id(cls, src: str) -> Tuple[str, str]: path_name = result[0][0] # at most one path_name _, ids = src.rsplit(path_name, 1) return path_name, ids[len(ID_SEP_KEY) :] if ids.startswith(ID_SEP_KEY) else "" + + @classmethod + def resolve_relative_ids(cls, id: str, value: str) -> str: + """ + To simplify the reference or macro tokens ID in the nested config content, it's available to use + relative ID name which starts with the `ID_SEP_KEY`, for example, "@#A" means `A` in the same level, + `@##A` means `A` in the upper level. + It resolves the relative ids to absolute ids. For example, if the input data is: + + .. code-block:: python + + { + "A": 1, + "B": {"key": "@##A", "value1": 2, "value2": "%#value1", "value3": [3, 4, "@#1"]}, + } + + It will resolve `B` to `{"key": "@A", "value1": 2, "value2": "%B#value1", "value3": [3, 4, "@B#value3#1"]}`. + + Args: + id: id name for current config item to compute relative id. + value: input value to resolve relative ids. + + """ + # get the prefixes like: "@####", "%###", "@#" + prefixes = sorted(set().union(cls.relative_id_prefix.findall(value)), reverse=True) + current_id = id.split(ID_SEP_KEY) + + for p in prefixes: + sym = ID_REF_KEY if ID_REF_KEY in p else MACRO_KEY + length = p[len(sym) :].count(ID_SEP_KEY) + if length > len(current_id): + raise ValueError(f"the relative id in `{value}` is out of the range of config content.") + if length == len(current_id): + new = "" # root id is `""` + else: + new = ID_SEP_KEY.join(current_id[:-length]) + ID_SEP_KEY + value = value.replace(p, sym + new) + return value diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index ce98be1214..8b1076b1f7 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -28,15 +28,17 @@ "_target_": "Compose", "transforms": [ {"_target_": "LoadImaged", "keys": "image"}, - {"_target_": "RandTorchVisiond", "keys": "image", "name": "ColorJitter", "brightness": 0.25}, + # test relative id in `keys` + {"_target_": "RandTorchVisiond", "keys": "@##0#keys", "name": "ColorJitter", "brightness": 0.25}, ], }, "dataset": {"_target_": "Dataset", "data": [1, 2], "transform": "@transform"}, "dataloader": { "_target_": "DataLoader", - "dataset": "@dataset", + # test relative id in `dataset` + "dataset": "@##dataset", "batch_size": 2, - "collate_fn": "monai.data.list_data_collate", + "collate_fn": "$monai.data.list_data_collate", }, }, ["transform", "transform#transforms#0", "transform#transforms#1", "dataset", "dataloader"], @@ -73,7 +75,17 @@ def __call__(self, a, b): ] -class TestConfigComponent(unittest.TestCase): +TEST_CASE_3 = [ + { + "A": 1, + "B": "@A", + "C": "@#A", + "D": {"key": "@##A", "value1": 2, "value2": "%#value1", "value3": [3, 4, "@#1", "$100 + @#0 + @##value1"]}, + } +] + + +class TestConfigParser(unittest.TestCase): def test_config_content(self): test_config = {"preprocessing": [{"_target_": "LoadImage"}], "dataset": {"_target_": "Dataset"}} parser = ConfigParser(config=test_config) @@ -120,6 +132,16 @@ def test_function(self, config): continue self.assertEqual(func(1, 2), 3) + @parameterized.expand([TEST_CASE_3]) + def test_relative_id(self, config): + parser = ConfigParser(config=config) + for id in config: + item = parser.get_parsed_content(id=id) + if isinstance(item, int): + self.assertEqual(item, 1) + if isinstance(item, dict): + self.assertEqual(str(item), str({"key": 1, "value1": 2, "value2": 2, "value3": [3, 4, 4, 105]})) + if __name__ == "__main__": unittest.main()