diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 6535ad82b7..4b1cdc6f43 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -44,6 +44,9 @@ Model Manifest .. autoclass:: ConfigItem :members: +.. autoclass:: ReferenceResolver + :members: + `Utilities` ----------- diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index df085bddea..0f233bc3ef 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -10,6 +10,6 @@ # limitations under the License. from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset -from .manifest import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem +from .manifest import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, ReferenceResolver from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/manifest/__init__.py b/monai/apps/manifest/__init__.py index d8919a5249..79c4376d5c 100644 --- a/monai/apps/manifest/__init__.py +++ b/monai/apps/manifest/__init__.py @@ -10,3 +10,4 @@ # limitations under the License. from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem +from .reference_resolver import ReferenceResolver diff --git a/monai/apps/manifest/config_item.py b/monai/apps/manifest/config_item.py index 1f0c06c8fc..075d00b961 100644 --- a/monai/apps/manifest/config_item.py +++ b/monai/apps/manifest/config_item.py @@ -23,29 +23,9 @@ class Instantiable(ABC): """ - Base class for instantiable object with module name and arguments. - - .. code-block:: python - - if not is_disabled(): - instantiate(module_name=resolve_module_name(), args=resolve_args()) - + Base class for an instantiable object. """ - @abstractmethod - def resolve_module_name(self, *args: Any, **kwargs: Any): - """ - Resolve the target module name, it should return an object class (or function) to be instantiated. - """ - raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") - - @abstractmethod - def resolve_args(self, *args: Any, **kwargs: Any): - """ - Resolve the arguments, it should return arguments to be passed to the object when instantiating. - """ - raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") - @abstractmethod def is_disabled(self, *args: Any, **kwargs: Any) -> bool: """ @@ -54,9 +34,9 @@ def is_disabled(self, *args: Any, **kwargs: Any) -> bool: raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") @abstractmethod - def instantiate(self, *args: Any, **kwargs: Any): + def instantiate(self, *args: Any, **kwargs: Any) -> object: """ - Instantiate the target component. + Instantiate the target component and return the instance. """ raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") @@ -140,11 +120,11 @@ class ConfigItem: Args: config: content of a config item, can be objects of any types, a configuration resolver may interpret the content to generate a configuration object. - id: optional name of the current config item, defaults to `None`. + id: name of the current config item, defaults to empty string. """ - def __init__(self, config: Any, id: Optional[str] = None) -> None: + def __init__(self, config: Any, id: str = "") -> None: self.config = config self.id = id @@ -203,7 +183,7 @@ class ConfigComponent(ConfigItem, Instantiable): Args: config: content of a config item. - id: optional name of the current config item, defaults to `None`. + id: name of the current config item, defaults to empty string. locator: a ``ComponentLocator`` to convert a module name string into the actual python module. if `None`, a ``ComponentLocator(excludes=excludes)`` will be used. excludes: if ``locator`` is None, create a new ``ComponentLocator`` with ``excludes``. @@ -214,7 +194,7 @@ class ConfigComponent(ConfigItem, Instantiable): def __init__( self, config: Any, - id: Optional[str] = None, + id: str = "", locator: Optional[ComponentLocator] = None, excludes: Optional[Union[Sequence[str], str]] = None, ) -> None: @@ -319,18 +299,18 @@ class ConfigExpression(ConfigItem): Args: config: content of a config item. - id: optional name of current config item, defaults to `None`. + id: name of current config item, defaults to empty string. globals: additional global context to evaluate the string. """ - def __init__(self, config: Any, id: Optional[str] = None, globals: Optional[Dict] = None) -> None: + def __init__(self, config: Any, id: str = "", globals: Optional[Dict] = None) -> None: super().__init__(config=config, id=id) self.globals = globals def evaluate(self, locals: Optional[Dict] = None): """ - Excute current config content and return the result if it is expression, based on python `eval()`. + Execute the current config content and return the result if it is expression, based on Python `eval()`. For more details: https://docs.python.org/3/library/functions.html#eval. Args: @@ -346,7 +326,7 @@ def evaluate(self, locals: Optional[Dict] = None): def is_expression(config: Union[Dict, List, str]) -> bool: """ Check whether the config is an executable expression string. - Currently A string starts with ``"$"`` character is interpreted as an expression. + Currently, a string starts with ``"$"`` character is interpreted as an expression. Args: config: input config content to check. diff --git a/monai/apps/manifest/reference_resolver.py b/monai/apps/manifest/reference_resolver.py new file mode 100644 index 0000000000..32d089370a --- /dev/null +++ b/monai/apps/manifest/reference_resolver.py @@ -0,0 +1,236 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import warnings +from typing import Any, Dict, Optional, Sequence, Set + +from monai.apps.manifest.config_item import ConfigComponent, ConfigExpression, ConfigItem + + +class ReferenceResolver: + """ + Utility class to manage a set of ``ConfigItem`` and resolve the references between them. + + This class maintains a set of ``ConfigItem`` objects and their associated IDs. + The IDs must be unique within this set. A string in ``ConfigItem`` + starting with ``@`` will be treated as a reference to other ``ConfigItem`` objects by ID. + Since ``ConfigItem`` may have a nested dictionary or list structure, + the reference string may also contain a ``#`` character to refer to a substructure by + key indexing for a dictionary or integer indexing for a list. + + In this class, resolving references is essentially substitution of the reference strings with the + corresponding python objects. A typical workflow of resolving references is as follows: + + - Add multiple ``ConfigItem`` objects to the ``ReferenceResolver`` by ``add_item()``. + - Call ``get_resolved_content()`` to automatically resolve the references. This is done (recursively) by: + - Convert the items to objects, for those do not have references to other items. + - If it is instantiable, instantiate it and cache the class instance in ``resolved_content``. + - If it is an expression, evaluate it and save the value in ``resolved_content``. + - Substitute the reference strings with the corresponding objects. + + Args: + items: ``ConfigItem``s to resolve, this could be added later with ``add_item()``. + + """ + + def __init__(self, items: Optional[Sequence[ConfigItem]] = None): + # save the items in a dictionary with the `ConfigItem.id` as key + self.items = {} if items is None else {i.get_id(): i for i in items} + self.resolved_content: Dict[str, Any] = {} + + def add_item(self, item: ConfigItem): + """ + Add a ``ConfigItem`` to the resolver. + + Args: + item: a ``ConfigItem``. + + """ + id = item.get_id() + if id == "": + raise ValueError("id should not be empty when resolving reference.") + if id in self.items: + warnings.warn(f"id '{id}' is already added.") + return + self.items[id] = item + + def get_item(self, id: str, resolve: bool = False): + """ + Get the ``ConfigItem`` by id. + + If ``resolve=True``, the returned item will be resolved, that is, + all the reference strings are substituted by the corresponding ``ConfigItem`` objects. + + Args: + id: id of the expected config item. + resolve: whether to resolve the item if it is not resolved, default to False. + + """ + if resolve and id not in self.resolved_content: + self._resolve_one_item(id=id) + return self.items.get(id) + + def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None): + """ + Resolve one ``ConfigItem`` of ``id``, cache the resolved result in ``resolved_content``. + If it has unresolved references, recursively resolve the referring items first. + + Args: + id: id name of ``ConfigItem`` to be resolved. + waiting_list: set of ids pending to be resolved. + It's used to detect circular references such as: + `{"name": "A", "dep": "@B"}` and `{"name": "B", "dep": "@A"}`. + + """ + item = self.items[id] # if invalid id name, raise KeyError + item_config = item.get_config() + + if waiting_list is None: + waiting_list = set() + waiting_list.add(id) + + ref_ids = self.find_refs_in_config(config=item_config, id=id) + + # if current item has reference already in the waiting list, that's circular references + for d in ref_ids: + if d in waiting_list: + raise ValueError(f"detected circular references for id='{d}' in the config content.") + + # # check whether the component has any unresolved references + for d in ref_ids: + if d not in self.resolved_content: + # this referring item is not resolved + if d not in self.items: + raise ValueError(f"the referring item `{d}` is not defined in config.") + # recursively resolve the reference first + self._resolve_one_item(id=d, waiting_list=waiting_list) + + # all references are resolved, then try to resolve current config item + new_config = self.update_config_with_refs(config=item_config, id=id, refs=self.resolved_content) + item.update_config(config=new_config) + # save the resolved result into `resolved_content` to recursively resolve others + if isinstance(item, ConfigComponent): + self.resolved_content[id] = item.instantiate() + elif isinstance(item, ConfigExpression): + self.resolved_content[id] = item.evaluate(locals={"refs": self.resolved_content}) + else: + self.resolved_content[id] = new_config + + def get_resolved_content(self, id: str): + """ + Get the resolved ``ConfigItem`` by id. If there are unresolved references, try to resolve them first. + + Args: + id: id name of the expected item. + + """ + if id not in self.resolved_content: + self._resolve_one_item(id=id) + return self.resolved_content[id] + + @staticmethod + def match_refs_pattern(value: str) -> Set[str]: + """ + Match regular expression for the input string to find the references. + The reference string starts with ``"@"``, like: ``"@XXX#YYY#ZZZ"``. + + Args: + value: input value to match regular expression. + + """ + refs: Set[str] = set() + # regular expression pattern to match "@XXX" or "@XXX#YYY" + result = re.compile(r"@\w*[\#\w]*").findall(value) + for item in result: + if ConfigExpression.is_expression(value) or value == item: + # only check when string starts with "$" or the whole content is "@XXX" + refs.add(item[1:]) + return refs + + @staticmethod + def update_refs_pattern(value: str, refs: Dict) -> str: + """ + Match regular expression for the input string to update content with the references. + The reference part starts with ``"@"``, like: ``"@XXX#YYY#ZZZ"``. + References dictionary must contain the referring IDs as keys. + + Args: + value: input value to match regular expression. + refs: all the referring components with ids as keys, default to `None`. + + """ + # regular expression pattern to match "@XXX" or "@XXX#YYY" + result = re.compile(r"@\w*[\#\w]*").findall(value) + for item in result: + ref_id = item[1:] + if ref_id not in refs: + raise KeyError(f"can not find expected ID '{ref_id}' in the references.") + if ConfigExpression.is_expression(value): + # replace with local code, will be used in the `evaluate` logic with `locals={"refs": ...}` + value = value.replace(item, f"refs['{ref_id}']") + elif value == item: + # the whole content is "@XXX", it will avoid the case that regular string contains "@" + value = refs[ref_id] + return value + + @staticmethod + def find_refs_in_config(config, id: str, refs: Optional[Set[str]] = None) -> Set[str]: + """ + Recursively search all the content of input config item to get the ids of references. + References mean: the IDs of other config items (``"@XXX"`` in this config item), or the + sub-item in the config is `instantiable`, or the sub-item in the config is `expression`. + For `dict` and `list`, recursively check the sub-items. + + Args: + config: input config content to search. + id: ID name for the input config item. + refs: list of the ID name of found references, default to `None`. + + """ + refs_: Set[str] = refs or set() + if isinstance(config, str): + return refs_.union(ReferenceResolver.match_refs_pattern(value=config)) + if not isinstance(config, (list, dict)): + return refs_ + for k, v in config.items() if isinstance(config, dict) else enumerate(config): + sub_id = f"{id}#{k}" if id != "" else f"{k}" + if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v): + refs_.add(sub_id) + refs_ = ReferenceResolver.find_refs_in_config(v, sub_id, refs_) + return refs_ + + @staticmethod + def update_config_with_refs(config, id: str, refs: Optional[Dict] = None): + """ + With all the references in ``refs``, update the input config content with references + and return the new config. + + Args: + config: input config content to update. + id: ID name for the input config. + refs: all the referring content with ids, default to `None`. + + """ + refs_: Dict = refs or {} + if isinstance(config, str): + return ReferenceResolver.update_refs_pattern(config, refs_) + if not isinstance(config, (list, dict)): + return config + ret = type(config)() + for idx, v in config.items() if isinstance(config, dict) else enumerate(config): + sub_id = f"{id}#{idx}" if id != "" else f"{idx}" + if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v): + updated = ReferenceResolver.update_config_with_refs(v, sub_id, refs_) + else: + updated = ReferenceResolver.update_config_with_refs(v, sub_id, refs_) + ret.update({idx: updated}) if isinstance(ret, dict) else ret.append(updated) + return ret diff --git a/tests/test_reference_resolver.py b/tests/test_reference_resolver.py new file mode 100644 index 0000000000..a62d6befd9 --- /dev/null +++ b/tests/test_reference_resolver.py @@ -0,0 +1,120 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +import monai +from monai.apps import ConfigComponent, ReferenceResolver +from monai.apps.manifest.config_item import ComponentLocator, ConfigExpression, ConfigItem +from monai.data import DataLoader +from monai.transforms import LoadImaged, RandTorchVisiond +from monai.utils import optional_import + +_, has_tv = optional_import("torchvision") + +# test instance with no dependencies +TEST_CASE_1 = [ + { + # all the recursively parsed config items + "transform#1": {"": "LoadImaged", "": {"keys": ["image"]}}, + "transform#1#": "LoadImaged", + "transform#1#": {"keys": ["image"]}, + "transform#1##keys": ["image"], + "transform#1##keys#0": "image", + }, + "transform#1", + LoadImaged, +] +# test depends on other component and executable code +TEST_CASE_2 = [ + { + # some the recursively parsed config items + "dataloader": { + "": "DataLoader", + "": {"dataset": "@dataset", "collate_fn": "$monai.data.list_data_collate"}, + }, + "dataset": {"": "Dataset", "": {"data": [1, 2]}}, + "dataloader#": "DataLoader", + "dataloader#": {"dataset": "@dataset", "collate_fn": "$monai.data.list_data_collate"}, + "dataloader##dataset": "@dataset", + "dataloader##collate_fn": "$monai.data.list_data_collate", + "dataset#": "Dataset", + "dataset#": {"data": [1, 2]}, + "dataset##data": [1, 2], + "dataset##data#0": 1, + "dataset##data#1": 2, + }, + "dataloader", + DataLoader, +] +# test config has key `name` +TEST_CASE_3 = [ + { + # all the recursively parsed config items + "transform#1": { + "": "RandTorchVisiond", + "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}, + }, + "transform#1#": "RandTorchVisiond", + "transform#1#": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}, + "transform#1##keys": "image", + "transform#1##name": "ColorJitter", + "transform#1##brightness": 0.25, + }, + "transform#1", + RandTorchVisiond, +] + + +class TestReferenceResolver(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2] + ([TEST_CASE_3] if has_tv else [])) + def test_resolve(self, configs, expected_id, output_type): + locator = ComponentLocator() + resolver = ReferenceResolver() + # add items to resolver + for k, v in configs.items(): + if ConfigComponent.is_instantiable(v): + resolver.add_item(ConfigComponent(config=v, id=k, locator=locator)) + elif ConfigExpression.is_expression(v): + resolver.add_item(ConfigExpression(config=v, id=k, globals={"monai": monai, "torch": torch})) + else: + resolver.add_item(ConfigItem(config=v, id=k)) + + result = resolver.get_resolved_content(expected_id) # the root id is `expected_id` here + self.assertTrue(isinstance(result, output_type)) + + # test lazy instantiation + item = resolver.get_item(expected_id, resolve=True) + config = item.get_config() + config[""] = False + item.update_config(config=config) + if isinstance(item, ConfigComponent): + result = item.instantiate() + else: + result = item.get_config() + self.assertTrue(isinstance(result, output_type)) + + def test_circular_references(self): + locator = ComponentLocator() + resolver = ReferenceResolver() + configs = {"A": "@B", "B": "@C", "C": "@A"} + for k, v in configs.items(): + resolver.add_item(ConfigComponent(config=v, id=k, locator=locator)) + for k in ["A", "B", "C"]: + with self.assertRaises(ValueError): + resolver.get_resolved_content(k) + + +if __name__ == "__main__": + unittest.main()