From ea5160d528dac2b7da802aafc9a5f9b3df75e4f0 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 22 Mar 2022 12:56:02 +0800 Subject: [PATCH 1/5] [DLMED] support relative ids Signed-off-by: Nic Ma --- monai/bundle/reference_resolver.py | 37 ++++++++++++++++++++++++++++++ tests/test_config_parser.py | 28 +++++++++++++++++++--- 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/monai/bundle/reference_resolver.py b/monai/bundle/reference_resolver.py index c1599c2124..ee281cb088 100644 --- a/monai/bundle/reference_resolver.py +++ b/monai/bundle/reference_resolver.py @@ -50,6 +50,8 @@ class ReferenceResolver: ref = ID_REF_KEY # reference prefix # match a reference string, e.g. "@id#key", "@id#key#0", "@_target_#key" id_matcher = re.compile(rf"{ref}(?:\w*)(?:{sep}\w*)*") + # match relative id names, e.g. "@#data", "@##transform#1" + relative_id_prefix = re.compile(rf"{ref}{sep}+") def __init__(self, items: Optional[Sequence[ConfigItem]] = None): # save the items in a dictionary with the `ConfigItem.id` as key @@ -165,6 +167,39 @@ def get_resolved_content(self, id: str, **kwargs): self._resolve_one_item(id=id, **kwargs) return self.resolved_content[id] + @classmethod + def resolve_relative_ids(cls, id: str, value: str) -> str: + """ + Resolve the relative reference ids to absolute ids. For example: + + .. code-block:: python + + { + "A": 1, + "B": "@#A", + "C": {"key": "@##A", "value1": 2, "value2": "@#value1", "value3": [3, 4, "@#1"]}, + } + + Args: + id: id name for current config item to compute relative id. + value: input value to resolve relative ids. + + """ + # get the prefixes like: "@####", "@###", "@#" + prefixes = sorted(set().union(cls.relative_id_prefix.findall(value)), reverse=True) + current_id = id.split(cls.sep) + + for p in prefixes: + length = len(p[len(cls.ref) :]) // len(cls.sep) + if length > len(current_id): + raise ValueError(f"the relative id in `{value}` is out of the range of config content.") + if length == len(current_id): + new = "" # root id is `""` + else: + new = cls.sep.join(current_id[:-length]) + cls.sep + value = value.replace(p, cls.ref + new) + return value + @classmethod def match_refs_pattern(cls, value: str) -> Set[str]: """ @@ -228,6 +263,7 @@ def find_refs_in_config(cls, config, id: str, refs: Optional[Set[str]] = None) - """ refs_: Set[str] = refs or set() if isinstance(config, str): + config = cls.resolve_relative_ids(id=id, value=config) return refs_.union(cls.match_refs_pattern(value=config)) if not isinstance(config, (list, dict)): return refs_ @@ -252,6 +288,7 @@ def update_config_with_refs(cls, config, id: str, refs: Optional[Dict] = None): """ refs_: Dict = refs or {} if isinstance(config, str): + config = cls.resolve_relative_ids(id=id, value=config) return cls.update_refs_pattern(config, refs_) if not isinstance(config, (list, dict)): return config diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index ce98be1214..c384100d1b 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -28,15 +28,17 @@ "_target_": "Compose", "transforms": [ {"_target_": "LoadImaged", "keys": "image"}, - {"_target_": "RandTorchVisiond", "keys": "image", "name": "ColorJitter", "brightness": 0.25}, + # test relative id in `keys` + {"_target_": "RandTorchVisiond", "keys": "@##0#keys", "name": "ColorJitter", "brightness": 0.25}, ], }, "dataset": {"_target_": "Dataset", "data": [1, 2], "transform": "@transform"}, "dataloader": { "_target_": "DataLoader", - "dataset": "@dataset", + # test relative id in `dataset` + "dataset": "@##dataset", "batch_size": 2, - "collate_fn": "monai.data.list_data_collate", + "collate_fn": "$monai.data.list_data_collate", }, }, ["transform", "transform#transforms#0", "transform#transforms#1", "dataset", "dataloader"], @@ -73,6 +75,16 @@ def __call__(self, a, b): ] +TEST_CASE_3 = [ + { + "A": 1, + "B": "@A", + "C": "@#A", + "D": {"key": "@##A", "value1": 2, "value2": "@#value1", "value3": [3, 4, "@#1", "$100 + @#0 + @##value1"]}, + } +] + + class TestConfigComponent(unittest.TestCase): def test_config_content(self): test_config = {"preprocessing": [{"_target_": "LoadImage"}], "dataset": {"_target_": "Dataset"}} @@ -120,6 +132,16 @@ def test_function(self, config): continue self.assertEqual(func(1, 2), 3) + @parameterized.expand([TEST_CASE_3]) + def test_relative_id(self, config): + parser = ConfigParser(config=config) + for id in config: + item = parser.get_parsed_content(id=id) + if isinstance(item, int): + self.assertEqual(item, 1) + if isinstance(item, dict): + self.assertEqual(str(item), str({"key": 1, "value1": 2, "value2": 2, "value3": [3, 4, 4, 105]})) + if __name__ == "__main__": unittest.main() From f60c96e6a51d39238daaec14bbf23b41d1fb0b4b Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 22 Mar 2022 18:33:47 +0800 Subject: [PATCH 2/5] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/bundle/reference_resolver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/bundle/reference_resolver.py b/monai/bundle/reference_resolver.py index e99b9969e3..b044c8c6f1 100644 --- a/monai/bundle/reference_resolver.py +++ b/monai/bundle/reference_resolver.py @@ -197,7 +197,7 @@ def resolve_relative_ids(cls, id: str, value: str) -> str: current_id = id.split(cls.sep) for p in prefixes: - length = len(p[len(cls.ref) :]) // len(cls.sep) + length = p[len(cls.ref) :].count(cls.sep) if length > len(current_id): raise ValueError(f"the relative id in `{value}` is out of the range of config content.") if length == len(current_id): From 466e819fbc0cff02b9f0163a42db962f691570af Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 22 Mar 2022 19:02:42 +0800 Subject: [PATCH 3/5] [DLMED] fix typo in test Signed-off-by: Nic Ma --- tests/test_config_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index c384100d1b..2643e3c9e6 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -85,7 +85,7 @@ def __call__(self, a, b): ] -class TestConfigComponent(unittest.TestCase): +class TestConfigParser(unittest.TestCase): def test_config_content(self): test_config = {"preprocessing": [{"_target_": "LoadImage"}], "dataset": {"_target_": "Dataset"}} parser = ConfigParser(config=test_config) From 23d3a629c5e15d549fc977da8affc22f2b4f30a7 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 23 Mar 2022 07:38:22 +0800 Subject: [PATCH 4/5] [DLMED] add description Signed-off-by: Nic Ma --- monai/bundle/reference_resolver.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/monai/bundle/reference_resolver.py b/monai/bundle/reference_resolver.py index b044c8c6f1..610022bba3 100644 --- a/monai/bundle/reference_resolver.py +++ b/monai/bundle/reference_resolver.py @@ -177,16 +177,19 @@ def get_resolved_content(self, id: str, **kwargs): @classmethod def resolve_relative_ids(cls, id: str, value: str) -> str: """ - Resolve the relative reference ids to absolute ids. For example: + To simplify the reference ID in the nested config content, it's available to use relative ID name which starts + with the `sep` symbol, for example, "@#A" means `A` in the same level, `@##A` means `A` in the upper level. + It resolves the relative reference ids to absolute ids. For example, if the input data is: .. code-block:: python { "A": 1, - "B": "@#A", - "C": {"key": "@##A", "value1": 2, "value2": "@#value1", "value3": [3, 4, "@#1"]}, + "B": {"key": "@##A", "value1": 2, "value2": "@#value1", "value3": [3, 4, "@#1"]}, } + It will resolve `B` to `{"key": "@A", "value1": 2, "value2": "@B#value1", "value3": [3, 4, "@B#value3#1"]}`. + Args: id: id name for current config item to compute relative id. value: input value to resolve relative ids. From ccbe941b08905acac12544c8a5c40a4ff573d208 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 23 Mar 2022 08:29:56 +0800 Subject: [PATCH 5/5] [DLMED] optimize the logic Signed-off-by: Nic Ma --- monai/bundle/config_parser.py | 84 +++++++++++++++++++++++------- monai/bundle/reference_resolver.py | 40 -------------- tests/test_config_parser.py | 2 +- 3 files changed, 67 insertions(+), 59 deletions(-) diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py index 23d4ac7c55..800e18ade0 100644 --- a/monai/bundle/config_parser.py +++ b/monai/bundle/config_parser.py @@ -17,7 +17,7 @@ from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem from monai.bundle.reference_resolver import ReferenceResolver -from monai.bundle.utils import ID_SEP_KEY, MACRO_KEY +from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY from monai.config import PathLike from monai.utils import ensure_tuple, look_up_option, optional_import @@ -87,6 +87,8 @@ class ConfigParser: suffixes = ("json", "yaml", "yml") suffix_match = rf".*\.({'|'.join(suffixes)})" path_match = rf"({suffix_match}$)" + # match relative id names, e.g. "@#data", "@##transform#1" + relative_id_prefix = re.compile(rf"(?:{ID_REF_KEY}|{MACRO_KEY}){ID_SEP_KEY}+") meta_key = "_meta_" # field key to save metadata def __init__( @@ -127,7 +129,7 @@ def __getitem__(self, id: Union[str, int]): if id == "": return self.config config = self.config - for k in str(id).split(self.ref_resolver.sep): + for k in str(id).split(ID_SEP_KEY): if not isinstance(config, (dict, list)): raise ValueError(f"config must be dict or list for key `{k}`, but got {type(config)}: {config}.") indexing = k if isinstance(config, dict) else int(k) @@ -151,9 +153,9 @@ def __setitem__(self, id: Union[str, int], config: Any): self.config = config self.ref_resolver.reset() return - keys = str(id).split(self.ref_resolver.sep) + keys = str(id).split(ID_SEP_KEY) # get the last parent level config item and replace it - last_id = self.ref_resolver.sep.join(keys[:-1]) + last_id = ID_SEP_KEY.join(keys[:-1]) conf_ = self[last_id] indexing = keys[-1] if isinstance(conf_, dict) else int(keys[-1]) conf_[indexing] = config @@ -192,7 +194,7 @@ def parse(self, reset: bool = True): """ if reset: self.ref_resolver.reset() - self.resolve_macro() + self.resolve_macro_and_relative_ids() self._do_parse(config=self.get()) def get_parsed_content(self, id: str = "", **kwargs): @@ -247,28 +249,37 @@ def read_config(self, f: Union[PathLike, Sequence[PathLike], Dict], **kwargs): content.update(self.load_config_files(f, **kwargs)) self.set(config=content) - def _do_resolve(self, config: Any): + def _do_resolve(self, config: Any, id: str = ""): """ - Recursively resolve the config content to replace the macro tokens with target content. + Recursively resolve `self.config` to replace the relative ids with absolute ids, for example, + `@##A` means `A` in the upper level. and replace the macro tokens with target content, The macro tokens start with "%", can be from another structured file, like: - ``{"net": "%default_net"}``, ``{"net": "%/data/config.json#net"}``. + ``"%default_net"``, ``"%/data/config.json#net"``. Args: config: input config file to resolve. + id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to + go one level further into the nested structures. + Use digits indexing from "0" for list or other strings for dict. + For example: ``"xform#5"``, ``"net#channels"``. ``""`` indicates the entire ``self.config``. """ if isinstance(config, (dict, list)): for k, v in enumerate(config) if isinstance(config, list) else config.items(): - config[k] = self._do_resolve(v) - if isinstance(config, str) and config.startswith(MACRO_KEY): - path, ids = ConfigParser.split_path_id(config[len(MACRO_KEY) :]) - parser = ConfigParser(config=self.get() if not path else ConfigParser.load_config_file(path)) - return self._do_resolve(config=deepcopy(parser[ids])) + sub_id = f"{id}{ID_SEP_KEY}{k}" if id != "" else k + config[k] = self._do_resolve(v, sub_id) + if isinstance(config, str): + config = self.resolve_relative_ids(id, config) + if config.startswith(MACRO_KEY): + path, ids = ConfigParser.split_path_id(config[len(MACRO_KEY) :]) + parser = ConfigParser(config=self.get() if not path else ConfigParser.load_config_file(path)) + return self._do_resolve(config=deepcopy(parser[ids])) return config - def resolve_macro(self): + def resolve_macro_and_relative_ids(self): """ - Recursively resolve `self.config` to replace the macro tokens with target content. + Recursively resolve `self.config` to replace the relative ids with absolute ids, for example, + `@##A` means `A` in the upper level. and replace the macro tokens with target content, The macro tokens are marked as starting with "%", can be from another structured file, like: ``"%default_net"``, ``"%/data/config.json#net"``. @@ -288,9 +299,8 @@ def _do_parse(self, config, id: str = ""): """ if isinstance(config, (dict, list)): - subs = enumerate(config) if isinstance(config, list) else config.items() - for k, v in subs: - sub_id = f"{id}{self.ref_resolver.sep}{k}" if id != "" else k + for k, v in enumerate(config) if isinstance(config, list) else config.items(): + sub_id = f"{id}{ID_SEP_KEY}{k}" if id != "" else k self._do_parse(config=v, id=sub_id) # copy every config item to make them independent and add them to the resolver @@ -376,3 +386,41 @@ def split_path_id(cls, src: str) -> Tuple[str, str]: path_name = result[0][0] # at most one path_name _, ids = src.rsplit(path_name, 1) return path_name, ids[len(ID_SEP_KEY) :] if ids.startswith(ID_SEP_KEY) else "" + + @classmethod + def resolve_relative_ids(cls, id: str, value: str) -> str: + """ + To simplify the reference or macro tokens ID in the nested config content, it's available to use + relative ID name which starts with the `ID_SEP_KEY`, for example, "@#A" means `A` in the same level, + `@##A` means `A` in the upper level. + It resolves the relative ids to absolute ids. For example, if the input data is: + + .. code-block:: python + + { + "A": 1, + "B": {"key": "@##A", "value1": 2, "value2": "%#value1", "value3": [3, 4, "@#1"]}, + } + + It will resolve `B` to `{"key": "@A", "value1": 2, "value2": "%B#value1", "value3": [3, 4, "@B#value3#1"]}`. + + Args: + id: id name for current config item to compute relative id. + value: input value to resolve relative ids. + + """ + # get the prefixes like: "@####", "%###", "@#" + prefixes = sorted(set().union(cls.relative_id_prefix.findall(value)), reverse=True) + current_id = id.split(ID_SEP_KEY) + + for p in prefixes: + sym = ID_REF_KEY if ID_REF_KEY in p else MACRO_KEY + length = p[len(sym) :].count(ID_SEP_KEY) + if length > len(current_id): + raise ValueError(f"the relative id in `{value}` is out of the range of config content.") + if length == len(current_id): + new = "" # root id is `""` + else: + new = ID_SEP_KEY.join(current_id[:-length]) + ID_SEP_KEY + value = value.replace(p, sym + new) + return value diff --git a/monai/bundle/reference_resolver.py b/monai/bundle/reference_resolver.py index 610022bba3..f9f73c9c71 100644 --- a/monai/bundle/reference_resolver.py +++ b/monai/bundle/reference_resolver.py @@ -50,8 +50,6 @@ class ReferenceResolver: ref = ID_REF_KEY # reference prefix # match a reference string, e.g. "@id#key", "@id#key#0", "@_target_#key" id_matcher = re.compile(rf"{ref}(?:\w*)(?:{sep}\w*)*") - # match relative id names, e.g. "@#data", "@##transform#1" - relative_id_prefix = re.compile(rf"{ref}{sep}+") def __init__(self, items: Optional[Sequence[ConfigItem]] = None): # save the items in a dictionary with the `ConfigItem.id` as key @@ -174,42 +172,6 @@ def get_resolved_content(self, id: str, **kwargs): """ return self._resolve_one_item(id=id, **kwargs) - @classmethod - def resolve_relative_ids(cls, id: str, value: str) -> str: - """ - To simplify the reference ID in the nested config content, it's available to use relative ID name which starts - with the `sep` symbol, for example, "@#A" means `A` in the same level, `@##A` means `A` in the upper level. - It resolves the relative reference ids to absolute ids. For example, if the input data is: - - .. code-block:: python - - { - "A": 1, - "B": {"key": "@##A", "value1": 2, "value2": "@#value1", "value3": [3, 4, "@#1"]}, - } - - It will resolve `B` to `{"key": "@A", "value1": 2, "value2": "@B#value1", "value3": [3, 4, "@B#value3#1"]}`. - - Args: - id: id name for current config item to compute relative id. - value: input value to resolve relative ids. - - """ - # get the prefixes like: "@####", "@###", "@#" - prefixes = sorted(set().union(cls.relative_id_prefix.findall(value)), reverse=True) - current_id = id.split(cls.sep) - - for p in prefixes: - length = p[len(cls.ref) :].count(cls.sep) - if length > len(current_id): - raise ValueError(f"the relative id in `{value}` is out of the range of config content.") - if length == len(current_id): - new = "" # root id is `""` - else: - new = cls.sep.join(current_id[:-length]) + cls.sep - value = value.replace(p, cls.ref + new) - return value - @classmethod def match_refs_pattern(cls, value: str) -> Set[str]: """ @@ -273,7 +235,6 @@ def find_refs_in_config(cls, config, id: str, refs: Optional[Set[str]] = None) - """ refs_: Set[str] = refs or set() if isinstance(config, str): - config = cls.resolve_relative_ids(id=id, value=config) return refs_.union(cls.match_refs_pattern(value=config)) if not isinstance(config, (list, dict)): return refs_ @@ -298,7 +259,6 @@ def update_config_with_refs(cls, config, id: str, refs: Optional[Dict] = None): """ refs_: Dict = refs or {} if isinstance(config, str): - config = cls.resolve_relative_ids(id=id, value=config) return cls.update_refs_pattern(config, refs_) if not isinstance(config, (list, dict)): return config diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index 2643e3c9e6..8b1076b1f7 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -80,7 +80,7 @@ def __call__(self, a, b): "A": 1, "B": "@A", "C": "@#A", - "D": {"key": "@##A", "value1": 2, "value2": "@#value1", "value3": [3, 4, "@#1", "$100 + @#0 + @##value1"]}, + "D": {"key": "@##A", "value1": 2, "value2": "%#value1", "value3": [3, 4, "@#1", "$100 + @#0 + @##value1"]}, } ]