From 5dc08f44bd1946e0c4fee7549b9431da153d6d9c Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 26 Jan 2022 23:00:08 +0800 Subject: [PATCH 01/25] [DLMED] add ConfigComponent Signed-off-by: Nic Ma --- docs/source/apps.rst | 7 ++ monai/apps/__init__.py | 11 +- monai/apps/mmars/__init__.py | 2 + monai/apps/mmars/config_resolver.py | 166 ++++++++++++++++++++++++++++ monai/apps/mmars/utils.py | 88 +++++++++++++++ monai/utils/__init__.py | 3 + monai/utils/module.py | 104 ++++++++++++++++- tests/test_config_component.py | 142 ++++++++++++++++++++++++ 8 files changed, 521 insertions(+), 2 deletions(-) create mode 100644 monai/apps/mmars/config_resolver.py create mode 100644 monai/apps/mmars/utils.py create mode 100644 tests/test_config_component.py diff --git a/docs/source/apps.rst b/docs/source/apps.rst index f4f7aff2d2..8393566254 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -29,6 +29,13 @@ Clara MMARs :annotation: +Model Package +------------- + +.. autoclass:: ConfigComponent + :members: + + `Utilities` ----------- diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index 893f7877d2..46203fb09b 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -10,5 +10,14 @@ # limitations under the License. from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset -from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar +from .mmars import ( + MODEL_DESC, + ConfigComponent, + RemoteMMARKeys, + download_mmar, + get_model_spec, + load_from_mmar, + search_configs_with_deps, + update_configs_with_deps, +) from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/mmars/__init__.py b/monai/apps/mmars/__init__.py index 8f1448bb06..f565e4c6d9 100644 --- a/monai/apps/mmars/__init__.py +++ b/monai/apps/mmars/__init__.py @@ -9,5 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .config_resolver import ConfigComponent from .mmars import download_mmar, get_model_spec, load_from_mmar from .model_desc import MODEL_DESC, RemoteMMARKeys +from .utils import search_configs_with_deps, update_configs_with_deps diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py new file mode 100644 index 0000000000..eb12ae78d1 --- /dev/null +++ b/monai/apps/mmars/config_resolver.py @@ -0,0 +1,166 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Any, Dict, List, Optional + +from monai.apps.mmars.utils import search_configs_with_deps, update_configs_with_deps +from monai.utils.module import ClassScanner, instantiate_class + + +class ConfigComponent: + """ + Utility class to manage every component in the config with a unique `id` name. + When recursively parsing a complicated config dictioanry, every item should be treated as a `ConfigComponent`. + For example: + - `{"preprocessing": [{"": "LoadImage", "": {"keys": "image"}}]}` + - `{"": "LoadImage", "": {"keys": "image"}}` + - `"": "LoadImage"` + - `"keys": "image"` + + It can search the config content and find out all the dependencies, then build the config to instance + when all the dependencies are resolved. + + Here we predefined several special marks to parse the config content: + - "": like "" is the name of a class, to distinguish it with regular key "name" in the config content. + now we have 4 keys: ``, ``, ``, ``. + - "XXX#YYY": join nested config ids, like "transforms#5" is id name of the 6th transform in the transforms list. + - "@XXX": use an instance as config item, like `"dataset": "@dataset"` uses `dataset` instance as the parameter. + - "$XXX": execute the string after "$" as python code with `eval()` function, like "$@model.parameters()". + + Args: + id: id name of current config component, for nested config items, use `#` to join ids. + for list component, use index from `0` as id. + for example: `transform`, `transform#5`, `transform#5##keys`, etc. + config: config content of current component, can be a `dict`, `list`, `string`, `float`, `int`, etc. + class_scanner: ClassScanner to help get the class name or path in the config and build instance. + globals: to support executable string in the config, sometimes we need to provide the global variables + which are referred in the executable string. for example: `globals={"monai": monai} will be useful + for config `"collate_fn": "$monai.data.list_data_collate"`. + + """ + + def __init__(self, id: str, config: Any, class_scanner: ClassScanner, globals: Optional[Dict] = None) -> None: + self.id = id + self.config = config + self.class_scanner = class_scanner + self.globals = globals + + def get_id(self) -> str: + """ + Get the id name of current config component. + + """ + return self.id + + def get_config(self): + """ + Get the raw config content of current config component. + + """ + return self.config + + def get_dependent_ids(self) -> List[str]: + """ + Recursively search all the content of current config compoent to get the ids of dependencies. + It's used to build all the dependencies before build current config component. + For `dict` and `list`, treat every item as a dependency. + For example, for `{"": "DataLoader", "": {"dataset": "@dataset"}}`, the dependency ids: + `["", "", "#dataset", "dataset"]`. + + """ + return search_configs_with_deps(config=self.config, id=self.id) + + def get_updated_config(self, deps: dict): + """ + If all the dependencies are ready in `deps`, update the config content with them and return new config. + It can be used for lazy instantiation. + + Args: + deps: all the dependent components with ids. + + """ + return update_configs_with_deps(config=self.config, deps=deps, id=self.id, globals=self.globals) + + def _check_dependency(self, config): + """ + Check whether current config still has unresolved dependencies or executable string code. + + Args: + config: config content to check. + + """ + if isinstance(config, list): + for i in config: + if self._check_dependency(i): + return True + if isinstance(config, dict): + for v in config.values(): + if self._check_dependency(v): + return True + if isinstance(config, str): + if config.startswith("&") or "@" in config: + return True + return False + + def build(self, config: Optional[Dict] = None) -> object: + """ + Build component instance based on the provided dictonary config. + Supported special keys for the config: + - '' - class name in the modules of packages. + - '' - directly specify the class path, based on PYTHONPATH, ignore '' if specified. + - '' - arguments to initialize the component instance. + - '' - if defined `'': True`, will skip the buiding, useful for development or tuning. + + Args: + config: dictionary config that defines a component. + + Raises: + ValueError: must provide `` or `` of class to build component. + ValueError: can not find component class. + + """ + config = self.config if config is None else config + if self._check_dependency(config=config): + warnings.warn("config content has other dependencies or executable string, skip `build`.") + return config + + if ( + not isinstance(config, dict) + or ("" not in config and "" not in config) + or config.get("") is True + ): + # if marked as `disabled`, skip parsing + return config + + class_args = config.get("", {}) + class_path = self._get_class_path(config) + return instantiate_class(class_path, **class_args) + + def _get_class_path(self, config): + """ + Get the path of class specified in the config content. + + Args: + config: dictionary config that defines a component. + + """ + class_path = config.get("", None) + if class_path is None: + class_name = config.get("", None) + if class_name is None: + raise ValueError("must provide `` or `` of class to build component.") + module_name = self.class_scanner.get_class_module_name(class_name) + if module_name is None: + raise ValueError(f"can not find component class '{class_name}'.") + class_path = f"{module_name}.{class_name}" + + return class_path diff --git a/monai/apps/mmars/utils.py b/monai/apps/mmars/utils.py new file mode 100644 index 0000000000..f532ff8cc1 --- /dev/null +++ b/monai/apps/mmars/utils.py @@ -0,0 +1,88 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Dict, List, Optional, Union + + +def search_configs_with_deps(config: Union[Dict, List, str], id: str, deps: Optional[List[str]] = None) -> List[str]: + """ + Recursively search all the content of input config compoent to get the ids of dependencies. + It's used to build all the dependencies before build current config component. + For `dict` and `list`, treat every item as a dependency. + For example, for `{"": "DataLoader", "": {"dataset": "@dataset"}}`, the dependency ids: + `["", "", "#dataset", "dataset"]`. + + Args: + config: input config content to search. + id: id name for the input config. + deps: list of the id name of existing dependencies, default to None. + + """ + deps_: List[str] = [] if deps is None else deps + pattern = re.compile(r"@\w*[\#\w]*") # match ref as args: "@XXX#YYY#ZZZ" + if isinstance(config, list): + for i, v in enumerate(config): + sub_id = f"{id}#{i}" + # all the items in the list should be marked as dependent reference + deps_.append(sub_id) + deps_ = search_configs_with_deps(v, sub_id, deps_) + if isinstance(config, dict): + for k, v in config.items(): + sub_id = f"{id}#{k}" + # all the items in the dict should be marked as dependent reference + deps_.append(sub_id) + deps_ = search_configs_with_deps(v, sub_id, deps_) + if isinstance(config, str): + result = pattern.findall(config) + for item in result: + if config.startswith("$") or config == item: + ref_obj_id = item[1:] + if ref_obj_id not in deps_: + deps_.append(ref_obj_id) + return deps_ + + +def update_configs_with_deps(config: Union[Dict, List, str], deps: dict, id: str, globals: Optional[Dict] = None): + """ + With all the dependencies in `deps`, update the config content with them and return new config. + It can be used for lazy instantiation. + + Args: + config: input config content to update. + deps: all the dependent components with ids. + id: id name for the input config. + globals: predefined global variables to execute code string with `eval()`. + + """ + pattern = re.compile(r"@\w*[\#\w]*") # match ref as args: "@XXX#YYY#ZZZ" + if isinstance(config, list): + # all the items in the list should be replaced with the reference + return [deps[f"{id}#{i}"] for i in range(len(config))] + if isinstance(config, dict): + # all the items in the dict should be replaced with the reference + return {k: deps[f"{id}#{k}"] for k, _ in config.items()} + if isinstance(config, str): + result = pattern.findall(config) + config_: str = config + for item in result: + ref_obj_id = item[1:] + if config_.startswith("$"): + # replace with local code and execute soon + config_ = config_.replace(item, f"deps['{ref_obj_id}']") + elif config_ == item: + config_ = deps[ref_obj_id] + + if isinstance(config_, str): + if config_.startswith("$"): + config_ = eval(config_[1:], globals, {"deps": deps}) + return config_ + return config diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index dec9bdea65..a1e9390c30 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -61,14 +61,17 @@ zip_with, ) from .module import ( + ClassScanner, InvalidPyTorchVersionError, OptionalImportError, damerau_levenshtein_distance, exact_version, export, + get_class, get_full_type_name, get_package_version, get_torch_version_tuple, + instantiate_class, load_submodules, look_up_option, min_version, diff --git a/monai/utils/module.py b/monai/utils/module.py index c0fc10a7c0..2994b80421 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -10,6 +10,7 @@ # limitations under the License. import enum +import inspect import os import re import sys @@ -19,7 +20,7 @@ from pkgutil import walk_packages from re import match from types import FunctionType -from typing import Any, Callable, Collection, Hashable, Iterable, List, Mapping, Tuple, cast +from typing import Any, Callable, Collection, Hashable, Iterable, List, Mapping, Sequence, Tuple, cast import torch @@ -36,6 +37,9 @@ "optional_import", "require_pkg", "load_submodules", + "ClassScanner", + "get_class", + "instantiate_class", "get_full_type_name", "get_package_version", "get_torch_version_tuple", @@ -193,7 +197,105 @@ def load_submodules(basemod, load_all: bool = True, exclude_pattern: str = "(.*[ return submodules, err_mod +class ClassScanner: + """ + Scan all the available classes in the specified packages and modules. + Map the all the class names and the module names in a table. + + Args: + pkgs: the expected packages to scan modules and parse class names in the config. + modules: the expected modules in the packages to scan for all the classes. + for example, to parser "LoadImage" in config, `pkgs` can be ["monai"], `modules` can be ["transforms"]. + + """ + + def __init__(self, pkgs: Sequence[str], modules: Sequence[str]): + self.pkgs = pkgs + self.modules = modules + self._class_table = self._create_classes_table() + + def _create_classes_table(self): + class_table = {} + for pkg in self.pkgs: + package = import_module(pkg) + + for _, modname, _ in walk_packages(path=package.__path__, prefix=package.__name__ + "."): + # if no modules specified, load all modules in the package + if len(self.modules) == 0 or any(name in modname for name in self.modules): + try: + module = import_module(modname) + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and obj.__module__ == modname: + class_table[name] = modname + except ModuleNotFoundError: + pass + return class_table + + def get_class_module_name(self, class_name): + """ + Get the module name of the class with specified class name. + + Args: + class_name: name of the expected class. + + """ + return self._class_table.get(class_name, None) + + +def get_class(class_path: str): + """ + Get the class from specified class path. + + Args: + class_path (str): full path of the class. + + Raises: + ValueError: invalid class_path, missing the module name. + ValueError: class does not exist. + ValueError: module does not exist. + + """ + if len(class_path.split(".")) < 2: + raise ValueError(f"invalid class_path: {class_path}, missing the module name.") + module_name, class_name = class_path.rsplit(".", 1) + + try: + module_ = import_module(module_name) + + try: + class_ = getattr(module_, class_name) + except AttributeError as e: + raise ValueError(f"class {class_name} does not exist.") from e + + except AttributeError as e: + raise ValueError(f"module {module_name} does not exist.") from e + + return class_ + + +def instantiate_class(class_path: str, **kwargs): + """ + Method for creating an instance for the specified class. + + Args: + class_path: full path of the class. + kwargs: arguments to initialize the class instance. + + Raises: + ValueError: class has paramenters error. + """ + + try: + return get_class(class_path)(**kwargs) + except TypeError as e: + raise ValueError(f"class {class_path} has parameters error.") from e + + def get_full_type_name(typeobj): + """ + Utility to get the full path name of a class or object type. + + """ module = typeobj.__module__ if module is None or module == str.__class__.__module__: return typeobj.__name__ # Avoid reporting __builtin__ diff --git a/tests/test_config_component.py b/tests/test_config_component.py new file mode 100644 index 0000000000..d48a4e274f --- /dev/null +++ b/tests/test_config_component.py @@ -0,0 +1,142 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Callable, Iterator + +import torch +from parameterized import parameterized + +import monai +from monai.apps import ConfigComponent +from monai.data import DataLoader, Dataset +from monai.transforms import LoadImaged, RandTorchVisiond +from monai.utils import ClassScanner, optional_import + +_, has_tv = optional_import("torchvision") + +TEST_CASE_1 = [ + dict(pkgs=["monai"], modules=["transforms"]), + {"": "LoadImaged", "": {"keys": ["image"]}}, + LoadImaged, +] +# test python `` +TEST_CASE_2 = [ + dict(pkgs=[], modules=[]), + {"": "monai.transforms.LoadImaged", "": {"keys": ["image"]}}, + LoadImaged, +] +# test `` +TEST_CASE_3 = [ + dict(pkgs=["monai"], modules=["transforms"]), + {"": "LoadImaged", "": True, "": {"keys": ["image"]}}, + dict, +] +# test unresolved dependency +TEST_CASE_4 = [ + dict(pkgs=["monai"], modules=["transforms"]), + {"": "LoadImaged", "": {"keys": ["@key_name"]}}, + dict, +] +# test non-monai modules +TEST_CASE_5 = [ + dict(pkgs=["torch.optim", "monai"], modules=["adam"]), + {"": "Adam", "": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, + torch.optim.Adam, +] +# test args contains "name" field +TEST_CASE_6 = [ + dict(pkgs=["monai"], modules=["transforms"]), + {"": "RandTorchVisiond", "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}}, + RandTorchVisiond, +] +# test dependencies of dict config +TEST_CASE_7 = [{"dataset": "@dataset", "batch_size": 2}, ["test#dataset", "dataset", "test#batch_size"]] +# test dependencies of list config +TEST_CASE_8 = [ + {"dataset": "@dataset", "transforms": ["@trans0", "@trans1"]}, + ["test#dataset", "dataset", "test#transforms", "test#transforms#0", "trans0", "test#transforms#1", "trans1"], +] +# test dependencies of execute code +TEST_CASE_9 = [ + {"dataset": "$@dataset.test_func()", "transforms": ["$torch.zeros([2, 2]) + @trans"]}, + ["test#dataset", "dataset", "test#transforms", "test#transforms#0", "trans"], +] + +# test dependencies of lambda function +TEST_CASE_10 = [ + {"lr_range": "$lambda x: x + @num_epochs", "lr": ["$lambda x: torch.zeros([2, 2]) + @init_lr"]}, + ["test#lr_range", "num_epochs", "test#lr", "test#lr#0", "init_lr"], +] +# test instance with no dependencies +TEST_CASE_11 = [ + "transform#1", + {"": "LoadImaged", "": {"keys": ["image"]}}, + {"transform#1#": "LoadImaged", "transform#1#": {"keys": ["image"]}}, + LoadImaged, +] +# test dataloader refers to `@dataset`, here we don't test recursive dependencies, test that in `ConfigResolver` +TEST_CASE_12 = [ + "dataloader", + {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}}, + {"dataloader#": "DataLoader", "dataloader#": {"dataset": Dataset(data=[1, 2]), "batch_size": 2}}, + DataLoader, +] +# test dependencies in code execution +TEST_CASE_13 = [ + "optimizer", + {"": "Adam", "": {"params": "$@model.parameters()", "lr": "@learning_rate"}}, + {"optimizer#": "Adam", "optimizer#": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, + torch.optim.Adam, +] +# test replace dependencies with code execution result +TEST_CASE_14 = ["optimizer##params", "$@model.parameters()", {"model": torch.nn.PReLU()}, Iterator] +# test execute some function in args, test pre-imported global packages `monai` +TEST_CASE_15 = ["dataloader##collate_fn", "$monai.data.list_data_collate", {}, Callable] +# test lambda function, should not execute the lambda function, just change the string with dependent objects +TEST_CASE_16 = ["dataloader##collate_fn", "$lambda x: monai.data.list_data_collate(x) + 100", {}, Callable] + + +class TestConfigComponent(unittest.TestCase): + @parameterized.expand( + [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5] + ([TEST_CASE_6] if has_tv else []) + ) + def test_build(self, input_param, test_input, output_type): + scanner = ClassScanner(**input_param) + configer = ConfigComponent(id="test", config=test_input, class_scanner=scanner) + ret = configer.build() + self.assertTrue(isinstance(ret, output_type)) + if isinstance(ret, LoadImaged): + self.assertEqual(ret.keys[0], "image") + if isinstance(ret, dict): + # test `` works fine + self.assertDictEqual(ret, test_input) + + @parameterized.expand([TEST_CASE_7, TEST_CASE_8, TEST_CASE_9, TEST_CASE_10]) + def test_dependent_ids(self, test_input, ref_ids): + scanner = ClassScanner(pkgs=[], modules=[]) + configer = ConfigComponent(id="test", config=test_input, class_scanner=scanner) + ret = configer.get_dependent_ids() + self.assertListEqual(ret, ref_ids) + + @parameterized.expand([TEST_CASE_11, TEST_CASE_12, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15, TEST_CASE_16]) + def test_update_dependencies(self, id, test_input, deps, output_type): + scanner = ClassScanner(pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"]) + configer = ConfigComponent( + id=id, config=test_input, class_scanner=scanner, globals={"monai": monai, "torch": torch} + ) + config = configer.get_updated_config(deps) + ret = configer.build(config) + self.assertTrue(isinstance(ret, output_type)) + + +if __name__ == "__main__": + unittest.main() From 7bf23095e351bc1f8b7a00553109ad6ad6701348 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 11 Feb 2022 17:40:37 +0800 Subject: [PATCH 02/25] [DLMED] totally update according to comments Signed-off-by: Nic Ma --- monai/apps/mmars/config_resolver.py | 68 ++++++++------- monai/utils/__init__.py | 6 +- monai/utils/module.py | 130 ++++++++++++++++------------ tests/test_config_component.py | 56 ++++++------ 4 files changed, 146 insertions(+), 114 deletions(-) diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py index eb12ae78d1..60ec3c6a7a 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_resolver.py @@ -13,7 +13,7 @@ from typing import Any, Dict, List, Optional from monai.apps.mmars.utils import search_configs_with_deps, update_configs_with_deps -from monai.utils.module import ClassScanner, instantiate_class +from monai.utils.module import ComponentScanner, instantiate class ConfigComponent: @@ -29,29 +29,30 @@ class ConfigComponent: It can search the config content and find out all the dependencies, then build the config to instance when all the dependencies are resolved. - Here we predefined several special marks to parse the config content: - - "": like "" is the name of a class, to distinguish it with regular key "name" in the config content. - now we have 4 keys: ``, ``, ``, ``. + Here we predefined 4 kinds special marks (`<>`, `#`, `@`, `$`) to parse the config content: + - "": like "" is the name of a target component, to distinguish it with regular key "name" + in the config content. now we have 4 keys: ``, ``, ``, ``. - "XXX#YYY": join nested config ids, like "transforms#5" is id name of the 6th transform in the transforms list. - - "@XXX": use an instance as config item, like `"dataset": "@dataset"` uses `dataset` instance as the parameter. + - "@XXX": use an component as config item, like `"input_data": "@dataset"` uses `dataset` instance as parameter. - "$XXX": execute the string after "$" as python code with `eval()` function, like "$@model.parameters()". Args: id: id name of current config component, for nested config items, use `#` to join ids. for list component, use index from `0` as id. for example: `transform`, `transform#5`, `transform#5##keys`, etc. + the id can be useful to quickly get the expected item in a complicated and nested config content. config: config content of current component, can be a `dict`, `list`, `string`, `float`, `int`, etc. - class_scanner: ClassScanner to help get the class name or path in the config and build instance. + scanner: ComponentScanner to help get the `` or `` in the config and build instance. globals: to support executable string in the config, sometimes we need to provide the global variables which are referred in the executable string. for example: `globals={"monai": monai} will be useful for config `"collate_fn": "$monai.data.list_data_collate"`. """ - def __init__(self, id: str, config: Any, class_scanner: ClassScanner, globals: Optional[Dict] = None) -> None: + def __init__(self, id: str, config: Any, scanner: ComponentScanner, globals: Optional[Dict] = None) -> None: self.id = id self.config = config - self.class_scanner = class_scanner + self.scanner = scanner self.globals = globals def get_id(self) -> str: @@ -111,21 +112,23 @@ def _check_dependency(self, config): return True return False - def build(self, config: Optional[Dict] = None) -> object: + def build(self, config: Optional[Dict] = None, **kwargs) -> object: """ Build component instance based on the provided dictonary config. + The target component must be a class and a function. Supported special keys for the config: - - '' - class name in the modules of packages. - - '' - directly specify the class path, based on PYTHONPATH, ignore '' if specified. + - '' - class / function name in the modules of packages. + - '' - directly specify the path, based on PYTHONPATH, ignore '' if specified. - '' - arguments to initialize the component instance. - '' - if defined `'': True`, will skip the buiding, useful for development or tuning. Args: config: dictionary config that defines a component. + kwargs: args to override / add the config args when building. Raises: - ValueError: must provide `` or `` of class to build component. - ValueError: can not find component class. + ValueError: must provide `` or `` of class / function to build component. + ValueError: can not find component class or function. """ config = self.config if config is None else config @@ -141,26 +144,33 @@ def build(self, config: Optional[Dict] = None) -> object: # if marked as `disabled`, skip parsing return config - class_args = config.get("", {}) - class_path = self._get_class_path(config) - return instantiate_class(class_path, **class_args) + args = config.get("", {}) + args.update(kwargs) + path = self._get_path(config) + return instantiate(path, **args) - def _get_class_path(self, config): + def _get_path(self, config): """ - Get the path of class specified in the config content. + Get the path of class / function specified in the config content. Args: config: dictionary config that defines a component. """ - class_path = config.get("", None) - if class_path is None: - class_name = config.get("", None) - if class_name is None: - raise ValueError("must provide `` or `` of class to build component.") - module_name = self.class_scanner.get_class_module_name(class_name) - if module_name is None: - raise ValueError(f"can not find component class '{class_name}'.") - class_path = f"{module_name}.{class_name}" - - return class_path + path = config.get("", None) + if path is None: + name = config.get("", None) + if name is None: + raise ValueError("must provide `` or `` of target component to build.") + module = self.scanner.get_component_module_name(name) + if module is None: + raise ValueError(f"can not find component '{name}'.") + if isinstance(module, list): + warnings.warn( + f"there are more than 1 component name `{name}`: {module}, use the first one `{module[0]}." + f" if want to use others, please set the full python path in `` directly." + ) + module = module[0] + path = f"{module}.{name}" + + return path diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 8ae8a31fb1..87228f18dd 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -62,18 +62,18 @@ zip_with, ) from .module import ( - ClassScanner, + ComponentScanner, InvalidPyTorchVersionError, OptionalImportError, damerau_levenshtein_distance, exact_version, export, - get_class, get_full_type_name, get_package_version, get_torch_version_tuple, - instantiate_class, + instantiate, load_submodules, + locate, look_up_option, min_version, optional_import, diff --git a/monai/utils/module.py b/monai/utils/module.py index 1901a58430..09eda4b16b 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -15,12 +15,12 @@ import re import sys import warnings -from functools import wraps +from functools import partial, wraps from importlib import import_module from pkgutil import walk_packages from re import match -from types import FunctionType -from typing import Any, Callable, Collection, Hashable, Iterable, List, Mapping, Sequence, Tuple, Union, cast +from types import FunctionType, ModuleType +from typing import Any, Callable, Collection, Dict, Hashable, Iterable, List, Mapping, Sequence, Tuple, Union, cast import torch @@ -37,9 +37,9 @@ "optional_import", "require_pkg", "load_submodules", - "ClassScanner", - "get_class", - "instantiate_class", + "ComponentScanner", + "locate", + "instantiate", "get_full_type_name", "get_package_version", "get_torch_version_tuple", @@ -197,14 +197,14 @@ def load_submodules(basemod, load_all: bool = True, exclude_pattern: str = "(.*[ return submodules, err_mod -class ClassScanner: +class ComponentScanner: """ - Scan all the available classes in the specified packages and modules. - Map the all the class names and the module names in a table. + Scan all the available classes and functions in the specified packages and modules. + Map the all the names and the module names in a table. Args: - pkgs: the expected packages to scan modules and parse class names in the config. - modules: the expected modules in the packages to scan for all the classes. + pkgs: the expected packages to scan modules and parse component names in the config. + modules: the expected modules in the packages to scan for all the components. for example, to parser "LoadImage" in config, `pkgs` can be ["monai"], `modules` can be ["transforms"]. """ @@ -212,10 +212,10 @@ class ClassScanner: def __init__(self, pkgs: Sequence[str], modules: Sequence[str]): self.pkgs = pkgs self.modules = modules - self._class_table = self._create_classes_table() + self._components_table = self._create_table() - def _create_classes_table(self): - class_table = {} + def _create_table(self): + table: Dict[str, List] = {} for pkg in self.pkgs: package = import_module(pkg) @@ -225,70 +225,86 @@ def _create_classes_table(self): try: module = import_module(modname) for name, obj in inspect.getmembers(module): - if inspect.isclass(obj) and obj.__module__ == modname: - class_table[name] = modname + if (inspect.isclass(obj) or inspect.isfunction(obj)) and obj.__module__ == modname: + if name not in table: + table[name] = [] + table[name].append(modname) except ModuleNotFoundError: pass - return class_table + return table - def get_class_module_name(self, class_name): + def get_component_module_name(self, name): """ - Get the module name of the class with specified class name. + Get the full module name of the class / function with specified name. + If target component name exists in multiple packages or modules, return all the paths.z Args: - class_name: name of the expected class. + name: name of the expected class or function. """ - return self._class_table.get(class_name, None) + mods = self._components_table.get(name, None) + if isinstance(mods, list) and len(mods) == 1: + mods = mods[0] + return mods -def get_class(class_path: str): +def locate(path: str): """ - Get the class from specified class path. + Locate an object by name or dotted path, importing as necessary. + Refer to Hydra: https://github.com/facebookresearch/hydra/blob/v1.1.1/hydra/_internal/utils.py#L554. Args: - class_path (str): full path of the class. - - Raises: - ValueError: invalid class_path, missing the module name. - ValueError: class does not exist. - ValueError: module does not exist. + path: full path of the expected component to locate and import. """ - if len(class_path.split(".")) < 2: - raise ValueError(f"invalid class_path: {class_path}, missing the module name.") - module_name, class_name = class_path.rsplit(".", 1) - - try: - module_ = import_module(module_name) - + parts = [part for part in path.split(".") if part] + for n in reversed(range(1, len(parts) + 1)): try: - class_ = getattr(module_, class_name) - except AttributeError as e: - raise ValueError(f"class {class_name} does not exist.") from e - - except AttributeError as e: - raise ValueError(f"module {module_name} does not exist.") from e - - return class_ - - -def instantiate_class(class_path: str, **kwargs): - """ - Method for creating an instance for the specified class. + obj = import_module(".".join(parts[:n])) + except Exception as exc_import: + if n == 1: + raise ImportError(f"can not import module '{path}'.") from exc_import + continue + break + + for m in range(n, len(parts)): + part = parts[m] + try: + obj = getattr(obj, part) + except AttributeError as exc_attr: + if isinstance(obj, ModuleType): + mod = ".".join(parts[: m + 1]) + try: + import_module(mod) + except ModuleNotFoundError: + pass + except Exception as exc_import: + raise ImportError(f"can not import module '{path}'.") from exc_import + raise ImportError(f"AttributeError while loading '{path}': {exc_attr}") from exc_attr + return obj + + +def instantiate(path: str, **kwargs): + """ + Method for creating an instance for the specified class / function path. + kwargs will be class args or default args for `partial` function. + The target component must be a class or a function. Args: - class_path: full path of the class. - kwargs: arguments to initialize the class instance. + path: full path of the target class or function component. + kwargs: arguments to initialize the class instance or set default args + for `partial` function. - Raises: - ValueError: class has paramenters error. """ - try: - return get_class(class_path)(**kwargs) - except TypeError as e: - raise ValueError(f"class {class_path} has parameters error.") from e + component = locate(path) + if inspect.isclass(component): + return component(**kwargs) + if inspect.isfunction(component): + return partial(component, **kwargs) + + warnings.warn(f"target component must be a valid class or function, but got {path}.") + return component def get_full_type_name(typeobj): diff --git a/tests/test_config_component.py b/tests/test_config_component.py index d48a4e274f..89acfdaebb 100644 --- a/tests/test_config_component.py +++ b/tests/test_config_component.py @@ -10,16 +10,18 @@ # limitations under the License. import unittest +from functools import partial from typing import Callable, Iterator import torch from parameterized import parameterized +from torch.optim._multi_tensor import Adam import monai from monai.apps import ConfigComponent from monai.data import DataLoader, Dataset from monai.transforms import LoadImaged, RandTorchVisiond -from monai.utils import ClassScanner, optional_import +from monai.utils import ComponentScanner, optional_import _, has_tv = optional_import("torchvision") @@ -50,68 +52,74 @@ TEST_CASE_5 = [ dict(pkgs=["torch.optim", "monai"], modules=["adam"]), {"": "Adam", "": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, - torch.optim.Adam, + Adam, ] -# test args contains "name" field TEST_CASE_6 = [ + dict(pkgs=["monai"], modules=["data"]), + {"": "decollate_batch", "": {"detach": True, "pad": True}}, + partial, +] +# test args contains "name" field +TEST_CASE_7 = [ dict(pkgs=["monai"], modules=["transforms"]), {"": "RandTorchVisiond", "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}}, RandTorchVisiond, ] # test dependencies of dict config -TEST_CASE_7 = [{"dataset": "@dataset", "batch_size": 2}, ["test#dataset", "dataset", "test#batch_size"]] +TEST_CASE_8 = [{"dataset": "@dataset", "batch_size": 2}, ["test#dataset", "dataset", "test#batch_size"]] # test dependencies of list config -TEST_CASE_8 = [ +TEST_CASE_9 = [ {"dataset": "@dataset", "transforms": ["@trans0", "@trans1"]}, ["test#dataset", "dataset", "test#transforms", "test#transforms#0", "trans0", "test#transforms#1", "trans1"], ] # test dependencies of execute code -TEST_CASE_9 = [ +TEST_CASE_10 = [ {"dataset": "$@dataset.test_func()", "transforms": ["$torch.zeros([2, 2]) + @trans"]}, ["test#dataset", "dataset", "test#transforms", "test#transforms#0", "trans"], ] # test dependencies of lambda function -TEST_CASE_10 = [ +TEST_CASE_11 = [ {"lr_range": "$lambda x: x + @num_epochs", "lr": ["$lambda x: torch.zeros([2, 2]) + @init_lr"]}, ["test#lr_range", "num_epochs", "test#lr", "test#lr#0", "init_lr"], ] # test instance with no dependencies -TEST_CASE_11 = [ +TEST_CASE_12 = [ "transform#1", {"": "LoadImaged", "": {"keys": ["image"]}}, {"transform#1#": "LoadImaged", "transform#1#": {"keys": ["image"]}}, LoadImaged, ] # test dataloader refers to `@dataset`, here we don't test recursive dependencies, test that in `ConfigResolver` -TEST_CASE_12 = [ +TEST_CASE_13 = [ "dataloader", {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}}, {"dataloader#": "DataLoader", "dataloader#": {"dataset": Dataset(data=[1, 2]), "batch_size": 2}}, DataLoader, ] # test dependencies in code execution -TEST_CASE_13 = [ +TEST_CASE_14 = [ "optimizer", {"": "Adam", "": {"params": "$@model.parameters()", "lr": "@learning_rate"}}, {"optimizer#": "Adam", "optimizer#": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, - torch.optim.Adam, + Adam, ] # test replace dependencies with code execution result -TEST_CASE_14 = ["optimizer##params", "$@model.parameters()", {"model": torch.nn.PReLU()}, Iterator] +TEST_CASE_15 = ["optimizer##params", "$@model.parameters()", {"model": torch.nn.PReLU()}, Iterator] # test execute some function in args, test pre-imported global packages `monai` -TEST_CASE_15 = ["dataloader##collate_fn", "$monai.data.list_data_collate", {}, Callable] +TEST_CASE_16 = ["dataloader##collate_fn", "$monai.data.list_data_collate", {}, Callable] # test lambda function, should not execute the lambda function, just change the string with dependent objects -TEST_CASE_16 = ["dataloader##collate_fn", "$lambda x: monai.data.list_data_collate(x) + 100", {}, Callable] +TEST_CASE_17 = ["dataloader##collate_fn", "$lambda x: monai.data.list_data_collate(x) + 100", {}, Callable] class TestConfigComponent(unittest.TestCase): @parameterized.expand( - [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5] + ([TEST_CASE_6] if has_tv else []) + [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6] + + ([TEST_CASE_7] if has_tv else []) ) def test_build(self, input_param, test_input, output_type): - scanner = ClassScanner(**input_param) - configer = ConfigComponent(id="test", config=test_input, class_scanner=scanner) + scanner = ComponentScanner(**input_param) + configer = ConfigComponent(id="test", config=test_input, scanner=scanner) ret = configer.build() self.assertTrue(isinstance(ret, output_type)) if isinstance(ret, LoadImaged): @@ -120,19 +128,17 @@ def test_build(self, input_param, test_input, output_type): # test `` works fine self.assertDictEqual(ret, test_input) - @parameterized.expand([TEST_CASE_7, TEST_CASE_8, TEST_CASE_9, TEST_CASE_10]) + @parameterized.expand([TEST_CASE_8, TEST_CASE_9, TEST_CASE_10, TEST_CASE_11]) def test_dependent_ids(self, test_input, ref_ids): - scanner = ClassScanner(pkgs=[], modules=[]) - configer = ConfigComponent(id="test", config=test_input, class_scanner=scanner) + scanner = ComponentScanner(pkgs=[], modules=[]) + configer = ConfigComponent(id="test", config=test_input, scanner=scanner) ret = configer.get_dependent_ids() self.assertListEqual(ret, ref_ids) - @parameterized.expand([TEST_CASE_11, TEST_CASE_12, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15, TEST_CASE_16]) + @parameterized.expand([TEST_CASE_12, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15, TEST_CASE_16, TEST_CASE_17]) def test_update_dependencies(self, id, test_input, deps, output_type): - scanner = ClassScanner(pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"]) - configer = ConfigComponent( - id=id, config=test_input, class_scanner=scanner, globals={"monai": monai, "torch": torch} - ) + scanner = ComponentScanner(pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"]) + configer = ConfigComponent(id=id, config=test_input, scanner=scanner, globals={"monai": monai, "torch": torch}) config = configer.get_updated_config(deps) ret = configer.build(config) self.assertTrue(isinstance(ret, output_type)) From 598585f1bafb3c8cb983c90d61aec979820cbdf3 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 11 Feb 2022 18:51:06 +0800 Subject: [PATCH 03/25] [DLMED] add excludes Signed-off-by: Nic Ma --- monai/utils/module.py | 28 ++++++++++++++++++++++++---- tests/test_config_component.py | 13 +++++++------ 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/monai/utils/module.py b/monai/utils/module.py index 09eda4b16b..383f24d189 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -20,7 +20,21 @@ from pkgutil import walk_packages from re import match from types import FunctionType, ModuleType -from typing import Any, Callable, Collection, Dict, Hashable, Iterable, List, Mapping, Sequence, Tuple, Union, cast +from typing import ( + Any, + Callable, + Collection, + Dict, + Hashable, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, + cast, +) import torch @@ -206,12 +220,16 @@ class ComponentScanner: pkgs: the expected packages to scan modules and parse component names in the config. modules: the expected modules in the packages to scan for all the components. for example, to parser "LoadImage" in config, `pkgs` can be ["monai"], `modules` can be ["transforms"]. + excludes: if any string of the `excludes` exists in the full module name, don't import this module. """ - def __init__(self, pkgs: Sequence[str], modules: Sequence[str]): + def __init__( + self, pkgs: Sequence[str], modules: Optional[Sequence[str]] = None, excludes: Optional[Sequence[str]] = None + ): self.pkgs = pkgs - self.modules = modules + self.modules = [] if modules is None else modules + self.excludes = [] if excludes is None else excludes self._components_table = self._create_table() def _create_table(self): @@ -221,7 +239,9 @@ def _create_table(self): for _, modname, _ in walk_packages(path=package.__path__, prefix=package.__name__ + "."): # if no modules specified, load all modules in the package - if len(self.modules) == 0 or any(name in modname for name in self.modules): + if all(s not in modname for s in self.excludes) and ( + len(self.modules) == 0 or any(name in modname for name in self.modules) + ): try: module = import_module(modname) for name, obj in inspect.getmembers(module): diff --git a/tests/test_config_component.py b/tests/test_config_component.py index 89acfdaebb..3f801bf98d 100644 --- a/tests/test_config_component.py +++ b/tests/test_config_component.py @@ -15,7 +15,6 @@ import torch from parameterized import parameterized -from torch.optim._multi_tensor import Adam import monai from monai.apps import ConfigComponent @@ -48,11 +47,11 @@ {"": "LoadImaged", "": {"keys": ["@key_name"]}}, dict, ] -# test non-monai modules +# test non-monai modules and excludes TEST_CASE_5 = [ - dict(pkgs=["torch.optim", "monai"], modules=["adam"]), + dict(pkgs=["torch.optim", "monai"], modules=["adam"], excludes=["_multi_tensor"]), {"": "Adam", "": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, - Adam, + torch.optim.Adam, ] TEST_CASE_6 = [ dict(pkgs=["monai"], modules=["data"]), @@ -102,7 +101,7 @@ "optimizer", {"": "Adam", "": {"params": "$@model.parameters()", "lr": "@learning_rate"}}, {"optimizer#": "Adam", "optimizer#": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, - Adam, + torch.optim.Adam, ] # test replace dependencies with code execution result TEST_CASE_15 = ["optimizer##params", "$@model.parameters()", {"model": torch.nn.PReLU()}, Iterator] @@ -137,7 +136,9 @@ def test_dependent_ids(self, test_input, ref_ids): @parameterized.expand([TEST_CASE_12, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15, TEST_CASE_16, TEST_CASE_17]) def test_update_dependencies(self, id, test_input, deps, output_type): - scanner = ComponentScanner(pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"]) + scanner = ComponentScanner( + pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"], excludes=["_multi_tensor"] + ) configer = ConfigComponent(id=id, config=test_input, scanner=scanner, globals={"monai": monai, "torch": torch}) config = configer.get_updated_config(deps) ret = configer.build(config) From 7d13b474e94d33e37cef48f142ca2101a60a41cb Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 11 Feb 2022 21:18:19 +0800 Subject: [PATCH 04/25] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/utils/module.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/utils/module.py b/monai/utils/module.py index 383f24d189..8bd80bfe38 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -277,15 +277,15 @@ def locate(path: str): path: full path of the expected component to locate and import. """ - parts = [part for part in path.split(".") if part] - for n in reversed(range(1, len(parts) + 1)): + parts = [p for p in path.split(".") if p] + for n in range(len(parts), 0, -1): try: obj = import_module(".".join(parts[:n])) + break except Exception as exc_import: if n == 1: raise ImportError(f"can not import module '{path}'.") from exc_import continue - break for m in range(n, len(parts)): part = parts[m] From cce6ff64719d2a52c67f4bd705c0dd7d1672f4a7 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 11 Feb 2022 23:34:24 +0800 Subject: [PATCH 05/25] [DLMED] update ComponentScanner Signed-off-by: Nic Ma --- monai/apps/__init__.py | 1 + monai/apps/mmars/__init__.py | 2 +- monai/apps/mmars/config_resolver.py | 69 +++++++++++++++++++++++++- monai/utils/__init__.py | 1 - monai/utils/module.py | 76 +---------------------------- tests/test_config_component.py | 16 +++--- 6 files changed, 78 insertions(+), 87 deletions(-) diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index 46203fb09b..8348a4b7f7 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -12,6 +12,7 @@ from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset from .mmars import ( MODEL_DESC, + ComponentScanner, ConfigComponent, RemoteMMARKeys, download_mmar, diff --git a/monai/apps/mmars/__init__.py b/monai/apps/mmars/__init__.py index f565e4c6d9..e26f93ddda 100644 --- a/monai/apps/mmars/__init__.py +++ b/monai/apps/mmars/__init__.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_resolver import ConfigComponent +from .config_resolver import ComponentScanner, ConfigComponent from .mmars import download_mmar, get_model_spec, load_from_mmar from .model_desc import MODEL_DESC, RemoteMMARKeys from .utils import search_configs_with_deps, update_configs_with_deps diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py index 60ec3c6a7a..b35017e78d 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_resolver.py @@ -9,11 +9,76 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import warnings -from typing import Any, Dict, List, Optional +from importlib import import_module +from pkgutil import walk_packages +from typing import Any, Dict, List, Optional, Sequence from monai.apps.mmars.utils import search_configs_with_deps, update_configs_with_deps -from monai.utils.module import ComponentScanner, instantiate +from monai.utils.module import instantiate + +__all__ = ["ComponentScanner", "ConfigComponent"] + + +class ComponentScanner: + """ + Scan all the available classes and functions in the specified packages and modules. + Map the all the names and the module names in a table. + + Args: + pkgs: the expected packages to scan modules and parse component names in the config. + modules: the expected modules in the packages to scan for all the components. + for example, to parser "LoadImage" in config, `pkgs` can be ["monai"], `modules` can be ["transforms"]. + excludes: if any string of the `excludes` exists in the full module name, don't import this module. + + """ + + def __init__( + self, pkgs: Sequence[str], modules: Optional[Sequence[str]] = None, excludes: Optional[Sequence[str]] = None + ): + for p in pkgs: + if not p.startswith("monai"): + raise ValueError("only support to scan MONAI package so far.") + self.pkgs = pkgs + self.modules = [] if modules is None else modules + self.excludes = [] if excludes is None else excludes + self._components_table = self._create_table() + + def _create_table(self): + table: Dict[str, List] = {} + for pkg in self.pkgs: + package = import_module(pkg) + + for _, modname, _ in walk_packages(path=package.__path__, prefix=package.__name__ + "."): + # if no modules specified, load all modules in the package + if all(s not in modname for s in self.excludes) and ( + len(self.modules) == 0 or any(name in modname for name in self.modules) + ): + try: + module = import_module(modname) + for name, obj in inspect.getmembers(module): + if (inspect.isclass(obj) or inspect.isfunction(obj)) and obj.__module__ == modname: + if name not in table: + table[name] = [] + table[name].append(modname) + except ModuleNotFoundError: + pass + return table + + def get_component_module_name(self, name): + """ + Get the full module name of the class / function with specified name. + If target component name exists in multiple packages or modules, return all the paths.z + + Args: + name: name of the expected class or function. + + """ + mods = self._components_table.get(name, None) + if isinstance(mods, list) and len(mods) == 1: + mods = mods[0] + return mods class ConfigComponent: diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 87228f18dd..e5a9a34bb4 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -62,7 +62,6 @@ zip_with, ) from .module import ( - ComponentScanner, InvalidPyTorchVersionError, OptionalImportError, damerau_levenshtein_distance, diff --git a/monai/utils/module.py b/monai/utils/module.py index 8bd80bfe38..adb24cb91b 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -20,21 +20,7 @@ from pkgutil import walk_packages from re import match from types import FunctionType, ModuleType -from typing import ( - Any, - Callable, - Collection, - Dict, - Hashable, - Iterable, - List, - Mapping, - Optional, - Sequence, - Tuple, - Union, - cast, -) +from typing import Any, Callable, Collection, Hashable, Iterable, List, Mapping, Tuple, Union, cast import torch @@ -51,7 +37,6 @@ "optional_import", "require_pkg", "load_submodules", - "ComponentScanner", "locate", "instantiate", "get_full_type_name", @@ -211,63 +196,6 @@ def load_submodules(basemod, load_all: bool = True, exclude_pattern: str = "(.*[ return submodules, err_mod -class ComponentScanner: - """ - Scan all the available classes and functions in the specified packages and modules. - Map the all the names and the module names in a table. - - Args: - pkgs: the expected packages to scan modules and parse component names in the config. - modules: the expected modules in the packages to scan for all the components. - for example, to parser "LoadImage" in config, `pkgs` can be ["monai"], `modules` can be ["transforms"]. - excludes: if any string of the `excludes` exists in the full module name, don't import this module. - - """ - - def __init__( - self, pkgs: Sequence[str], modules: Optional[Sequence[str]] = None, excludes: Optional[Sequence[str]] = None - ): - self.pkgs = pkgs - self.modules = [] if modules is None else modules - self.excludes = [] if excludes is None else excludes - self._components_table = self._create_table() - - def _create_table(self): - table: Dict[str, List] = {} - for pkg in self.pkgs: - package = import_module(pkg) - - for _, modname, _ in walk_packages(path=package.__path__, prefix=package.__name__ + "."): - # if no modules specified, load all modules in the package - if all(s not in modname for s in self.excludes) and ( - len(self.modules) == 0 or any(name in modname for name in self.modules) - ): - try: - module = import_module(modname) - for name, obj in inspect.getmembers(module): - if (inspect.isclass(obj) or inspect.isfunction(obj)) and obj.__module__ == modname: - if name not in table: - table[name] = [] - table[name].append(modname) - except ModuleNotFoundError: - pass - return table - - def get_component_module_name(self, name): - """ - Get the full module name of the class / function with specified name. - If target component name exists in multiple packages or modules, return all the paths.z - - Args: - name: name of the expected class or function. - - """ - mods = self._components_table.get(name, None) - if isinstance(mods, list) and len(mods) == 1: - mods = mods[0] - return mods - - def locate(path: str): """ Locate an object by name or dotted path, importing as necessary. @@ -308,7 +236,7 @@ def instantiate(path: str, **kwargs): """ Method for creating an instance for the specified class / function path. kwargs will be class args or default args for `partial` function. - The target component must be a class or a function. + The target component must be a class or a function, if not, return the component directly. Args: path: full path of the target class or function component. diff --git a/tests/test_config_component.py b/tests/test_config_component.py index 3f801bf98d..e628509849 100644 --- a/tests/test_config_component.py +++ b/tests/test_config_component.py @@ -17,10 +17,10 @@ from parameterized import parameterized import monai -from monai.apps import ConfigComponent +from monai.apps import ComponentScanner, ConfigComponent from monai.data import DataLoader, Dataset from monai.transforms import LoadImaged, RandTorchVisiond -from monai.utils import ComponentScanner, optional_import +from monai.utils import optional_import _, has_tv = optional_import("torchvision") @@ -49,8 +49,8 @@ ] # test non-monai modules and excludes TEST_CASE_5 = [ - dict(pkgs=["torch.optim", "monai"], modules=["adam"], excludes=["_multi_tensor"]), - {"": "Adam", "": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, + dict(pkgs=["monai"], modules=["data"], excludes=["utils"]), + {"": "torch.optim.Adam", "": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, torch.optim.Adam, ] TEST_CASE_6 = [ @@ -99,8 +99,8 @@ # test dependencies in code execution TEST_CASE_14 = [ "optimizer", - {"": "Adam", "": {"params": "$@model.parameters()", "lr": "@learning_rate"}}, - {"optimizer#": "Adam", "optimizer#": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, + {"": "torch.optim.Adam", "": {"params": "$@model.parameters()", "lr": "@learning_rate"}}, + {"optimizer#": "torch.optim.Adam", "optimizer#": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, torch.optim.Adam, ] # test replace dependencies with code execution result @@ -136,9 +136,7 @@ def test_dependent_ids(self, test_input, ref_ids): @parameterized.expand([TEST_CASE_12, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15, TEST_CASE_16, TEST_CASE_17]) def test_update_dependencies(self, id, test_input, deps, output_type): - scanner = ComponentScanner( - pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"], excludes=["_multi_tensor"] - ) + scanner = ComponentScanner(pkgs=["monai"], modules=["data", "transforms", "data"], excludes=["utils"]) configer = ConfigComponent(id=id, config=test_input, scanner=scanner, globals={"monai": monai, "torch": torch}) config = configer.get_updated_config(deps) ret = configer.build(config) From 2659b2ac1b11cc20d1c82605fb1be538222b55c6 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 12 Feb 2022 00:42:48 +0800 Subject: [PATCH 06/25] [DLMED] enhance doc Signed-off-by: Nic Ma --- docs/source/apps.rst | 3 +++ monai/apps/mmars/config_resolver.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 8393566254..fe3be6a7fc 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -35,6 +35,9 @@ Model Package .. autoclass:: ConfigComponent :members: +.. autoclass:: ComponentScanner + :members: + `Utilities` ----------- diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py index b35017e78d..463c9aa45f 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_resolver.py @@ -148,7 +148,7 @@ def get_dependent_ids(self) -> List[str]: def get_updated_config(self, deps: dict): """ If all the dependencies are ready in `deps`, update the config content with them and return new config. - It can be used for lazy instantiation. + It can be used for lazy instantiation, the returned config has no dependencies, can be built immediately. Args: deps: all the dependent components with ids. @@ -198,7 +198,7 @@ def build(self, config: Optional[Dict] = None, **kwargs) -> object: """ config = self.config if config is None else config if self._check_dependency(config=config): - warnings.warn("config content has other dependencies or executable string, skip `build`.") + warnings.warn("config content still has other dependencies or executable strings to run, skip `build`.") return config if ( From 6eb45ffcf30cdd80fbca661460fc02c18372d562 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 12 Feb 2022 07:59:54 +0800 Subject: [PATCH 07/25] [DLMED] use load_submodules Signed-off-by: Nic Ma --- monai/apps/mmars/config_resolver.py | 46 ++++++++++------------------- tests/test_config_component.py | 40 ++++++------------------- 2 files changed, 25 insertions(+), 61 deletions(-) diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py index 463c9aa45f..059a65c66b 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_resolver.py @@ -10,9 +10,9 @@ # limitations under the License. import inspect +import sys import warnings from importlib import import_module -from pkgutil import walk_packages from typing import Any, Dict, List, Optional, Sequence from monai.apps.mmars.utils import search_configs_with_deps, update_configs_with_deps @@ -23,47 +23,33 @@ class ComponentScanner: """ - Scan all the available classes and functions in the specified packages and modules. + Scan all the available classes and functions in the MONAI package. Map the all the names and the module names in a table. Args: - pkgs: the expected packages to scan modules and parse component names in the config. - modules: the expected modules in the packages to scan for all the components. - for example, to parser "LoadImage" in config, `pkgs` can be ["monai"], `modules` can be ["transforms"]. excludes: if any string of the `excludes` exists in the full module name, don't import this module. """ - def __init__( - self, pkgs: Sequence[str], modules: Optional[Sequence[str]] = None, excludes: Optional[Sequence[str]] = None - ): - for p in pkgs: - if not p.startswith("monai"): - raise ValueError("only support to scan MONAI package so far.") - self.pkgs = pkgs - self.modules = [] if modules is None else modules + def __init__(self, excludes: Optional[Sequence[str]] = None): self.excludes = [] if excludes is None else excludes self._components_table = self._create_table() def _create_table(self): table: Dict[str, List] = {} - for pkg in self.pkgs: - package = import_module(pkg) - - for _, modname, _ in walk_packages(path=package.__path__, prefix=package.__name__ + "."): - # if no modules specified, load all modules in the package - if all(s not in modname for s in self.excludes) and ( - len(self.modules) == 0 or any(name in modname for name in self.modules) - ): - try: - module = import_module(modname) - for name, obj in inspect.getmembers(module): - if (inspect.isclass(obj) or inspect.isfunction(obj)) and obj.__module__ == modname: - if name not in table: - table[name] = [] - table[name].append(modname) - except ModuleNotFoundError: - pass + # all the MONAI modules are already loaded by `load_submodules` + modnames = [m for m in sys.modules.keys() if m.startswith("monai") and all(s not in m for s in self.excludes)] + for modname in modnames: + try: + # scan all the classes and functions in the module + module = import_module(modname) + for name, obj in inspect.getmembers(module): + if (inspect.isclass(obj) or inspect.isfunction(obj)) and obj.__module__ == modname: + if name not in table: + table[name] = [] + table[name].append(modname) + except ModuleNotFoundError: + pass return table def get_component_module_name(self, name): diff --git a/tests/test_config_component.py b/tests/test_config_component.py index e628509849..a3445e78d3 100644 --- a/tests/test_config_component.py +++ b/tests/test_config_component.py @@ -24,43 +24,21 @@ _, has_tv = optional_import("torchvision") -TEST_CASE_1 = [ - dict(pkgs=["monai"], modules=["transforms"]), - {"": "LoadImaged", "": {"keys": ["image"]}}, - LoadImaged, -] +TEST_CASE_1 = [{"": "LoadImaged", "": {"keys": ["image"]}}, LoadImaged] # test python `` -TEST_CASE_2 = [ - dict(pkgs=[], modules=[]), - {"": "monai.transforms.LoadImaged", "": {"keys": ["image"]}}, - LoadImaged, -] +TEST_CASE_2 = [{"": "monai.transforms.LoadImaged", "": {"keys": ["image"]}}, LoadImaged] # test `` -TEST_CASE_3 = [ - dict(pkgs=["monai"], modules=["transforms"]), - {"": "LoadImaged", "": True, "": {"keys": ["image"]}}, - dict, -] +TEST_CASE_3 = [{"": "LoadImaged", "": True, "": {"keys": ["image"]}}, dict] # test unresolved dependency -TEST_CASE_4 = [ - dict(pkgs=["monai"], modules=["transforms"]), - {"": "LoadImaged", "": {"keys": ["@key_name"]}}, - dict, -] +TEST_CASE_4 = [{"": "LoadImaged", "": {"keys": ["@key_name"]}}, dict] # test non-monai modules and excludes TEST_CASE_5 = [ - dict(pkgs=["monai"], modules=["data"], excludes=["utils"]), {"": "torch.optim.Adam", "": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, torch.optim.Adam, ] -TEST_CASE_6 = [ - dict(pkgs=["monai"], modules=["data"]), - {"": "decollate_batch", "": {"detach": True, "pad": True}}, - partial, -] +TEST_CASE_6 = [{"": "decollate_batch", "": {"detach": True, "pad": True}}, partial] # test args contains "name" field TEST_CASE_7 = [ - dict(pkgs=["monai"], modules=["transforms"]), {"": "RandTorchVisiond", "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}}, RandTorchVisiond, ] @@ -116,8 +94,8 @@ class TestConfigComponent(unittest.TestCase): [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6] + ([TEST_CASE_7] if has_tv else []) ) - def test_build(self, input_param, test_input, output_type): - scanner = ComponentScanner(**input_param) + def test_build(self, test_input, output_type): + scanner = ComponentScanner(excludes=["metrics"]) configer = ConfigComponent(id="test", config=test_input, scanner=scanner) ret = configer.build() self.assertTrue(isinstance(ret, output_type)) @@ -129,14 +107,14 @@ def test_build(self, input_param, test_input, output_type): @parameterized.expand([TEST_CASE_8, TEST_CASE_9, TEST_CASE_10, TEST_CASE_11]) def test_dependent_ids(self, test_input, ref_ids): - scanner = ComponentScanner(pkgs=[], modules=[]) + scanner = ComponentScanner() configer = ConfigComponent(id="test", config=test_input, scanner=scanner) ret = configer.get_dependent_ids() self.assertListEqual(ret, ref_ids) @parameterized.expand([TEST_CASE_12, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15, TEST_CASE_16, TEST_CASE_17]) def test_update_dependencies(self, id, test_input, deps, output_type): - scanner = ComponentScanner(pkgs=["monai"], modules=["data", "transforms", "data"], excludes=["utils"]) + scanner = ComponentScanner(excludes=["utils"]) configer = ConfigComponent(id=id, config=test_input, scanner=scanner, globals={"monai": monai, "torch": torch}) config = configer.get_updated_config(deps) ret = configer.build(config) From 96617a81d1f0cf7cc2c40345daac8d8ecd650bbe Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 12 Feb 2022 17:56:18 +0800 Subject: [PATCH 08/25] [DLMED] remove locate Signed-off-by: Nic Ma --- monai/utils/__init__.py | 1 - monai/utils/module.py | 40 ++-------------------------------------- 2 files changed, 2 insertions(+), 39 deletions(-) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index e5a9a34bb4..636ea15c8d 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -72,7 +72,6 @@ get_torch_version_tuple, instantiate, load_submodules, - locate, look_up_option, min_version, optional_import, diff --git a/monai/utils/module.py b/monai/utils/module.py index adb24cb91b..1dcbe6849f 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -18,8 +18,9 @@ from functools import partial, wraps from importlib import import_module from pkgutil import walk_packages +from pydoc import locate from re import match -from types import FunctionType, ModuleType +from types import FunctionType from typing import Any, Callable, Collection, Hashable, Iterable, List, Mapping, Tuple, Union, cast import torch @@ -37,7 +38,6 @@ "optional_import", "require_pkg", "load_submodules", - "locate", "instantiate", "get_full_type_name", "get_package_version", @@ -196,42 +196,6 @@ def load_submodules(basemod, load_all: bool = True, exclude_pattern: str = "(.*[ return submodules, err_mod -def locate(path: str): - """ - Locate an object by name or dotted path, importing as necessary. - Refer to Hydra: https://github.com/facebookresearch/hydra/blob/v1.1.1/hydra/_internal/utils.py#L554. - - Args: - path: full path of the expected component to locate and import. - - """ - parts = [p for p in path.split(".") if p] - for n in range(len(parts), 0, -1): - try: - obj = import_module(".".join(parts[:n])) - break - except Exception as exc_import: - if n == 1: - raise ImportError(f"can not import module '{path}'.") from exc_import - continue - - for m in range(n, len(parts)): - part = parts[m] - try: - obj = getattr(obj, part) - except AttributeError as exc_attr: - if isinstance(obj, ModuleType): - mod = ".".join(parts[: m + 1]) - try: - import_module(mod) - except ModuleNotFoundError: - pass - except Exception as exc_import: - raise ImportError(f"can not import module '{path}'.") from exc_import - raise ImportError(f"AttributeError while loading '{path}': {exc_attr}") from exc_attr - return obj - - def instantiate(path: str, **kwargs): """ Method for creating an instance for the specified class / function path. From 8f976250f109ca43af211effa55fd6ab07436847 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 12 Feb 2022 23:04:15 +0800 Subject: [PATCH 09/25] [DLMED] add test to ensure all components support `locate` Signed-off-by: Nic Ma --- tests/test_component_scanner.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 tests/test_component_scanner.py diff --git a/tests/test_component_scanner.py b/tests/test_component_scanner.py new file mode 100644 index 0000000000..6c1eda6056 --- /dev/null +++ b/tests/test_component_scanner.py @@ -0,0 +1,30 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from pydoc import locate + +from monai.apps.mmars import ComponentScanner + + +class TestComponentScanner(unittest.TestCase): + def test_locate(self): + scanner = ComponentScanner(excludes=None) + self.assertGreater(len(scanner._components_table), 0) + for _, mods in scanner._components_table.items(): + for i in mods: + self.assertGreater(len(mods), 0) + # ensure we can locate all the items by `name` + self.assertIsNotNone(locate(i), msg=f"can not locate target: {i}.") + + +if __name__ == "__main__": + unittest.main() From 17e2493f963d56c60c751158e32187fbb38fe1d0 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 12 Feb 2022 23:20:16 +0800 Subject: [PATCH 10/25] [DLMED] fix min_tests Signed-off-by: Nic Ma --- tests/test_component_scanner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_component_scanner.py b/tests/test_component_scanner.py index 6c1eda6056..16f13f3ce5 100644 --- a/tests/test_component_scanner.py +++ b/tests/test_component_scanner.py @@ -13,11 +13,14 @@ from pydoc import locate from monai.apps.mmars import ComponentScanner +from monai.utils import optional_import + +_, has_ignite = optional_import("ignite") class TestComponentScanner(unittest.TestCase): def test_locate(self): - scanner = ComponentScanner(excludes=None) + scanner = ComponentScanner(excludes=None if has_ignite else ["monai.handlers"]) self.assertGreater(len(scanner._components_table), 0) for _, mods in scanner._components_table.items(): for i in mods: From 24cd6c3cfef5b69d989181d9d6da55a455427553 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 14 Feb 2022 12:35:34 +0800 Subject: [PATCH 11/25] [DLMED] update according to comments Signed-off-by: Nic Ma --- docs/source/apps.rst | 4 +- monai/apps/__init__.py | 2 +- monai/apps/mmars/__init__.py | 2 +- monai/apps/mmars/config_resolver.py | 48 +++++++++++-------- ...t_scanner.py => test_component_locator.py} | 12 +++-- 5 files changed, 39 insertions(+), 29 deletions(-) rename tests/{test_component_scanner.py => test_component_locator.py} (69%) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index fe3be6a7fc..d54e7a2f8f 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -32,10 +32,10 @@ Clara MMARs Model Package ------------- -.. autoclass:: ConfigComponent +.. autoclass:: ComponentLocator :members: -.. autoclass:: ComponentScanner +.. autoclass:: ConfigComponent :members: diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index 8348a4b7f7..d0f3fbeb92 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -12,7 +12,7 @@ from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset from .mmars import ( MODEL_DESC, - ComponentScanner, + ComponentLocator, ConfigComponent, RemoteMMARKeys, download_mmar, diff --git a/monai/apps/mmars/__init__.py b/monai/apps/mmars/__init__.py index e26f93ddda..78671c361a 100644 --- a/monai/apps/mmars/__init__.py +++ b/monai/apps/mmars/__init__.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_resolver import ComponentScanner, ConfigComponent +from .config_resolver import ComponentLocator, ConfigComponent from .mmars import download_mmar, get_model_spec, load_from_mmar from .model_desc import MODEL_DESC, RemoteMMARKeys from .utils import search_configs_with_deps, update_configs_with_deps diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py index 059a65c66b..5ca917590c 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_resolver.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,33 +13,37 @@ import sys import warnings from importlib import import_module -from typing import Any, Dict, List, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence, Union from monai.apps.mmars.utils import search_configs_with_deps, update_configs_with_deps -from monai.utils.module import instantiate +from monai.utils import ensure_tuple, instantiate -__all__ = ["ComponentScanner", "ConfigComponent"] +__all__ = ["ComponentLocator", "ConfigComponent"] -class ComponentScanner: +class ComponentLocator: """ - Scan all the available classes and functions in the MONAI package. - Map the all the names and the module names in a table. + Scan all the available classes and functions in the MONAI package and map them with the module paths in a table. + It's used to locate the module path for provided component name. Args: excludes: if any string of the `excludes` exists in the full module name, don't import this module. """ - def __init__(self, excludes: Optional[Sequence[str]] = None): - self.excludes = [] if excludes is None else excludes - self._components_table = self._create_table() + MOD_START = "monai" - def _create_table(self): + def __init__(self, excludes: Optional[Union[Sequence[str], str]] = None): + self.excludes = [] if excludes is None else ensure_tuple(excludes) + self._components_table = None + + def _find_module_names(self) -> List[str]: + return [m for m in sys.modules.keys() if m.startswith(self.MOD_START) and all(s not in m for s in self.excludes)] + + def _find_classes_or_functions(self, modnames: Union[Sequence[str], str]): table: Dict[str, List] = {} # all the MONAI modules are already loaded by `load_submodules` - modnames = [m for m in sys.modules.keys() if m.startswith("monai") and all(s not in m for s in self.excludes)] - for modname in modnames: + for modname in ensure_tuple(modnames): try: # scan all the classes and functions in the module module = import_module(modname) @@ -52,15 +56,19 @@ def _create_table(self): pass return table - def get_component_module_name(self, name): + def get_component_module_name(self, name) -> Union[List[str], str]: """ Get the full module name of the class / function with specified name. - If target component name exists in multiple packages or modules, return all the paths.z + If target component name exists in multiple packages or modules, return a list of full module names. Args: name: name of the expected class or function. """ + if self._components_table is None: + # init component and module mapping table + self._components_table = self._find_classes_or_functions(self._find_module_names()) + mods = self._components_table.get(name, None) if isinstance(mods, list) and len(mods) == 1: mods = mods[0] @@ -70,7 +78,7 @@ def get_component_module_name(self, name): class ConfigComponent: """ Utility class to manage every component in the config with a unique `id` name. - When recursively parsing a complicated config dictioanry, every item should be treated as a `ConfigComponent`. + When recursively parsing a complicated config dictionary, every item should be treated as a `ConfigComponent`. For example: - `{"preprocessing": [{"": "LoadImage", "": {"keys": "image"}}]}` - `{"": "LoadImage", "": {"keys": "image"}}` @@ -93,17 +101,17 @@ class ConfigComponent: for example: `transform`, `transform#5`, `transform#5##keys`, etc. the id can be useful to quickly get the expected item in a complicated and nested config content. config: config content of current component, can be a `dict`, `list`, `string`, `float`, `int`, etc. - scanner: ComponentScanner to help get the `` or `` in the config and build instance. + locator: ComponentLocator to help locate the module path of `` in the config and build instance. globals: to support executable string in the config, sometimes we need to provide the global variables which are referred in the executable string. for example: `globals={"monai": monai} will be useful for config `"collate_fn": "$monai.data.list_data_collate"`. """ - def __init__(self, id: str, config: Any, scanner: ComponentScanner, globals: Optional[Dict] = None) -> None: + def __init__(self, id: str, config: Any, locator: ComponentLocator, globals: Optional[Dict] = None) -> None: self.id = id self.config = config - self.scanner = scanner + self.locator = locator self.globals = globals def get_id(self) -> str: @@ -213,7 +221,7 @@ def _get_path(self, config): name = config.get("", None) if name is None: raise ValueError("must provide `` or `` of target component to build.") - module = self.scanner.get_component_module_name(name) + module = self.locator.get_component_module_name(name) if module is None: raise ValueError(f"can not find component '{name}'.") if isinstance(module, list): diff --git a/tests/test_component_scanner.py b/tests/test_component_locator.py similarity index 69% rename from tests/test_component_scanner.py rename to tests/test_component_locator.py index 16f13f3ce5..adff25457a 100644 --- a/tests/test_component_scanner.py +++ b/tests/test_component_locator.py @@ -12,17 +12,19 @@ import unittest from pydoc import locate -from monai.apps.mmars import ComponentScanner +from monai.apps.mmars import ComponentLocator from monai.utils import optional_import _, has_ignite = optional_import("ignite") -class TestComponentScanner(unittest.TestCase): +class TestComponentLocator(unittest.TestCase): def test_locate(self): - scanner = ComponentScanner(excludes=None if has_ignite else ["monai.handlers"]) - self.assertGreater(len(scanner._components_table), 0) - for _, mods in scanner._components_table.items(): + locator = ComponentLocator(excludes=None if has_ignite else ["monai.handlers"]) + # test init mapping table and get the module path of component + self.assertEqual(locator.get_component_module_name("LoadImage"), "monai.transforms.io.array") + self.assertGreater(len(locator._components_table), 0) + for _, mods in locator._components_table.items(): for i in mods: self.assertGreater(len(mods), 0) # ensure we can locate all the items by `name` From a4ce07b802cb413022491bbd3c0a7b67d72fb3c3 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 14 Feb 2022 21:20:55 +0800 Subject: [PATCH 12/25] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/apps/__init__.py | 5 +- monai/apps/mmars/__init__.py | 2 +- monai/apps/mmars/config_resolver.py | 181 ++++++++++++++++------------ monai/apps/mmars/utils.py | 83 +++++++++---- tests/test_config_component.py | 65 +++++----- 5 files changed, 193 insertions(+), 143 deletions(-) diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index d0f3fbeb92..13257765fd 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -17,8 +17,9 @@ RemoteMMARKeys, download_mmar, get_model_spec, + is_to_build, load_from_mmar, - search_configs_with_deps, - update_configs_with_deps, + resolve_config_with_deps, + search_config_with_deps, ) from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/mmars/__init__.py b/monai/apps/mmars/__init__.py index 78671c361a..0896814bbb 100644 --- a/monai/apps/mmars/__init__.py +++ b/monai/apps/mmars/__init__.py @@ -12,4 +12,4 @@ from .config_resolver import ComponentLocator, ConfigComponent from .mmars import download_mmar, get_model_spec, load_from_mmar from .model_desc import MODEL_DESC, RemoteMMARKeys -from .utils import search_configs_with_deps, update_configs_with_deps +from .utils import is_to_build, resolve_config_with_deps, search_config_with_deps diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py index 5ca917590c..251409227e 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_resolver.py @@ -9,13 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from asyncio import FastChildWatcher import inspect import sys import warnings from importlib import import_module from typing import Any, Dict, List, Optional, Sequence, Union -from monai.apps.mmars.utils import search_configs_with_deps, update_configs_with_deps +from monai.apps.mmars.utils import is_to_build, resolve_config_with_deps, search_config_with_deps from monai.utils import ensure_tuple, instantiate __all__ = ["ComponentLocator", "ConfigComponent"] @@ -95,40 +96,78 @@ class ConfigComponent: - "@XXX": use an component as config item, like `"input_data": "@dataset"` uses `dataset` instance as parameter. - "$XXX": execute the string after "$" as python code with `eval()` function, like "$@model.parameters()". + The typical usage of the APIs: + - Initialize with config content. + - If no dependencies, `build` the component if having "" or "" keywords and return the instance. + - If having dependencies, get the IDs of its dependent components. + - When all the dependent components are built, update the config content with them, execute expressions in + the config and `build` instance. + Args: - id: id name of current config component, for nested config items, use `#` to join ids. - for list component, use index from `0` as id. - for example: `transform`, `transform#5`, `transform#5##keys`, etc. - the id can be useful to quickly get the expected item in a complicated and nested config content. config: config content of current component, can be a `dict`, `list`, `string`, `float`, `int`, etc. + no_deps: flag to mark whether the config has dependent components, default to `False`. if `True`, + no need to resolve dependencies before building. + id: ID name of current config component, useful to construct dependent components. + for example, component A may have ID "transforms#A" and component B depends on A + and uses the built instance of A as a dependent arg `"XXX": "@transforms#A"`. + for nested config items, use `#` to join ids, for list component, use index from `0` as id. + for example: `transform`, `transform#5`, `transform#5##keys`, etc. + the ID can be useful to quickly get the expected item in a complicated and nested config content. + ID defaults to `None`, if some component depends on current component, ID must be a `string`. locator: ComponentLocator to help locate the module path of `` in the config and build instance. + if `None`, will create a new ComponentLocator with specified `excludes`. + excludes: if `locator` is None, create a new ComponentLocator with `excludes`. any string of the `excludes` + exists in the full module name, don't import this module. globals: to support executable string in the config, sometimes we need to provide the global variables which are referred in the executable string. for example: `globals={"monai": monai} will be useful for config `"collate_fn": "$monai.data.list_data_collate"`. """ - def __init__(self, id: str, config: Any, locator: ComponentLocator, globals: Optional[Dict] = None) -> None: + def __init__( + self, + config: Any, + no_deps: bool = False, + id: Optional[str] = None, + locator: Optional[ComponentLocator] = None, + excludes: Optional[Union[Sequence[str], str]] = None, + globals: Optional[Dict] = None, + ) -> None: + self.config = None + self.resolved_config = None + self.is_resolved = False self.id = id - self.config = config - self.locator = locator + self.locator = ComponentLocator(excludes=excludes) if locator is None else locator self.globals = globals + self.set_config(config=config, no_deps=no_deps) def get_id(self) -> str: """ - Get the id name of current config component. + Get the unique ID of current component, useful to construct dependent components. + For example, component A may have ID "transforms#A" and component B depends on A + and uses the built instance of A as a dependent arg `"XXX": "@transforms#A"`. + ID defaults to `None`, if some component depends on current component, ID must be a string. """ return self.id + def set_config(self, config: Any, no_deps: bool = False): + self.config = config + self.resolved_config = None + self.is_resolved = False + if no_deps: + # if no dependencies, can resolve the config immediately + self.resolve_config(deps=None) + def get_config(self): """ - Get the raw config content of current config component. + Get the init config content of current config component, usually set at the constructor. + It can be useful for lazy instantiation to dynamically update the config content before resolving """ return self.config - def get_dependent_ids(self) -> List[str]: + def get_id_of_deps(self) -> List[str]: """ Recursively search all the content of current config compoent to get the ids of dependencies. It's used to build all the dependencies before build current config component. @@ -137,9 +176,9 @@ def get_dependent_ids(self) -> List[str]: `["", "", "#dataset", "dataset"]`. """ - return search_configs_with_deps(config=self.config, id=self.id) + return search_config_with_deps(config=self.config, id=self.id) - def get_updated_config(self, deps: dict): + def resolve_config(self, deps: dict): """ If all the dependencies are ready in `deps`, update the config content with them and return new config. It can be used for lazy instantiation, the returned config has no dependencies, can be built immediately. @@ -148,33 +187,45 @@ def get_updated_config(self, deps: dict): deps: all the dependent components with ids. """ - return update_configs_with_deps(config=self.config, deps=deps, id=self.id, globals=self.globals) - - def _check_dependency(self, config): - """ - Check whether current config still has unresolved dependencies or executable string code. + self.resolved_config = resolve_config_with_deps(config=self.config, deps=deps, id=self.id, globals=self.globals) + self.is_resolved = True - Args: - config: config content to check. + def get_resolved_config(self): + return self.resolved_config + def _resolve_module_name(self): + config = self.get_resolved_config() + path = config.get("", None) + if path is not None: + if "" in config: + warnings.warn(f"should not set both '' and '', default to use '': {path}.") + return path + + name = config.get("", None) + if name is None: + raise ValueError("must provide `` or `` of target component to build.") + + module = self.locator.get_component_module_name(name) + if module is None: + raise ModuleNotFoundError(f"can not find component '{name}' in {self.locator.MOD_START} modules.") + if isinstance(module, list): + warnings.warn( + f"there are more than 1 component name `{name}`: {module}, use the first one `{module[0]}." + f" if want to use others, please set the full python path in `` directly." + ) + module = module[0] + return f"{module}.{name}" + + def _resolve_args(self): + return self.get_resolved_config().get("", {}) + + def _is_disabled(self): + return self.get_resolved_config().get("", False) + + def build(self, **kwargs) -> object: """ - if isinstance(config, list): - for i in config: - if self._check_dependency(i): - return True - if isinstance(config, dict): - for v in config.values(): - if self._check_dependency(v): - return True - if isinstance(config, str): - if config.startswith("&") or "@" in config: - return True - return False - - def build(self, config: Optional[Dict] = None, **kwargs) -> object: - """ - Build component instance based on the provided dictonary config. - The target component must be a class and a function. + Build component instance based on the resolved config content. + The target component must be a `class` or a `function`. Supported special keys for the config: - '' - class / function name in the modules of packages. - '' - directly specify the path, based on PYTHONPATH, ignore '' if specified. @@ -182,7 +233,6 @@ def build(self, config: Optional[Dict] = None, **kwargs) -> object: - '' - if defined `'': True`, will skip the buiding, useful for development or tuning. Args: - config: dictionary config that defines a component. kwargs: args to override / add the config args when building. Raises: @@ -190,46 +240,17 @@ def build(self, config: Optional[Dict] = None, **kwargs) -> object: ValueError: can not find component class or function. """ - config = self.config if config is None else config - if self._check_dependency(config=config): - warnings.warn("config content still has other dependencies or executable strings to run, skip `build`.") - return config - - if ( - not isinstance(config, dict) - or ("" not in config and "" not in config) - or config.get("") is True - ): - # if marked as `disabled`, skip parsing - return config - - args = config.get("", {}) + if not self.is_resolved: + warnings.warn( + "the config content of current component has not been resolved," + " please try to resolve the dependencies first." + ) + config = self.get_resolved_config() + if not is_to_build(config) or self._is_disabled(): + # if not a class or function, or marked as `disabled`, skip parsing and return `None` + return None + + modname = self._resolve_module_name() + args = self._resolve_args() args.update(kwargs) - path = self._get_path(config) - return instantiate(path, **args) - - def _get_path(self, config): - """ - Get the path of class / function specified in the config content. - - Args: - config: dictionary config that defines a component. - - """ - path = config.get("", None) - if path is None: - name = config.get("", None) - if name is None: - raise ValueError("must provide `` or `` of target component to build.") - module = self.locator.get_component_module_name(name) - if module is None: - raise ValueError(f"can not find component '{name}'.") - if isinstance(module, list): - warnings.warn( - f"there are more than 1 component name `{name}`: {module}, use the first one `{module[0]}." - f" if want to use others, please set the full python path in `` directly." - ) - module = module[0] - path = f"{module}.{name}" - - return path + return instantiate(modname, **args) diff --git a/monai/apps/mmars/utils.py b/monai/apps/mmars/utils.py index f532ff8cc1..2ec59c902a 100644 --- a/monai/apps/mmars/utils.py +++ b/monai/apps/mmars/utils.py @@ -13,7 +13,23 @@ from typing import Dict, List, Optional, Union -def search_configs_with_deps(config: Union[Dict, List, str], id: str, deps: Optional[List[str]] = None) -> List[str]: +def is_to_build(config: Union[Dict, List, str]) -> bool: + """ + Check whether the target component of the config is a `class` or `function` to build + with specified "" or "". + + Args: + config: input config content to check. + + """ + return isinstance(config, dict) and ("" in config or "" in config) + + +def search_config_with_deps( + config: Union[Dict, List, str], + id: Optional[str] = None, + deps: Optional[List[str]] = None, +) -> List[str]: """ Recursively search all the content of input config compoent to get the ids of dependencies. It's used to build all the dependencies before build current config component. @@ -23,66 +39,81 @@ def search_configs_with_deps(config: Union[Dict, List, str], id: str, deps: Opti Args: config: input config content to search. - id: id name for the input config. - deps: list of the id name of existing dependencies, default to None. + id: ID name for the input config, default to `None`. + deps: list of the ID name of existing dependencies, default to `None`. """ deps_: List[str] = [] if deps is None else deps pattern = re.compile(r"@\w*[\#\w]*") # match ref as args: "@XXX#YYY#ZZZ" if isinstance(config, list): for i, v in enumerate(config): - sub_id = f"{id}#{i}" - # all the items in the list should be marked as dependent reference - deps_.append(sub_id) - deps_ = search_configs_with_deps(v, sub_id, deps_) + sub_id = f"{id}#{i}" if id is not None else f"{i}" + if is_to_build(v): + # sub-item is component need to build, mark as dependency + deps_.append(sub_id) + deps_ = search_config_with_deps(v, sub_id, deps_) if isinstance(config, dict): for k, v in config.items(): - sub_id = f"{id}#{k}" - # all the items in the dict should be marked as dependent reference - deps_.append(sub_id) - deps_ = search_configs_with_deps(v, sub_id, deps_) + sub_id = f"{id}#{k}" if id is not None else f"{k}" + if is_to_build(v): + # sub-item is component need to build, mark as dependency + deps_.append(sub_id) + deps_ = search_config_with_deps(v, sub_id, deps_) if isinstance(config, str): result = pattern.findall(config) for item in result: if config.startswith("$") or config == item: + # only check when string starts with "$" or the whole content is "@XXX" ref_obj_id = item[1:] if ref_obj_id not in deps_: deps_.append(ref_obj_id) return deps_ -def update_configs_with_deps(config: Union[Dict, List, str], deps: dict, id: str, globals: Optional[Dict] = None): +def resolve_config_with_deps( + config: Union[Dict, List, str], + deps: Optional[Dict] = None, + id: Optional[str] = None, + globals: Optional[Dict] = None, +): """ - With all the dependencies in `deps`, update the config content with them and return new config. + With all the dependencies in `deps`, resolve the config content with them and return new config. It can be used for lazy instantiation. Args: - config: input config content to update. - deps: all the dependent components with ids. - id: id name for the input config. + config: input config content to resolve. + deps: all the dependent components with ids, default to `None`. + id: id name for the input config, default to `None`. globals: predefined global variables to execute code string with `eval()`. """ + deps_: Dict = {} if deps is None else deps pattern = re.compile(r"@\w*[\#\w]*") # match ref as args: "@XXX#YYY#ZZZ" if isinstance(config, list): # all the items in the list should be replaced with the reference - return [deps[f"{id}#{i}"] for i in range(len(config))] + ret: List = [] + for i, v in enumerate(config): + sub_id = f"{id}#{i}" if id is not None else f"{i}" + ret.append(deps_[sub_id] if is_to_build(v) else resolve_config_with_deps(v, deps_, sub_id, globals)) + return ret if isinstance(config, dict): # all the items in the dict should be replaced with the reference - return {k: deps[f"{id}#{k}"] for k, _ in config.items()} + ret: Dict = {} + for k, v in config.items(): + sub_id = f"{id}#{k}" if id is not None else f"{k}" + ret[k] = deps_[sub_id] if is_to_build(v) else resolve_config_with_deps(v, deps_, sub_id, globals) + return ret if isinstance(config, str): result = pattern.findall(config) - config_: str = config + config_: str = config # to avoid mypy CI errors for item in result: ref_obj_id = item[1:] if config_.startswith("$"): - # replace with local code and execute soon - config_ = config_.replace(item, f"deps['{ref_obj_id}']") + # replace with local code and execute later + config_ = config_.replace(item, f"deps_['{ref_obj_id}']") elif config_ == item: - config_ = deps[ref_obj_id] - - if isinstance(config_, str): - if config_.startswith("$"): - config_ = eval(config_[1:], globals, {"deps": deps}) + config_ = deps_[ref_obj_id] + if isinstance(config_, str) and config_.startswith("$"): + config_ = eval(config_[1:], globals, {"deps_": deps_}) return config_ return config diff --git a/tests/test_config_component.py b/tests/test_config_component.py index a3445e78d3..ab5a91dff0 100644 --- a/tests/test_config_component.py +++ b/tests/test_config_component.py @@ -17,7 +17,7 @@ from parameterized import parameterized import monai -from monai.apps import ComponentScanner, ConfigComponent +from monai.apps import ComponentLocator, ConfigComponent, is_to_build from monai.data import DataLoader, Dataset from monai.transforms import LoadImaged, RandTorchVisiond from monai.utils import optional_import @@ -30,7 +30,7 @@ # test `` TEST_CASE_3 = [{"": "LoadImaged", "": True, "": {"keys": ["image"]}}, dict] # test unresolved dependency -TEST_CASE_4 = [{"": "LoadImaged", "": {"keys": ["@key_name"]}}, dict] +TEST_CASE_4 = [{"": "LoadImaged", "": {"keys": ["@key_name"]}}] # test non-monai modules and excludes TEST_CASE_5 = [ {"": "torch.optim.Adam", "": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, @@ -43,42 +43,32 @@ RandTorchVisiond, ] # test dependencies of dict config -TEST_CASE_8 = [{"dataset": "@dataset", "batch_size": 2}, ["test#dataset", "dataset", "test#batch_size"]] +TEST_CASE_8 = [{"dataset": "@dataset", "batch_size": 2}, ["dataset"]] # test dependencies of list config -TEST_CASE_9 = [ - {"dataset": "@dataset", "transforms": ["@trans0", "@trans1"]}, - ["test#dataset", "dataset", "test#transforms", "test#transforms#0", "trans0", "test#transforms#1", "trans1"], -] +TEST_CASE_9 = [{"dataset": "@dataset", "transforms": ["@trans0", "@trans1"]}, ["dataset", "trans0", "trans1"]] # test dependencies of execute code TEST_CASE_10 = [ - {"dataset": "$@dataset.test_func()", "transforms": ["$torch.zeros([2, 2]) + @trans"]}, - ["test#dataset", "dataset", "test#transforms", "test#transforms#0", "trans"], + {"dataset": "$@dataset.test_func()", "transforms": ["$torch.zeros([2, 2]) + @trans"]}, ["dataset", "trans"] ] - # test dependencies of lambda function TEST_CASE_11 = [ {"lr_range": "$lambda x: x + @num_epochs", "lr": ["$lambda x: torch.zeros([2, 2]) + @init_lr"]}, - ["test#lr_range", "num_epochs", "test#lr", "test#lr#0", "init_lr"], + ["num_epochs", "init_lr"], ] # test instance with no dependencies -TEST_CASE_12 = [ - "transform#1", - {"": "LoadImaged", "": {"keys": ["image"]}}, - {"transform#1#": "LoadImaged", "transform#1#": {"keys": ["image"]}}, - LoadImaged, -] +TEST_CASE_12 = ["transform#1", {"": "LoadImaged", "": {"keys": ["image"]}}, {}, LoadImaged] # test dataloader refers to `@dataset`, here we don't test recursive dependencies, test that in `ConfigResolver` TEST_CASE_13 = [ "dataloader", {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}}, - {"dataloader#": "DataLoader", "dataloader#": {"dataset": Dataset(data=[1, 2]), "batch_size": 2}}, + {"dataset": Dataset(data=[1, 2])}, DataLoader, ] # test dependencies in code execution TEST_CASE_14 = [ "optimizer", {"": "torch.optim.Adam", "": {"params": "$@model.parameters()", "lr": "@learning_rate"}}, - {"optimizer#": "torch.optim.Adam", "optimizer#": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, + {"model": torch.nn.PReLU(), "learning_rate": 1e-4}, torch.optim.Adam, ] # test replace dependencies with code execution result @@ -91,33 +81,40 @@ class TestConfigComponent(unittest.TestCase): @parameterized.expand( - [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6] - + ([TEST_CASE_7] if has_tv else []) + [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_5, TEST_CASE_6] + ([TEST_CASE_7] if has_tv else []) ) def test_build(self, test_input, output_type): - scanner = ComponentScanner(excludes=["metrics"]) - configer = ConfigComponent(id="test", config=test_input, scanner=scanner) + locator = ComponentLocator(excludes=["metrics"]) + configer = ConfigComponent(id="test", config=test_input, locator=locator, no_deps=True) ret = configer.build() + if test_input.get("", False): + # test `` works fine + self.assertEqual(ret, None) + return self.assertTrue(isinstance(ret, output_type)) if isinstance(ret, LoadImaged): self.assertEqual(ret.keys[0], "image") - if isinstance(ret, dict): - # test `` works fine - self.assertDictEqual(ret, test_input) + + @parameterized.expand([TEST_CASE_4]) + def test_raise_error(self, test_input): + with self.assertRaises(KeyError): # has unresolved keys + ConfigComponent(id="test", config=test_input, no_deps=True) @parameterized.expand([TEST_CASE_8, TEST_CASE_9, TEST_CASE_10, TEST_CASE_11]) def test_dependent_ids(self, test_input, ref_ids): - scanner = ComponentScanner() - configer = ConfigComponent(id="test", config=test_input, scanner=scanner) - ret = configer.get_dependent_ids() + configer = ConfigComponent(id="test", config=test_input) # also test default locator + ret = configer.get_id_of_deps() self.assertListEqual(ret, ref_ids) @parameterized.expand([TEST_CASE_12, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15, TEST_CASE_16, TEST_CASE_17]) - def test_update_dependencies(self, id, test_input, deps, output_type): - scanner = ComponentScanner(excludes=["utils"]) - configer = ConfigComponent(id=id, config=test_input, scanner=scanner, globals={"monai": monai, "torch": torch}) - config = configer.get_updated_config(deps) - ret = configer.build(config) + def test_resolve_dependencies(self, id, test_input, deps, output_type): + configer = ConfigComponent( + id=id, config=test_input, locator=None, excludes=["utils"], globals={"monai": monai, "torch": torch} + ) + configer.resolve_config(deps=deps) + ret = configer.get_resolved_config() + if is_to_build(ret): + ret = configer.build(**{}) # also test kwargs self.assertTrue(isinstance(ret, output_type)) From a158d30439175fcf3585a54da7b7b8c868c881c8 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 14 Feb 2022 23:31:48 +0800 Subject: [PATCH 13/25] [DLMED] add more doc-strings Signed-off-by: Nic Ma --- monai/apps/mmars/config_resolver.py | 88 ++++++++++++++++++++++------- monai/apps/mmars/utils.py | 24 ++++---- tests/test_config_component.py | 17 +++++- 3 files changed, 94 insertions(+), 35 deletions(-) diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py index 251409227e..20b6103d58 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_resolver.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from asyncio import FastChildWatcher import inspect import sys import warnings @@ -39,9 +38,22 @@ def __init__(self, excludes: Optional[Union[Sequence[str], str]] = None): self._components_table = None def _find_module_names(self) -> List[str]: - return [m for m in sys.modules.keys() if m.startswith(self.MOD_START) and all(s not in m for s in self.excludes)] + """ + Find all the modules start with MOD_START and don't contain any of `excludes`. + + """ + return [ + m for m in sys.modules.keys() if m.startswith(self.MOD_START) and all(s not in m for s in self.excludes) + ] + + def _find_classes_or_functions(self, modnames: Union[Sequence[str], str]) -> Dict[str, List]: + """ + Find all the classes and functions in the modules with specified `modnames`. + + Args: + modnames: names of the target modules to find all the classes and functions. - def _find_classes_or_functions(self, modnames: Union[Sequence[str], str]): + """ table: Dict[str, List] = {} # all the MONAI modules are already loaded by `load_submodules` for modname in ensure_tuple(modnames): @@ -103,6 +115,16 @@ class ConfigComponent: - When all the dependent components are built, update the config content with them, execute expressions in the config and `build` instance. + .. code-block:: python + + locator = ComponentLocator(excludes=[""]) + config = {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}} + + configer = ConfigComponent(config, id="test_config", locator=locator) + configer.resolve_config(deps={"dataset": Dataset(data=[1, 2])}) + configer.get_resolved_config() + instance = configer.build() + Args: config: config content of current component, can be a `dict`, `list`, `string`, `float`, `int`, etc. no_deps: flag to mark whether the config has dependent components, default to `False`. if `True`, @@ -141,7 +163,7 @@ def __init__( self.globals = globals self.set_config(config=config, no_deps=no_deps) - def get_id(self) -> str: + def get_id(self) -> Optional[str]: """ Get the unique ID of current component, useful to construct dependent components. For example, component A may have ID "transforms#A" and component B depends on A @@ -152,6 +174,17 @@ def get_id(self) -> str: return self.id def set_config(self, config: Any, no_deps: bool = False): + """ + Set the initial config content at runtime. + If having dependencies, must resolve the config again. + A typical usage is to modify the initial config content at runtime and do lazy instantiation. + + Args: + config: config content of current component, can be a `dict`, `list`, `string`, `float`, `int`, etc. + no_deps: flag to mark whether the config has dependent components, default to `False`. if `True`, + no need to resolve dependencies before building. + + """ self.config = config self.resolved_config = None self.is_resolved = False @@ -161,39 +194,49 @@ def set_config(self, config: Any, no_deps: bool = False): def get_config(self): """ - Get the init config content of current config component, usually set at the constructor. - It can be useful for lazy instantiation to dynamically update the config content before resolving + Get the initial config content of current config component, usually set at the constructor. + It can be useful for `lazy instantiation` to dynamically update the config content before resolving. """ return self.config def get_id_of_deps(self) -> List[str]: """ - Recursively search all the content of current config compoent to get the ids of dependencies. - It's used to build all the dependencies before build current config component. - For `dict` and `list`, treat every item as a dependency. - For example, for `{"": "DataLoader", "": {"dataset": "@dataset"}}`, the dependency ids: - `["", "", "#dataset", "dataset"]`. + Recursively search all the content of current config compoent to get the IDs of dependencies. + It's used to detect and build all the dependencies before building current config component. + For `dict` and `list`, recursively check the sub-items. + For example: `{"": "DataLoader", "": {"dataset": "@dataset"}}`, the dependency IDs: `["dataset"]`. """ return search_config_with_deps(config=self.config, id=self.id) - def resolve_config(self, deps: dict): + def resolve_config(self, deps: Dict): """ - If all the dependencies are ready in `deps`, update the config content with them and return new config. - It can be used for lazy instantiation, the returned config has no dependencies, can be built immediately. + If all the dependencies are ready in `deps`, resolve the config content with them to construct `resolved_config`. Args: - deps: all the dependent components with ids. + deps: all the dependent components with ID as keys. """ self.resolved_config = resolve_config_with_deps(config=self.config, deps=deps, id=self.id, globals=self.globals) self.is_resolved = True def get_resolved_config(self): + """ + Get the resolved config content, constructed in `resolve_config`. The returned config has no dependencies. + Not all the config content is to build instance, some config just need to resolve the dependencies and use + it in the program, for example: input config `{"intervals": "@epoch / 10"}` and dependencies `{"epoch": 100}`, + the resolved config will be `{"intervals": 10}`. + + """ return self.resolved_config def _resolve_module_name(self): + """ + Utility function used in `build()` to resolve the module name from provided config content + if having `` or ``. + + """ config = self.get_resolved_config() path = config.get("", None) if path is not None: @@ -217,15 +260,24 @@ def _resolve_module_name(self): return f"{module}.{name}" def _resolve_args(self): + """ + Utility function used in `build()` to resolve the arguments of target component to build. + + """ return self.get_resolved_config().get("", {}) def _is_disabled(self): + """ + Utility function used in `build()` to check whether the target component disabled building. + + """ return self.get_resolved_config().get("", False) def build(self, **kwargs) -> object: """ Build component instance based on the resolved config content. - The target component must be a `class` or a `function`. + The target component must be a `class` or a `function`, otherwise, return `None`. + Supported special keys for the config: - '' - class / function name in the modules of packages. - '' - directly specify the path, based on PYTHONPATH, ignore '' if specified. @@ -235,10 +287,6 @@ def build(self, **kwargs) -> object: Args: kwargs: args to override / add the config args when building. - Raises: - ValueError: must provide `` or `` of class / function to build component. - ValueError: can not find component class or function. - """ if not self.is_resolved: warnings.warn( diff --git a/monai/apps/mmars/utils.py b/monai/apps/mmars/utils.py index 2ec59c902a..9b704f553f 100644 --- a/monai/apps/mmars/utils.py +++ b/monai/apps/mmars/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -26,16 +26,13 @@ def is_to_build(config: Union[Dict, List, str]) -> bool: def search_config_with_deps( - config: Union[Dict, List, str], - id: Optional[str] = None, - deps: Optional[List[str]] = None, + config: Union[Dict, List, str], id: Optional[str] = None, deps: Optional[List[str]] = None ) -> List[str]: """ Recursively search all the content of input config compoent to get the ids of dependencies. It's used to build all the dependencies before build current config component. - For `dict` and `list`, treat every item as a dependency. - For example, for `{"": "DataLoader", "": {"dataset": "@dataset"}}`, the dependency ids: - `["", "", "#dataset", "dataset"]`. + For `dict` and `list`, recursively check the sub-items. + For example: `{"": "DataLoader", "": {"dataset": "@dataset"}}`, the dependency IDs: `["dataset"]`. Args: config: input config content to search. @@ -78,7 +75,6 @@ def resolve_config_with_deps( ): """ With all the dependencies in `deps`, resolve the config content with them and return new config. - It can be used for lazy instantiation. Args: config: input config content to resolve. @@ -91,18 +87,18 @@ def resolve_config_with_deps( pattern = re.compile(r"@\w*[\#\w]*") # match ref as args: "@XXX#YYY#ZZZ" if isinstance(config, list): # all the items in the list should be replaced with the reference - ret: List = [] + ret_list: List = [] for i, v in enumerate(config): sub_id = f"{id}#{i}" if id is not None else f"{i}" - ret.append(deps_[sub_id] if is_to_build(v) else resolve_config_with_deps(v, deps_, sub_id, globals)) - return ret + ret_list.append(deps_[sub_id] if is_to_build(v) else resolve_config_with_deps(v, deps_, sub_id, globals)) + return ret_list if isinstance(config, dict): # all the items in the dict should be replaced with the reference - ret: Dict = {} + ret_dict: Dict = {} for k, v in config.items(): sub_id = f"{id}#{k}" if id is not None else f"{k}" - ret[k] = deps_[sub_id] if is_to_build(v) else resolve_config_with_deps(v, deps_, sub_id, globals) - return ret + ret_dict[k] = deps_[sub_id] if is_to_build(v) else resolve_config_with_deps(v, deps_, sub_id, globals) + return ret_dict if isinstance(config, str): result = pattern.findall(config) config_: str = config # to avoid mypy CI errors diff --git a/tests/test_config_component.py b/tests/test_config_component.py index ab5a91dff0..6d726f4e60 100644 --- a/tests/test_config_component.py +++ b/tests/test_config_component.py @@ -48,7 +48,8 @@ TEST_CASE_9 = [{"dataset": "@dataset", "transforms": ["@trans0", "@trans1"]}, ["dataset", "trans0", "trans1"]] # test dependencies of execute code TEST_CASE_10 = [ - {"dataset": "$@dataset.test_func()", "transforms": ["$torch.zeros([2, 2]) + @trans"]}, ["dataset", "trans"] + {"dataset": "$@dataset.test_func()", "transforms": ["$torch.zeros([2, 2]) + @trans"]}, + ["dataset", "trans"], ] # test dependencies of lambda function TEST_CASE_11 = [ @@ -117,6 +118,20 @@ def test_resolve_dependencies(self, id, test_input, deps, output_type): ret = configer.build(**{}) # also test kwargs self.assertTrue(isinstance(ret, output_type)) + def test_lazy_instantiation(self): + config = {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}} + deps = {"dataset": Dataset(data=[1, 2])} + configer = ConfigComponent(config=config, locator=None) + init_config = configer.get_config() + # modify config content at runtime + init_config[""]["batch_size"] = 4 + configer.set_config(config=init_config) + + configer.resolve_config(deps=deps) + ret = configer.build() + self.assertTrue(isinstance(ret, DataLoader)) + self.assertEqual(ret.batch_size, 4) + if __name__ == "__main__": unittest.main() From 596a7216d241209b0f2603ce67bc15706db3dd4e Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 14 Feb 2022 23:41:56 +0800 Subject: [PATCH 14/25] [DLMED] fix flake8 Signed-off-by: Nic Ma --- monai/apps/mmars/config_resolver.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py index 20b6103d58..c10f6e9609 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_resolver.py @@ -35,7 +35,7 @@ class ComponentLocator: def __init__(self, excludes: Optional[Union[Sequence[str], str]] = None): self.excludes = [] if excludes is None else ensure_tuple(excludes) - self._components_table = None + self._components_table: Optional[Dict[str, List]] = None def _find_module_names(self) -> List[str]: """ @@ -69,7 +69,7 @@ def _find_classes_or_functions(self, modnames: Union[Sequence[str], str]) -> Dic pass return table - def get_component_module_name(self, name) -> Union[List[str], str]: + def get_component_module_name(self, name) -> Optional[Union[List[str], str]]: """ Get the full module name of the class / function with specified name. If target component name exists in multiple packages or modules, return a list of full module names. @@ -82,7 +82,7 @@ def get_component_module_name(self, name) -> Union[List[str], str]: # init component and module mapping table self._components_table = self._find_classes_or_functions(self._find_module_names()) - mods = self._components_table.get(name, None) + mods: Optional[Union[List[str], str]] = self._components_table.get(name, None) if isinstance(mods, list) and len(mods) == 1: mods = mods[0] return mods @@ -155,7 +155,7 @@ def __init__( excludes: Optional[Union[Sequence[str], str]] = None, globals: Optional[Dict] = None, ) -> None: - self.config = None + self.config = config self.resolved_config = None self.is_resolved = False self.id = id @@ -210,12 +210,12 @@ def get_id_of_deps(self) -> List[str]: """ return search_config_with_deps(config=self.config, id=self.id) - def resolve_config(self, deps: Dict): + def resolve_config(self, deps: Optional[Dict] = None): """ If all the dependencies are ready in `deps`, resolve the config content with them to construct `resolved_config`. Args: - deps: all the dependent components with ID as keys. + deps: all the dependent components with ID as keys, default to `None`. """ self.resolved_config = resolve_config_with_deps(config=self.config, deps=deps, id=self.id, globals=self.globals) From bf97a8b86ce83a68de6be553b2d1da2af2c6d2bb Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 15 Feb 2022 12:50:16 +0800 Subject: [PATCH 15/25] [DLMED] extract ConfigItem base class Signed-off-by: Nic Ma --- docs/source/apps.rst | 3 + monai/apps/__init__.py | 3 +- monai/apps/mmars/__init__.py | 4 +- monai/apps/mmars/config_resolver.py | 208 ++++++++++-------- monai/apps/mmars/utils.py | 12 +- ...onfig_component.py => test_config_item.py} | 13 +- 6 files changed, 140 insertions(+), 103 deletions(-) rename tests/{test_config_component.py => test_config_item.py} (94%) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index d54e7a2f8f..89b22d1d4a 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -38,6 +38,9 @@ Model Package .. autoclass:: ConfigComponent :members: +.. autoclass:: ConfigItem + :members: + `Utilities` ----------- diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index 13257765fd..11434314be 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -14,10 +14,11 @@ MODEL_DESC, ComponentLocator, ConfigComponent, + ConfigItem, RemoteMMARKeys, + able_to_build, download_mmar, get_model_spec, - is_to_build, load_from_mmar, resolve_config_with_deps, search_config_with_deps, diff --git a/monai/apps/mmars/__init__.py b/monai/apps/mmars/__init__.py index 0896814bbb..dfcefc3e8e 100644 --- a/monai/apps/mmars/__init__.py +++ b/monai/apps/mmars/__init__.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_resolver import ComponentLocator, ConfigComponent +from .config_resolver import ComponentLocator, ConfigComponent, ConfigItem from .mmars import download_mmar, get_model_spec, load_from_mmar from .model_desc import MODEL_DESC, RemoteMMARKeys -from .utils import is_to_build, resolve_config_with_deps, search_config_with_deps +from .utils import able_to_build, resolve_config_with_deps, search_config_with_deps diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py index c10f6e9609..38a7c28ad4 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_resolver.py @@ -15,10 +15,10 @@ from importlib import import_module from typing import Any, Dict, List, Optional, Sequence, Union -from monai.apps.mmars.utils import is_to_build, resolve_config_with_deps, search_config_with_deps +from monai.apps.mmars.utils import able_to_build, resolve_config_with_deps, search_config_with_deps from monai.utils import ensure_tuple, instantiate -__all__ = ["ComponentLocator", "ConfigComponent"] +__all__ = ["ComponentLocator", "ConfigItem", "ConfigComponent"] class ComponentLocator: @@ -88,134 +88,115 @@ def get_component_module_name(self, name) -> Optional[Union[List[str], str]]: return mods -class ConfigComponent: +class ConfigItem: """ - Utility class to manage every component in the config with a unique `id` name. - When recursively parsing a complicated config dictionary, every item should be treated as a `ConfigComponent`. - For example: - - `{"preprocessing": [{"": "LoadImage", "": {"keys": "image"}}]}` - - `{"": "LoadImage", "": {"keys": "image"}}` - - `"": "LoadImage"` - - `"keys": "image"` - - It can search the config content and find out all the dependencies, then build the config to instance + Utility class to manage every item of the whole config content. + When recursively parsing a complicated config content, every item (like: dict, list, string, int, float, etc.) + can be treated as an "config item" then construct a `ConfigItem`. + For example, below are 5 config items when recursively parsing: + - a dict: `{"preprocessing": ["@transform1", "@transform2", "$lambda x: x"]}` + - a list: `["@transform1", "@transform2", "$lambda x: x"]` + - a string: `"transform1"` + - a string: `"transform2"` + - a string: `"$lambda x: x"` + + `ConfigItem` can set optional unique ID name, then another config item may depdend on it, for example: + config item with ID="A" is a list `[1, 2, 3]`, another config item can be `"args": {"input_list": "@A"}`. + It can search the config content and find out all the dependencies, and resolve the config content when all the dependencies are resolved. - Here we predefined 4 kinds special marks (`<>`, `#`, `@`, `$`) to parse the config content: - - "": like "" is the name of a target component, to distinguish it with regular key "name" - in the config content. now we have 4 keys: ``, ``, ``, ``. - - "XXX#YYY": join nested config ids, like "transforms#5" is id name of the 6th transform in the transforms list. - - "@XXX": use an component as config item, like `"input_data": "@dataset"` uses `dataset` instance as parameter. + Here we predefined 3 kinds special marks (`#`, `@`, `$`) when parsing the whole config content: + - "XXX#YYY": join nested config IDs, like "transforms#5" is ID name of the 6th transform in a list ID="transforms". + - "@XXX": current config item depends on another config item XXX, like `{"args": {"data": "@dataset"}}` uses + resolved config content of `dataset` as the parameter "data". - "$XXX": execute the string after "$" as python code with `eval()` function, like "$@model.parameters()". The typical usage of the APIs: - Initialize with config content. - - If no dependencies, `build` the component if having "" or "" keywords and return the instance. - If having dependencies, get the IDs of its dependent components. - - When all the dependent components are built, update the config content with them, execute expressions in - the config and `build` instance. + - When all the dependent components are resolved, resolve the config content with them, + and execute expressions in the config. .. code-block:: python - locator = ComponentLocator(excludes=[""]) - config = {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}} + config = {"lr": "$@epoch / 1000"} - configer = ConfigComponent(config, id="test_config", locator=locator) - configer.resolve_config(deps={"dataset": Dataset(data=[1, 2])}) - configer.get_resolved_config() - instance = configer.build() + configer = ConfigComponent(config, id="test") + dep_ids = configer.get_id_of_deps() + configer.resolve_config(deps={"epoch": 10}) + lr = configer.get_resolved_config() Args: - config: config content of current component, can be a `dict`, `list`, `string`, `float`, `int`, etc. - no_deps: flag to mark whether the config has dependent components, default to `False`. if `True`, - no need to resolve dependencies before building. - id: ID name of current config component, useful to construct dependent components. - for example, component A may have ID "transforms#A" and component B depends on A - and uses the built instance of A as a dependent arg `"XXX": "@transforms#A"`. - for nested config items, use `#` to join ids, for list component, use index from `0` as id. - for example: `transform`, `transform#5`, `transform#5##keys`, etc. - the ID can be useful to quickly get the expected item in a complicated and nested config content. - ID defaults to `None`, if some component depends on current component, ID must be a `string`. - locator: ComponentLocator to help locate the module path of `` in the config and build instance. - if `None`, will create a new ComponentLocator with specified `excludes`. - excludes: if `locator` is None, create a new ComponentLocator with `excludes`. any string of the `excludes` - exists in the full module name, don't import this module. + config: content of a config item, can be a `dict`, `list`, `string`, `float`, `int`, etc. + id: ID name of current config item, useful to construct dependent config items. + for example, config item A may have ID "transforms#A" and config item B depends on A + and uses the resolved config content of A as an arg `{"args": {"other": "@transforms#A"}}`. + `id` defaults to `None`, if some component depends on current component, `id` must be a `string`. globals: to support executable string in the config, sometimes we need to provide the global variables which are referred in the executable string. for example: `globals={"monai": monai} will be useful - for config `"collate_fn": "$monai.data.list_data_collate"`. + for config `{"collate_fn": "$monai.data.list_data_collate"}`. """ - def __init__( - self, - config: Any, - no_deps: bool = False, - id: Optional[str] = None, - locator: Optional[ComponentLocator] = None, - excludes: Optional[Union[Sequence[str], str]] = None, - globals: Optional[Dict] = None, - ) -> None: + def __init__(self, config: Any, id: Optional[str] = None, globals: Optional[Dict] = None) -> None: self.config = config self.resolved_config = None self.is_resolved = False self.id = id - self.locator = ComponentLocator(excludes=excludes) if locator is None else locator self.globals = globals - self.set_config(config=config, no_deps=no_deps) + self.set_config(config=config) def get_id(self) -> Optional[str]: """ - Get the unique ID of current component, useful to construct dependent components. - For example, component A may have ID "transforms#A" and component B depends on A - and uses the built instance of A as a dependent arg `"XXX": "@transforms#A"`. - ID defaults to `None`, if some component depends on current component, ID must be a string. + ID name of current config item, useful to construct dependent config items. + for example, config item A may have ID "transforms#A" and config item B depends on A + and uses the resolved config content of A as an arg `{"args": {"other": "@transforms#A"}}`. + `id` defaults to `None`, if some component depends on current component, `id` must be a `string`. """ return self.id - def set_config(self, config: Any, no_deps: bool = False): + def set_config(self, config: Any): """ - Set the initial config content at runtime. - If having dependencies, must resolve the config again. - A typical usage is to modify the initial config content at runtime and do lazy instantiation. + Set the config content for a config item at runtime. + If having dependencies, need resolve the config later. + A typical usage is to modify the initial config content at runtime and set back. Args: - config: config content of current component, can be a `dict`, `list`, `string`, `float`, `int`, etc. - no_deps: flag to mark whether the config has dependent components, default to `False`. if `True`, - no need to resolve dependencies before building. + config: content of a config item, can be a `dict`, `list`, `string`, `float`, `int`, etc. """ self.config = config self.resolved_config = None self.is_resolved = False - if no_deps: + if not self.get_id_of_deps(): # if no dependencies, can resolve the config immediately self.resolve_config(deps=None) def get_config(self): """ - Get the initial config content of current config component, usually set at the constructor. - It can be useful for `lazy instantiation` to dynamically update the config content before resolving. + Get the initial config content of current config item, usually set at the constructor. + It can be useful to dynamically update the config content before resolving. """ return self.config def get_id_of_deps(self) -> List[str]: """ - Recursively search all the content of current config compoent to get the IDs of dependencies. - It's used to detect and build all the dependencies before building current config component. + Recursively search all the content of current config item to get the IDs of dependencies. + It's used to detect and resolve all the dependencies before resolving current config item. For `dict` and `list`, recursively check the sub-items. - For example: `{"": "DataLoader", "": {"dataset": "@dataset"}}`, the dependency IDs: `["dataset"]`. + For example: `{"args": {"lr": "$@epoch / 1000"}}`, the dependency IDs: `["epoch"]`. """ return search_config_with_deps(config=self.config, id=self.id) def resolve_config(self, deps: Optional[Dict] = None): """ - If all the dependencies are ready in `deps`, resolve the config content with them to construct `resolved_config`. + If all the dependencies are resolved in `deps`, resolve the config content with them to construct `resolved_config`. Args: - deps: all the dependent components with ID as keys, default to `None`. + deps: all the resolved dependent items with ID as keys, default to `None`. """ self.resolved_config = resolve_config_with_deps(config=self.config, deps=deps, id=self.id, globals=self.globals) @@ -223,18 +204,74 @@ def resolve_config(self, deps: Optional[Dict] = None): def get_resolved_config(self): """ - Get the resolved config content, constructed in `resolve_config`. The returned config has no dependencies. - Not all the config content is to build instance, some config just need to resolve the dependencies and use - it in the program, for example: input config `{"intervals": "@epoch / 10"}` and dependencies `{"epoch": 100}`, - the resolved config will be `{"intervals": 10}`. + Get the resolved config content, constructed in `resolve_config()`. The returned config has no dependencies, + then use it in the program, for example: initial config item `{"intervals": "@epoch / 10"}` and dependencies + `{"epoch": 100}`, the resolved config will be `{"intervals": 10}`. """ return self.resolved_config + +class ConfigComponent(ConfigItem): + """ + Subclass of :py:class:`monai.apps.ConfigItem`, the config item represents a target component of class + or function, and support to build the instance. Example of config item: + `{"": "LoadImage", "": {"keys": "image"}}` + + Here we predefined 4 keys: ``, ``, ``, `` for component config: + - '' - class / function name in the modules of packages. + - '' - directly specify the module path, based on PYTHONPATH, ignore '' if specified. + - '' - arguments to initialize the component instance. + - '' - if defined `'': True`, will skip the buiding, useful for development or tuning. + + The typical usage of the APIs: + - Initialize with config content. + - If no dependencies, `build` the component if having "" or "" keywords and return the instance. + - If having dependencies, get the IDs of its dependent components. + - When all the dependent components are resolved, resolve the config content with them, execute expressions in + the config and `build` instance. + + .. code-block:: python + + locator = ComponentLocator(excludes=[""]) + config = {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}} + + configer = ConfigComponent(config, id="test_config", locator=locator) + configer.resolve_config(deps={"dataset": Dataset(data=[1, 2])}) + configer.get_resolved_config() + dataloader: DataLoader = configer.build() + + Args: + config: content of a component config item, should be a dict with `` or `` key. + id: ID name of current config item, useful to construct dependent config items. + for example, config item A may have ID "transforms#A" and config item B depends on A + and uses the resolved config content of A as an arg `{"args": {"other": "@transforms#A"}}`. + `id` defaults to `None`, if some component depends on current component, `id` must be a `string`. + locator: `ComponentLocator` to help locate the module path of `` in the config and build instance. + if `None`, will create a new `ComponentLocator` with specified `excludes`. + excludes: if `locator` is None, create a new `ComponentLocator` with `excludes`. any string of the `excludes` + exists in the module name, don't import this module. + globals: to support executable string in the config, sometimes we need to provide the global variables + which are referred in the executable string. for example: `globals={"monai": monai} will be useful + for config `"collate_fn": "$monai.data.list_data_collate"`. + + """ + + def __init__( + self, + config: Any, + id: Optional[str] = None, + locator: Optional[ComponentLocator] = None, + excludes: Optional[Union[Sequence[str], str]] = None, + globals: Optional[Dict] = None, + ) -> None: + super().__init__(config=config, id=id, globals=globals) + self.locator = ComponentLocator(excludes=excludes) if locator is None else locator + def _resolve_module_name(self): """ - Utility function used in `build()` to resolve the module name from provided config content - if having `` or ``. + Utility function used in `build()` to resolve the target module name from provided config content. + The config content must have `` or ``. """ config = self.get_resolved_config() @@ -253,22 +290,22 @@ def _resolve_module_name(self): raise ModuleNotFoundError(f"can not find component '{name}' in {self.locator.MOD_START} modules.") if isinstance(module, list): warnings.warn( - f"there are more than 1 component name `{name}`: {module}, use the first one `{module[0]}." - f" if want to use others, please set the full python path in `` directly." + f"there are more than 1 component have name `{name}`: {module}, use the first one `{module[0]}." + f" if want to use others, please set its module path in `` directly." ) module = module[0] return f"{module}.{name}" def _resolve_args(self): """ - Utility function used in `build()` to resolve the arguments of target component to build. + Utility function used in `build()` to resolve the arguments from config content of target component to build. """ return self.get_resolved_config().get("", {}) def _is_disabled(self): """ - Utility function used in `build()` to check whether the target component disabled building. + Utility function used in `build()` to check whether the target component is disabled building. """ return self.get_resolved_config().get("", False) @@ -278,12 +315,6 @@ def build(self, **kwargs) -> object: Build component instance based on the resolved config content. The target component must be a `class` or a `function`, otherwise, return `None`. - Supported special keys for the config: - - '' - class / function name in the modules of packages. - - '' - directly specify the path, based on PYTHONPATH, ignore '' if specified. - - '' - arguments to initialize the component instance. - - '' - if defined `'': True`, will skip the buiding, useful for development or tuning. - Args: kwargs: args to override / add the config args when building. @@ -293,8 +324,9 @@ def build(self, **kwargs) -> object: "the config content of current component has not been resolved," " please try to resolve the dependencies first." ) + return None config = self.get_resolved_config() - if not is_to_build(config) or self._is_disabled(): + if not able_to_build(config) or self._is_disabled(): # if not a class or function, or marked as `disabled`, skip parsing and return `None` return None diff --git a/monai/apps/mmars/utils.py b/monai/apps/mmars/utils.py index 9b704f553f..ead320dacc 100644 --- a/monai/apps/mmars/utils.py +++ b/monai/apps/mmars/utils.py @@ -13,9 +13,9 @@ from typing import Dict, List, Optional, Union -def is_to_build(config: Union[Dict, List, str]) -> bool: +def able_to_build(config: Union[Dict, List, str]) -> bool: """ - Check whether the target component of the config is a `class` or `function` to build + Check whether the content of the config represents a `class` or `function` to build with specified "" or "". Args: @@ -45,14 +45,14 @@ def search_config_with_deps( if isinstance(config, list): for i, v in enumerate(config): sub_id = f"{id}#{i}" if id is not None else f"{i}" - if is_to_build(v): + if able_to_build(v): # sub-item is component need to build, mark as dependency deps_.append(sub_id) deps_ = search_config_with_deps(v, sub_id, deps_) if isinstance(config, dict): for k, v in config.items(): sub_id = f"{id}#{k}" if id is not None else f"{k}" - if is_to_build(v): + if able_to_build(v): # sub-item is component need to build, mark as dependency deps_.append(sub_id) deps_ = search_config_with_deps(v, sub_id, deps_) @@ -90,14 +90,14 @@ def resolve_config_with_deps( ret_list: List = [] for i, v in enumerate(config): sub_id = f"{id}#{i}" if id is not None else f"{i}" - ret_list.append(deps_[sub_id] if is_to_build(v) else resolve_config_with_deps(v, deps_, sub_id, globals)) + ret_list.append(deps_[sub_id] if able_to_build(v) else resolve_config_with_deps(v, deps_, sub_id, globals)) return ret_list if isinstance(config, dict): # all the items in the dict should be replaced with the reference ret_dict: Dict = {} for k, v in config.items(): sub_id = f"{id}#{k}" if id is not None else f"{k}" - ret_dict[k] = deps_[sub_id] if is_to_build(v) else resolve_config_with_deps(v, deps_, sub_id, globals) + ret_dict[k] = deps_[sub_id] if able_to_build(v) else resolve_config_with_deps(v, deps_, sub_id, globals) return ret_dict if isinstance(config, str): result = pattern.findall(config) diff --git a/tests/test_config_component.py b/tests/test_config_item.py similarity index 94% rename from tests/test_config_component.py rename to tests/test_config_item.py index 6d726f4e60..8fa183aac4 100644 --- a/tests/test_config_component.py +++ b/tests/test_config_item.py @@ -17,7 +17,7 @@ from parameterized import parameterized import monai -from monai.apps import ComponentLocator, ConfigComponent, is_to_build +from monai.apps import ComponentLocator, ConfigComponent, ConfigItem, able_to_build from monai.data import DataLoader, Dataset from monai.transforms import LoadImaged, RandTorchVisiond from monai.utils import optional_import @@ -80,13 +80,13 @@ TEST_CASE_17 = ["dataloader##collate_fn", "$lambda x: monai.data.list_data_collate(x) + 100", {}, Callable] -class TestConfigComponent(unittest.TestCase): +class TestConfigItem(unittest.TestCase): @parameterized.expand( [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_5, TEST_CASE_6] + ([TEST_CASE_7] if has_tv else []) ) def test_build(self, test_input, output_type): locator = ComponentLocator(excludes=["metrics"]) - configer = ConfigComponent(id="test", config=test_input, locator=locator, no_deps=True) + configer = ConfigComponent(id="test", config=test_input, locator=locator) ret = configer.build() if test_input.get("", False): # test `` works fine @@ -99,11 +99,12 @@ def test_build(self, test_input, output_type): @parameterized.expand([TEST_CASE_4]) def test_raise_error(self, test_input): with self.assertRaises(KeyError): # has unresolved keys - ConfigComponent(id="test", config=test_input, no_deps=True) + configer = ConfigItem(id="test", config=test_input) + configer.resolve_config() @parameterized.expand([TEST_CASE_8, TEST_CASE_9, TEST_CASE_10, TEST_CASE_11]) def test_dependent_ids(self, test_input, ref_ids): - configer = ConfigComponent(id="test", config=test_input) # also test default locator + configer = ConfigItem(id="test", config=test_input) # also test default locator ret = configer.get_id_of_deps() self.assertListEqual(ret, ref_ids) @@ -114,7 +115,7 @@ def test_resolve_dependencies(self, id, test_input, deps, output_type): ) configer.resolve_config(deps=deps) ret = configer.get_resolved_config() - if is_to_build(ret): + if able_to_build(ret): ret = configer.build(**{}) # also test kwargs self.assertTrue(isinstance(ret, output_type)) From bca6851a716cd3ca646b40ad6de4cc2623a32584 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 16 Feb 2022 00:45:00 +0800 Subject: [PATCH 16/25] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/apps/__init__.py | 9 +- monai/apps/mmars/__init__.py | 11 +- .../{config_resolver.py => config_item.py} | 170 +++++++++++------- monai/apps/mmars/utils.py | 164 +++++++++++------ tests/test_config_item.py | 48 ++--- 5 files changed, 257 insertions(+), 145 deletions(-) rename monai/apps/mmars/{config_resolver.py => config_item.py} (69%) diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index 11434314be..dcd6674d3f 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -16,11 +16,14 @@ ConfigComponent, ConfigItem, RemoteMMARKeys, - able_to_build, download_mmar, + find_refs_in_config, get_model_spec, + instantiable, + is_expression, load_from_mmar, - resolve_config_with_deps, - search_config_with_deps, + match_refs_pattern, + resolve_config_with_refs, + resolve_refs_pattern, ) from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/mmars/__init__.py b/monai/apps/mmars/__init__.py index dfcefc3e8e..da012715b0 100644 --- a/monai/apps/mmars/__init__.py +++ b/monai/apps/mmars/__init__.py @@ -9,7 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_resolver import ComponentLocator, ConfigComponent, ConfigItem +from .config_item import ComponentLocator, ConfigComponent, ConfigItem from .mmars import download_mmar, get_model_spec, load_from_mmar from .model_desc import MODEL_DESC, RemoteMMARKeys -from .utils import able_to_build, resolve_config_with_deps, search_config_with_deps +from .utils import ( + find_refs_in_config, + instantiable, + is_expression, + match_refs_pattern, + resolve_config_with_refs, + resolve_refs_pattern, +) diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_item.py similarity index 69% rename from monai/apps/mmars/config_resolver.py rename to monai/apps/mmars/config_item.py index 38a7c28ad4..b30c522406 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_item.py @@ -12,10 +12,11 @@ import inspect import sys import warnings +from abc import ABC, abstractmethod from importlib import import_module from typing import Any, Dict, List, Optional, Sequence, Union -from monai.apps.mmars.utils import able_to_build, resolve_config_with_deps, search_config_with_deps +from monai.apps.mmars.utils import find_refs_in_config, instantiable, resolve_config_with_refs from monai.utils import ensure_tuple, instantiate __all__ = ["ComponentLocator", "ConfigItem", "ConfigComponent"] @@ -100,21 +101,24 @@ class ConfigItem: - a string: `"transform2"` - a string: `"$lambda x: x"` - `ConfigItem` can set optional unique ID name, then another config item may depdend on it, for example: + `ConfigItem` can set optional unique ID name, then another config item may refer to it. + The references mean the IDs of other config items used as "@XXX" in the current config item, for example: config item with ID="A" is a list `[1, 2, 3]`, another config item can be `"args": {"input_list": "@A"}`. - It can search the config content and find out all the dependencies, and resolve the config content - when all the dependencies are resolved. + If sub-item in the config is instantiable, also treat it as reference because must instantiate it before + resolving current config. + It can search the config content and find out all the references, and resolve the config content + when all the references are resolved. Here we predefined 3 kinds special marks (`#`, `@`, `$`) when parsing the whole config content: - "XXX#YYY": join nested config IDs, like "transforms#5" is ID name of the 6th transform in a list ID="transforms". - - "@XXX": current config item depends on another config item XXX, like `{"args": {"data": "@dataset"}}` uses + - "@XXX": current config item refers to another config item XXX, like `{"args": {"data": "@dataset"}}` uses resolved config content of `dataset` as the parameter "data". - "$XXX": execute the string after "$" as python code with `eval()` function, like "$@model.parameters()". The typical usage of the APIs: - Initialize with config content. - - If having dependencies, get the IDs of its dependent components. - - When all the dependent components are resolved, resolve the config content with them, + - If having references, get the IDs of its referring components. + - When all the referring components are resolved, resolve the config content with them, and execute expressions in the config. .. code-block:: python @@ -122,16 +126,16 @@ class ConfigItem: config = {"lr": "$@epoch / 1000"} configer = ConfigComponent(config, id="test") - dep_ids = configer.get_id_of_deps() - configer.resolve_config(deps={"epoch": 10}) + dep_ids = configer.get_id_of_refs() + configer.resolve_config(refs={"epoch": 10}) lr = configer.get_resolved_config() Args: config: content of a config item, can be a `dict`, `list`, `string`, `float`, `int`, etc. - id: ID name of current config item, useful to construct dependent config items. - for example, config item A may have ID "transforms#A" and config item B depends on A + id: ID name of current config item, useful to construct referring config items. + for example, config item A may have ID "transforms#A" and config item B refers to A and uses the resolved config content of A as an arg `{"args": {"other": "@transforms#A"}}`. - `id` defaults to `None`, if some component depends on current component, `id` must be a `string`. + `id` defaults to `None`, if some component refers to current component, `id` must be a `string`. globals: to support executable string in the config, sometimes we need to provide the global variables which are referred in the executable string. for example: `globals={"monai": monai} will be useful for config `{"collate_fn": "$monai.data.list_data_collate"}`. @@ -144,22 +148,22 @@ def __init__(self, config: Any, id: Optional[str] = None, globals: Optional[Dict self.is_resolved = False self.id = id self.globals = globals - self.set_config(config=config) + self.update_config(config=config) def get_id(self) -> Optional[str]: """ - ID name of current config item, useful to construct dependent config items. - for example, config item A may have ID "transforms#A" and config item B depends on A + ID name of current config item, useful to construct referring config items. + for example, config item A may have ID "transforms#A" and config item B refers to A and uses the resolved config content of A as an arg `{"args": {"other": "@transforms#A"}}`. - `id` defaults to `None`, if some component depends on current component, `id` must be a `string`. + `id` defaults to `None`, if some component refers to current component, `id` must be a `string`. """ return self.id - def set_config(self, config: Any): + def update_config(self, config: Any): """ - Set the config content for a config item at runtime. - If having dependencies, need resolve the config later. + Update the config content for a config item at runtime. + If having references, need resolve the config later. A typical usage is to modify the initial config content at runtime and set back. Args: @@ -169,9 +173,9 @@ def set_config(self, config: Any): self.config = config self.resolved_config = None self.is_resolved = False - if not self.get_id_of_deps(): - # if no dependencies, can resolve the config immediately - self.resolve_config(deps=None) + if not self.get_id_of_refs(): + # if no references, can resolve the config immediately + self.resolve(refs=None) def get_config(self): """ @@ -181,41 +185,90 @@ def get_config(self): """ return self.config - def get_id_of_deps(self) -> List[str]: + def get_id_of_refs(self) -> List[str]: """ - Recursively search all the content of current config item to get the IDs of dependencies. - It's used to detect and resolve all the dependencies before resolving current config item. + Recursively search all the content of current config item to get the IDs of references. + It's used to detect and resolve all the references before resolving current config item. For `dict` and `list`, recursively check the sub-items. - For example: `{"args": {"lr": "$@epoch / 1000"}}`, the dependency IDs: `["epoch"]`. + For example: `{"args": {"lr": "$@epoch / 1000"}}`, the reference IDs: `["epoch"]`. """ - return search_config_with_deps(config=self.config, id=self.id) + return find_refs_in_config(self.config, id=self.id) - def resolve_config(self, deps: Optional[Dict] = None): + def resolve(self, refs: Optional[Dict] = None): """ - If all the dependencies are resolved in `deps`, resolve the config content with them to construct `resolved_config`. + If all the references are resolved in `refs`, resolve the config content with them to construct `resolved_config`. Args: - deps: all the resolved dependent items with ID as keys, default to `None`. + refs: all the resolved referring items with ID as keys, default to `None`. """ - self.resolved_config = resolve_config_with_deps(config=self.config, deps=deps, id=self.id, globals=self.globals) + self.resolved_config = resolve_config_with_refs(self.config, id=self.id, refs=refs, globals=self.globals) self.is_resolved = True def get_resolved_config(self): """ - Get the resolved config content, constructed in `resolve_config()`. The returned config has no dependencies, - then use it in the program, for example: initial config item `{"intervals": "@epoch / 10"}` and dependencies + Get the resolved config content, constructed in `resolve_config()`. The returned config has no references, + then use it in the program, for example: initial config item `{"intervals": "@epoch / 10"}` and references `{"epoch": 100}`, the resolved config will be `{"intervals": 10}`. """ return self.resolved_config -class ConfigComponent(ConfigItem): +class Instantiable(ABC): + """ + Base class for instantiable object and provide the `instantiate` API. + + """ + + @abstractmethod + def _resolve_module_name(self): + """ + Utility function used in `instantiate()` to resolve the target module name. + + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def _resolve_args(self): + """ + Utility function used in `instantiate()` to resolve the arguments. + + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def _is_disabled(self): + """ + Utility function used in `instantiate()` to check whether the target component is disabled. + + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + def instantiate(self, **kwargs) -> object: + """ + Instantiate the target component. + + Args: + kwargs: args to override / add the kwargs when instantiation. + + """ + + if self._is_disabled(): + # if marked as `disabled`, skip parsing and return `None` + return None + + modname = self._resolve_module_name() + args = self._resolve_args() + args.update(kwargs) + return instantiate(modname, **args) + + +class ConfigComponent(ConfigItem, Instantiable): """ Subclass of :py:class:`monai.apps.ConfigItem`, the config item represents a target component of class - or function, and support to build the instance. Example of config item: + or function, and support to instantiate the component. Example of config item: `{"": "LoadImage", "": {"keys": "image"}}` Here we predefined 4 keys: ``, ``, ``, `` for component config: @@ -226,10 +279,10 @@ class ConfigComponent(ConfigItem): The typical usage of the APIs: - Initialize with config content. - - If no dependencies, `build` the component if having "" or "" keywords and return the instance. - - If having dependencies, get the IDs of its dependent components. - - When all the dependent components are resolved, resolve the config content with them, execute expressions in - the config and `build` instance. + - If no references, `instantiate` the component if having "" or "" keywords and return the instance. + - If having references, get the IDs of its referring components. + - When all the referring components are resolved, resolve the config content with them, execute expressions in + the config and `instantiate`. .. code-block:: python @@ -237,17 +290,17 @@ class ConfigComponent(ConfigItem): config = {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}} configer = ConfigComponent(config, id="test_config", locator=locator) - configer.resolve_config(deps={"dataset": Dataset(data=[1, 2])}) + configer.resolve_config(refs={"dataset": Dataset(data=[1, 2])}) configer.get_resolved_config() - dataloader: DataLoader = configer.build() + dataloader: DataLoader = configer.instantiate() Args: config: content of a component config item, should be a dict with `` or `` key. - id: ID name of current config item, useful to construct dependent config items. - for example, config item A may have ID "transforms#A" and config item B depends on A + id: ID name of current config item, useful to construct referring config items. + for example, config item A may have ID "transforms#A" and config item B refers to A and uses the resolved config content of A as an arg `{"args": {"other": "@transforms#A"}}`. - `id` defaults to `None`, if some component depends on current component, `id` must be a `string`. - locator: `ComponentLocator` to help locate the module path of `` in the config and build instance. + `id` defaults to `None`, if some component refers to current component, `id` must be a `string`. + locator: `ComponentLocator` to help locate the module path of `` in the config and instantiate. if `None`, will create a new `ComponentLocator` with specified `excludes`. excludes: if `locator` is None, create a new `ComponentLocator` with `excludes`. any string of the `excludes` exists in the module name, don't import this module. @@ -270,7 +323,7 @@ def __init__( def _resolve_module_name(self): """ - Utility function used in `build()` to resolve the target module name from provided config content. + Utility function used in `instantiate()` to resolve the target module name from provided config content. The config content must have `` or ``. """ @@ -283,7 +336,7 @@ def _resolve_module_name(self): name = config.get("", None) if name is None: - raise ValueError("must provide `` or `` of target component to build.") + raise ValueError("must provide `` or `` of target component to instantiate.") module = self.locator.get_component_module_name(name) if module is None: @@ -298,39 +351,34 @@ def _resolve_module_name(self): def _resolve_args(self): """ - Utility function used in `build()` to resolve the arguments from config content of target component to build. + Utility function used in `instantiate()` to resolve the arguments from config content of target component. """ return self.get_resolved_config().get("", {}) def _is_disabled(self): """ - Utility function used in `build()` to check whether the target component is disabled building. + Utility function used in `instantiate()` to check whether the target component is disabled. """ return self.get_resolved_config().get("", False) - def build(self, **kwargs) -> object: + def instantiate(self, **kwargs) -> object: """ - Build component instance based on the resolved config content. + Instantiate component based on the resolved config content. The target component must be a `class` or a `function`, otherwise, return `None`. Args: - kwargs: args to override / add the config args when building. + kwargs: args to override / add the config args when instantiation. """ if not self.is_resolved: warnings.warn( "the config content of current component has not been resolved," - " please try to resolve the dependencies first." + " please try to resolve the references first." ) return None - config = self.get_resolved_config() - if not able_to_build(config) or self._is_disabled(): - # if not a class or function, or marked as `disabled`, skip parsing and return `None` + if not instantiable(self.get_resolved_config()): + # if not a class or function, skip parsing and return `None` return None - - modname = self._resolve_module_name() - args = self._resolve_args() - args.update(kwargs) - return instantiate(modname, **args) + return super().instantiate(**kwargs) diff --git a/monai/apps/mmars/utils.py b/monai/apps/mmars/utils.py index ead320dacc..9503178e09 100644 --- a/monai/apps/mmars/utils.py +++ b/monai/apps/mmars/utils.py @@ -10,106 +10,160 @@ # limitations under the License. import re -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union -def able_to_build(config: Union[Dict, List, str]) -> bool: +def match_refs_pattern(value: str, pattern: str = r"@\w*[\#\w]*") -> List[str]: """ - Check whether the content of the config represents a `class` or `function` to build - with specified "" or "". + Match regular expression for the input string with specified `pattern` to find the references. + Default to find the ID of referring item starting with "@", like: "@XXX#YYY#ZZZ". Args: - config: input config content to check. + value: input value to match regular expression. + pattern: regular expression pattern, default to match "@XXX" or "@XXX#YYY". """ - return isinstance(config, dict) and ("" in config or "" in config) + refs: List[str] = [] + result = re.compile(pattern).findall(value) + for item in result: + if value.startswith("$") or value == item: + # only check when string starts with "$" or the whole content is "@XXX" + ref_obj_id = item[1:] + if ref_obj_id not in refs: + refs.append(ref_obj_id) + return refs + + +def resolve_refs_pattern(value: str, refs: Dict, globals: Optional[Dict] = None, pattern: str = r"@\w*[\#\w]*") -> str: + """ + Match regular expression for the input string with specified `pattern` to update content with + the references. Default to find the ID of referring item starting with "@", like: "@XXX#YYY#ZZZ". + References dictionary must contain the referring IDs as keys. + + Args: + value: input value to match regular expression. + refs: all the referring components with ids as keys, default to `None`. + globals: predefined global variables to execute code string with `eval()`. + pattern: regular expression pattern, default to match "@XXX" or "@XXX#YYY". + + """ + result = re.compile(pattern).findall(value) + for item in result: + ref_id = item[1:] + if is_expression(value): + # replace with local code and execute later + value = value.replace(item, f"refs['{ref_id}']") + elif value == item: + if ref_id not in refs: + raise KeyError(f"can not find expected ID '{ref_id}' in the references.") + value = refs[ref_id] + if is_expression(value): + # execute the code string with python `eval()` + value = eval(value[1:], globals, {"refs": refs}) + return value -def search_config_with_deps( - config: Union[Dict, List, str], id: Optional[str] = None, deps: Optional[List[str]] = None +def find_refs_in_config( + config: Union[Dict, List, str], + id: Optional[str] = None, + refs: Optional[List[str]] = None, + match_fn: Callable = match_refs_pattern, ) -> List[str]: """ - Recursively search all the content of input config compoent to get the ids of dependencies. - It's used to build all the dependencies before build current config component. + Recursively search all the content of input config item to get the ids of references. + References mean (1) referring to the ID of other item, can be extracted by `match_fn`, for example: + `{"lr": "$@epoch / 100"}` with "@" mark, the referring IDs: `["epoch"]`. (2) if sub-item in the config + is instantiable, treat it as reference because must instantiate it before resolving current config. For `dict` and `list`, recursively check the sub-items. - For example: `{"": "DataLoader", "": {"dataset": "@dataset"}}`, the dependency IDs: `["dataset"]`. Args: config: input config content to search. id: ID name for the input config, default to `None`. - deps: list of the ID name of existing dependencies, default to `None`. + refs: list of the ID name of existing references, default to `None`. + match_fn: callable function to match config item for references, take `config` as parameter. """ - deps_: List[str] = [] if deps is None else deps - pattern = re.compile(r"@\w*[\#\w]*") # match ref as args: "@XXX#YYY#ZZZ" + refs_: List[str] = [] if refs is None else refs + if isinstance(config, str): + refs_ += match_fn(value=config) + if isinstance(config, list): for i, v in enumerate(config): sub_id = f"{id}#{i}" if id is not None else f"{i}" - if able_to_build(v): - # sub-item is component need to build, mark as dependency - deps_.append(sub_id) - deps_ = search_config_with_deps(v, sub_id, deps_) + if instantiable(v): + refs_.append(sub_id) + refs_ = find_refs_in_config(v, sub_id, refs_, match_fn) if isinstance(config, dict): for k, v in config.items(): sub_id = f"{id}#{k}" if id is not None else f"{k}" - if able_to_build(v): - # sub-item is component need to build, mark as dependency - deps_.append(sub_id) - deps_ = search_config_with_deps(v, sub_id, deps_) - if isinstance(config, str): - result = pattern.findall(config) - for item in result: - if config.startswith("$") or config == item: - # only check when string starts with "$" or the whole content is "@XXX" - ref_obj_id = item[1:] - if ref_obj_id not in deps_: - deps_.append(ref_obj_id) - return deps_ + if instantiable(v): + refs_.append(sub_id) + refs_ = find_refs_in_config(v, sub_id, refs_, match_fn) + return refs_ -def resolve_config_with_deps( +def resolve_config_with_refs( config: Union[Dict, List, str], - deps: Optional[Dict] = None, id: Optional[str] = None, + refs: Optional[Dict] = None, globals: Optional[Dict] = None, + match_fn: Callable = resolve_refs_pattern, ): """ - With all the dependencies in `deps`, resolve the config content with them and return new config. + With all the references in `refs`, resolve the config content with them and return new config. Args: config: input config content to resolve. - deps: all the dependent components with ids, default to `None`. - id: id name for the input config, default to `None`. + id: ID name for the input config, default to `None`. + refs: all the referring components with ids, default to `None`. globals: predefined global variables to execute code string with `eval()`. + match_fn: callable function to match config item for references, take `config`, + `refs` and `globals` as parameter. """ - deps_: Dict = {} if deps is None else deps - pattern = re.compile(r"@\w*[\#\w]*") # match ref as args: "@XXX#YYY#ZZZ" + refs_: Dict = {} if refs is None else refs + if isinstance(config, str): + config = match_fn(config, refs, globals) if isinstance(config, list): - # all the items in the list should be replaced with the reference + # all the items in the list should be replaced with the references ret_list: List = [] for i, v in enumerate(config): sub_id = f"{id}#{i}" if id is not None else f"{i}" - ret_list.append(deps_[sub_id] if able_to_build(v) else resolve_config_with_deps(v, deps_, sub_id, globals)) + ret_list.append( + refs_[sub_id] if instantiable(v) else resolve_config_with_refs(v, sub_id, refs_, globals, match_fn) + ) return ret_list if isinstance(config, dict): - # all the items in the dict should be replaced with the reference + # all the items in the dict should be replaced with the references ret_dict: Dict = {} for k, v in config.items(): sub_id = f"{id}#{k}" if id is not None else f"{k}" - ret_dict[k] = deps_[sub_id] if able_to_build(v) else resolve_config_with_deps(v, deps_, sub_id, globals) + ret_dict.update( + {k: refs_[sub_id] if instantiable(v) else resolve_config_with_refs(v, sub_id, refs_, globals, match_fn)} + ) return ret_dict - if isinstance(config, str): - result = pattern.findall(config) - config_: str = config # to avoid mypy CI errors - for item in result: - ref_obj_id = item[1:] - if config_.startswith("$"): - # replace with local code and execute later - config_ = config_.replace(item, f"deps_['{ref_obj_id}']") - elif config_ == item: - config_ = deps_[ref_obj_id] - if isinstance(config_, str) and config_.startswith("$"): - config_ = eval(config_[1:], globals, {"deps_": deps_}) - return config_ return config + + +def instantiable(config: Union[Dict, List, str]) -> bool: + """ + Check whether the content of the config represents a `class` or `function` to instantiate + with specified "" or "". + + Args: + config: input config content to check. + + """ + return isinstance(config, dict) and ("" in config or "" in config) + + +def is_expression(config: Union[Dict, List, str]) -> bool: + """ + Check whether the content of the config is executable expression string. + If True, the string should start with "$" mark. + + Args: + config: input config content to check. + + """ + return isinstance(config, str) and config.startswith("$") diff --git a/tests/test_config_item.py b/tests/test_config_item.py index 8fa183aac4..7ec03c3d22 100644 --- a/tests/test_config_item.py +++ b/tests/test_config_item.py @@ -17,7 +17,7 @@ from parameterized import parameterized import monai -from monai.apps import ComponentLocator, ConfigComponent, ConfigItem, able_to_build +from monai.apps import ComponentLocator, ConfigComponent, ConfigItem, instantiable from monai.data import DataLoader, Dataset from monai.transforms import LoadImaged, RandTorchVisiond from monai.utils import optional_import @@ -29,7 +29,7 @@ TEST_CASE_2 = [{"": "monai.transforms.LoadImaged", "": {"keys": ["image"]}}, LoadImaged] # test `` TEST_CASE_3 = [{"": "LoadImaged", "": True, "": {"keys": ["image"]}}, dict] -# test unresolved dependency +# test unresolved reference TEST_CASE_4 = [{"": "LoadImaged", "": {"keys": ["@key_name"]}}] # test non-monai modules and excludes TEST_CASE_5 = [ @@ -42,41 +42,41 @@ {"": "RandTorchVisiond", "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}}, RandTorchVisiond, ] -# test dependencies of dict config +# test references of dict config TEST_CASE_8 = [{"dataset": "@dataset", "batch_size": 2}, ["dataset"]] -# test dependencies of list config +# test references of list config TEST_CASE_9 = [{"dataset": "@dataset", "transforms": ["@trans0", "@trans1"]}, ["dataset", "trans0", "trans1"]] -# test dependencies of execute code +# test references of execute code TEST_CASE_10 = [ {"dataset": "$@dataset.test_func()", "transforms": ["$torch.zeros([2, 2]) + @trans"]}, ["dataset", "trans"], ] -# test dependencies of lambda function +# test references of lambda function TEST_CASE_11 = [ {"lr_range": "$lambda x: x + @num_epochs", "lr": ["$lambda x: torch.zeros([2, 2]) + @init_lr"]}, ["num_epochs", "init_lr"], ] -# test instance with no dependencies +# test instance with no references TEST_CASE_12 = ["transform#1", {"": "LoadImaged", "": {"keys": ["image"]}}, {}, LoadImaged] -# test dataloader refers to `@dataset`, here we don't test recursive dependencies, test that in `ConfigResolver` +# test dataloader refers to `@dataset`, here we don't test recursive references, test that in `ConfigResolver` TEST_CASE_13 = [ "dataloader", {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}}, {"dataset": Dataset(data=[1, 2])}, DataLoader, ] -# test dependencies in code execution +# test references in code execution TEST_CASE_14 = [ "optimizer", {"": "torch.optim.Adam", "": {"params": "$@model.parameters()", "lr": "@learning_rate"}}, {"model": torch.nn.PReLU(), "learning_rate": 1e-4}, torch.optim.Adam, ] -# test replace dependencies with code execution result +# test replace references with code execution result TEST_CASE_15 = ["optimizer##params", "$@model.parameters()", {"model": torch.nn.PReLU()}, Iterator] # test execute some function in args, test pre-imported global packages `monai` TEST_CASE_16 = ["dataloader##collate_fn", "$monai.data.list_data_collate", {}, Callable] -# test lambda function, should not execute the lambda function, just change the string with dependent objects +# test lambda function, should not execute the lambda function, just change the string with referring objects TEST_CASE_17 = ["dataloader##collate_fn", "$lambda x: monai.data.list_data_collate(x) + 100", {}, Callable] @@ -84,10 +84,10 @@ class TestConfigItem(unittest.TestCase): @parameterized.expand( [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_5, TEST_CASE_6] + ([TEST_CASE_7] if has_tv else []) ) - def test_build(self, test_input, output_type): + def test_instantiate(self, test_input, output_type): locator = ComponentLocator(excludes=["metrics"]) configer = ConfigComponent(id="test", config=test_input, locator=locator) - ret = configer.build() + ret = configer.instantiate() if test_input.get("", False): # test `` works fine self.assertEqual(ret, None) @@ -100,36 +100,36 @@ def test_build(self, test_input, output_type): def test_raise_error(self, test_input): with self.assertRaises(KeyError): # has unresolved keys configer = ConfigItem(id="test", config=test_input) - configer.resolve_config() + configer.resolve() @parameterized.expand([TEST_CASE_8, TEST_CASE_9, TEST_CASE_10, TEST_CASE_11]) - def test_dependent_ids(self, test_input, ref_ids): + def test_referring_ids(self, test_input, ref_ids): configer = ConfigItem(id="test", config=test_input) # also test default locator - ret = configer.get_id_of_deps() + ret = configer.get_id_of_refs() self.assertListEqual(ret, ref_ids) @parameterized.expand([TEST_CASE_12, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15, TEST_CASE_16, TEST_CASE_17]) - def test_resolve_dependencies(self, id, test_input, deps, output_type): + def test_resolve_references(self, id, test_input, refs, output_type): configer = ConfigComponent( id=id, config=test_input, locator=None, excludes=["utils"], globals={"monai": monai, "torch": torch} ) - configer.resolve_config(deps=deps) + configer.resolve(refs=refs) ret = configer.get_resolved_config() - if able_to_build(ret): - ret = configer.build(**{}) # also test kwargs + if instantiable(ret): + ret = configer.instantiate(**{}) # also test kwargs self.assertTrue(isinstance(ret, output_type)) def test_lazy_instantiation(self): config = {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}} - deps = {"dataset": Dataset(data=[1, 2])} + refs = {"dataset": Dataset(data=[1, 2])} configer = ConfigComponent(config=config, locator=None) init_config = configer.get_config() # modify config content at runtime init_config[""]["batch_size"] = 4 - configer.set_config(config=init_config) + configer.update_config(config=init_config) - configer.resolve_config(deps=deps) - ret = configer.build() + configer.resolve(refs=refs) + ret = configer.instantiate() self.assertTrue(isinstance(ret, DataLoader)) self.assertEqual(ret.batch_size, 4) From 9f5a52305757a3a64ea245f92f20e55e31f85209 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 16 Feb 2022 16:59:16 +0800 Subject: [PATCH 17/25] [DLMED] fix typo Signed-off-by: Nic Ma --- monai/apps/mmars/utils.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/monai/apps/mmars/utils.py b/monai/apps/mmars/utils.py index 9503178e09..c1b5665efe 100644 --- a/monai/apps/mmars/utils.py +++ b/monai/apps/mmars/utils.py @@ -13,20 +13,20 @@ from typing import Callable, Dict, List, Optional, Union -def match_refs_pattern(value: str, pattern: str = r"@\w*[\#\w]*") -> List[str]: +def match_refs_pattern(value: str) -> List[str]: """ - Match regular expression for the input string with specified `pattern` to find the references. - Default to find the ID of referring item starting with "@", like: "@XXX#YYY#ZZZ". + Match regular expression for the input string to find the references. + The reference part starts with "@", like: "@XXX#YYY#ZZZ". Args: value: input value to match regular expression. - pattern: regular expression pattern, default to match "@XXX" or "@XXX#YYY". """ refs: List[str] = [] - result = re.compile(pattern).findall(value) + # regular expression pattern to match "@XXX" or "@XXX#YYY" + result = re.compile(r"@\w*[\#\w]*").findall(value) for item in result: - if value.startswith("$") or value == item: + if is_expression(value) or value == item: # only check when string starts with "$" or the whole content is "@XXX" ref_obj_id = item[1:] if ref_obj_id not in refs: @@ -34,20 +34,20 @@ def match_refs_pattern(value: str, pattern: str = r"@\w*[\#\w]*") -> List[str]: return refs -def resolve_refs_pattern(value: str, refs: Dict, globals: Optional[Dict] = None, pattern: str = r"@\w*[\#\w]*") -> str: +def resolve_refs_pattern(value: str, refs: Dict, globals: Optional[Dict] = None) -> str: """ - Match regular expression for the input string with specified `pattern` to update content with - the references. Default to find the ID of referring item starting with "@", like: "@XXX#YYY#ZZZ". + Match regular expression for the input string to update content with the references. + The reference part starts with "@", like: "@XXX#YYY#ZZZ". References dictionary must contain the referring IDs as keys. Args: value: input value to match regular expression. refs: all the referring components with ids as keys, default to `None`. globals: predefined global variables to execute code string with `eval()`. - pattern: regular expression pattern, default to match "@XXX" or "@XXX#YYY". """ - result = re.compile(pattern).findall(value) + # regular expression pattern to match "@XXX" or "@XXX#YYY" + result = re.compile(r"@\w*[\#\w]*").findall(value) for item in result: ref_id = item[1:] if is_expression(value): @@ -90,13 +90,13 @@ def find_refs_in_config( if isinstance(config, list): for i, v in enumerate(config): sub_id = f"{id}#{i}" if id is not None else f"{i}" - if instantiable(v): + if is_instantiable(v): refs_.append(sub_id) refs_ = find_refs_in_config(v, sub_id, refs_, match_fn) if isinstance(config, dict): for k, v in config.items(): sub_id = f"{id}#{k}" if id is not None else f"{k}" - if instantiable(v): + if is_instantiable(v): refs_.append(sub_id) refs_ = find_refs_in_config(v, sub_id, refs_, match_fn) return refs_ @@ -128,24 +128,24 @@ def resolve_config_with_refs( # all the items in the list should be replaced with the references ret_list: List = [] for i, v in enumerate(config): - sub_id = f"{id}#{i}" if id is not None else f"{i}" + sub = f"{id}#{i}" if id is not None else f"{i}" ret_list.append( - refs_[sub_id] if instantiable(v) else resolve_config_with_refs(v, sub_id, refs_, globals, match_fn) + refs_[sub] if is_instantiable(v) else resolve_config_with_refs(v, sub, refs_, globals, match_fn) ) return ret_list if isinstance(config, dict): # all the items in the dict should be replaced with the references ret_dict: Dict = {} for k, v in config.items(): - sub_id = f"{id}#{k}" if id is not None else f"{k}" + sub = f"{id}#{k}" if id is not None else f"{k}" ret_dict.update( - {k: refs_[sub_id] if instantiable(v) else resolve_config_with_refs(v, sub_id, refs_, globals, match_fn)} + {k: refs_[sub] if is_instantiable(v) else resolve_config_with_refs(v, sub, refs_, globals, match_fn)} ) return ret_dict return config -def instantiable(config: Union[Dict, List, str]) -> bool: +def is_instantiable(config: Union[Dict, List, str]) -> bool: """ Check whether the content of the config represents a `class` or `function` to instantiate with specified "" or "". From b8fe7f3ee002e89ea10bbadc53d094f58e84fb17 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 16 Feb 2022 17:26:01 +0800 Subject: [PATCH 18/25] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/apps/__init__.py | 2 +- monai/apps/mmars/__init__.py | 2 +- monai/apps/mmars/config_item.py | 50 +++++++++++++++------------------ monai/apps/mmars/utils.py | 2 +- tests/test_config_item.py | 4 +-- 5 files changed, 27 insertions(+), 33 deletions(-) diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index dcd6674d3f..3e6f7440a1 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -19,8 +19,8 @@ download_mmar, find_refs_in_config, get_model_spec, - instantiable, is_expression, + is_instantiable, load_from_mmar, match_refs_pattern, resolve_config_with_refs, diff --git a/monai/apps/mmars/__init__.py b/monai/apps/mmars/__init__.py index da012715b0..e54cb7ac9d 100644 --- a/monai/apps/mmars/__init__.py +++ b/monai/apps/mmars/__init__.py @@ -14,8 +14,8 @@ from .model_desc import MODEL_DESC, RemoteMMARKeys from .utils import ( find_refs_in_config, - instantiable, is_expression, + is_instantiable, match_refs_pattern, resolve_config_with_refs, resolve_refs_pattern, diff --git a/monai/apps/mmars/config_item.py b/monai/apps/mmars/config_item.py index b30c522406..536743c5e9 100644 --- a/monai/apps/mmars/config_item.py +++ b/monai/apps/mmars/config_item.py @@ -16,7 +16,7 @@ from importlib import import_module from typing import Any, Dict, List, Optional, Sequence, Union -from monai.apps.mmars.utils import find_refs_in_config, instantiable, resolve_config_with_refs +from monai.apps.mmars.utils import find_refs_in_config, is_instantiable, resolve_config_with_refs from monai.utils import ensure_tuple, instantiate __all__ = ["ComponentLocator", "ConfigItem", "ConfigComponent"] @@ -218,51 +218,41 @@ def get_resolved_config(self): class Instantiable(ABC): """ - Base class for instantiable object and provide the `instantiate` API. + Base class for instantiable object with module name and arguments. """ @abstractmethod - def _resolve_module_name(self): + def resolve_module_name(self, *args: Any, **kwargs: Any): """ - Utility function used in `instantiate()` to resolve the target module name. + Utility function to resolve the target module name. """ raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") @abstractmethod - def _resolve_args(self): + def resolve_args(self, *args: Any, **kwargs: Any): """ - Utility function used in `instantiate()` to resolve the arguments. + Utility function to resolve the arguments. """ raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") @abstractmethod - def _is_disabled(self): + def is_disabled(self, *args: Any, **kwargs: Any): """ - Utility function used in `instantiate()` to check whether the target component is disabled. + Utility function to check whether the target component is disabled. """ raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") - def instantiate(self, **kwargs) -> object: + @abstractmethod + def instantiate(self, *args: Any, **kwargs: Any): """ Instantiate the target component. - Args: - kwargs: args to override / add the kwargs when instantiation. - """ - - if self._is_disabled(): - # if marked as `disabled`, skip parsing and return `None` - return None - - modname = self._resolve_module_name() - args = self._resolve_args() - args.update(kwargs) - return instantiate(modname, **args) + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") class ConfigComponent(ConfigItem, Instantiable): @@ -321,7 +311,7 @@ def __init__( super().__init__(config=config, id=id, globals=globals) self.locator = ComponentLocator(excludes=excludes) if locator is None else locator - def _resolve_module_name(self): + def resolve_module_name(self): """ Utility function used in `instantiate()` to resolve the target module name from provided config content. The config content must have `` or ``. @@ -349,21 +339,21 @@ def _resolve_module_name(self): module = module[0] return f"{module}.{name}" - def _resolve_args(self): + def resolve_args(self): """ Utility function used in `instantiate()` to resolve the arguments from config content of target component. """ return self.get_resolved_config().get("", {}) - def _is_disabled(self): + def is_disabled(self): """ Utility function used in `instantiate()` to check whether the target component is disabled. """ return self.get_resolved_config().get("", False) - def instantiate(self, **kwargs) -> object: + def instantiate(self, **kwargs) -> object: # type: ignore """ Instantiate component based on the resolved config content. The target component must be a `class` or a `function`, otherwise, return `None`. @@ -378,7 +368,11 @@ def instantiate(self, **kwargs) -> object: " please try to resolve the references first." ) return None - if not instantiable(self.get_resolved_config()): - # if not a class or function, skip parsing and return `None` + if not is_instantiable(self.get_resolved_config()) or self.is_disabled(): + # if not a class or function or marked as `disabled`, skip parsing and return `None` return None - return super().instantiate(**kwargs) + + modname = self.resolve_module_name() + args = self.resolve_args() + args.update(kwargs) + return instantiate(modname, **args) diff --git a/monai/apps/mmars/utils.py b/monai/apps/mmars/utils.py index c1b5665efe..bc7b516886 100644 --- a/monai/apps/mmars/utils.py +++ b/monai/apps/mmars/utils.py @@ -118,7 +118,7 @@ def resolve_config_with_refs( refs: all the referring components with ids, default to `None`. globals: predefined global variables to execute code string with `eval()`. match_fn: callable function to match config item for references, take `config`, - `refs` and `globals` as parameter. + `refs` and `globals` as parameters. """ refs_: Dict = {} if refs is None else refs diff --git a/tests/test_config_item.py b/tests/test_config_item.py index 7ec03c3d22..af6a02802b 100644 --- a/tests/test_config_item.py +++ b/tests/test_config_item.py @@ -17,7 +17,7 @@ from parameterized import parameterized import monai -from monai.apps import ComponentLocator, ConfigComponent, ConfigItem, instantiable +from monai.apps import ComponentLocator, ConfigComponent, ConfigItem, is_instantiable from monai.data import DataLoader, Dataset from monai.transforms import LoadImaged, RandTorchVisiond from monai.utils import optional_import @@ -115,7 +115,7 @@ def test_resolve_references(self, id, test_input, refs, output_type): ) configer.resolve(refs=refs) ret = configer.get_resolved_config() - if instantiable(ret): + if is_instantiable(ret): ret = configer.instantiate(**{}) # also test kwargs self.assertTrue(isinstance(ret, output_type)) From d3646062163a61230526da97f0f6e2e3cbb7118a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 16 Feb 2022 10:57:57 +0000 Subject: [PATCH 19/25] update instantiate util Signed-off-by: Wenqi Li --- monai/utils/module.py | 8 +++++--- tests/test_config_item.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/monai/utils/module.py b/monai/utils/module.py index 1dcbe6849f..8b7745c3ee 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -198,8 +198,8 @@ def load_submodules(basemod, load_all: bool = True, exclude_pattern: str = "(.*[ def instantiate(path: str, **kwargs): """ - Method for creating an instance for the specified class / function path. - kwargs will be class args or default args for `partial` function. + Create an object instance or partial function from a class or function represented by string. + `kwargs` will be part of the input arguments to the class constructor or function. The target component must be a class or a function, if not, return the component directly. Args: @@ -210,12 +210,14 @@ def instantiate(path: str, **kwargs): """ component = locate(path) + if component is None: + raise ModuleNotFoundError(f"Cannot locate '{path}'.") if inspect.isclass(component): return component(**kwargs) if inspect.isfunction(component): return partial(component, **kwargs) - warnings.warn(f"target component must be a valid class or function, but got {path}.") + warnings.warn(f"Component to instantiate must represent a valid class or function, but got {path}.") return component diff --git a/tests/test_config_item.py b/tests/test_config_item.py index af6a02802b..32c121dc9b 100644 --- a/tests/test_config_item.py +++ b/tests/test_config_item.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at From aa5f95d181bc9abfa1b5fa0c8b15dee0fa6425d2 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 17 Feb 2022 01:02:40 +0800 Subject: [PATCH 20/25] [DLMED] optimize design Signed-off-by: Nic Ma --- docs/source/apps.rst | 3 + monai/apps/__init__.py | 18 +- monai/apps/manifest/__init__.py | 13 + monai/apps/{mmars => manifest}/config_item.py | 288 ++++++++---------- monai/apps/manifest/utils.py | 36 +++ monai/apps/mmars/__init__.py | 9 - monai/apps/mmars/utils.py | 169 ---------- tests/test_component_locator.py | 2 +- tests/test_config_item.py | 96 ++---- 9 files changed, 209 insertions(+), 425 deletions(-) create mode 100644 monai/apps/manifest/__init__.py rename monai/apps/{mmars => manifest}/config_item.py (55%) create mode 100644 monai/apps/manifest/utils.py delete mode 100644 monai/apps/mmars/utils.py diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 89b22d1d4a..e02246af56 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -38,6 +38,9 @@ Model Package .. autoclass:: ConfigComponent :members: +.. autoclass:: ConfigExpression + :members: + .. autoclass:: ConfigItem :members: diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index 3e6f7440a1..9ed6a90609 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -10,20 +10,6 @@ # limitations under the License. from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset -from .mmars import ( - MODEL_DESC, - ComponentLocator, - ConfigComponent, - ConfigItem, - RemoteMMARKeys, - download_mmar, - find_refs_in_config, - get_model_spec, - is_expression, - is_instantiable, - load_from_mmar, - match_refs_pattern, - resolve_config_with_refs, - resolve_refs_pattern, -) +from .manifest import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, is_expression, is_instantiable +from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/manifest/__init__.py b/monai/apps/manifest/__init__.py new file mode 100644 index 0000000000..6f68ed1226 --- /dev/null +++ b/monai/apps/manifest/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem +from .utils import is_expression, is_instantiable diff --git a/monai/apps/mmars/config_item.py b/monai/apps/manifest/config_item.py similarity index 55% rename from monai/apps/mmars/config_item.py rename to monai/apps/manifest/config_item.py index 536743c5e9..121c28335e 100644 --- a/monai/apps/mmars/config_item.py +++ b/monai/apps/manifest/config_item.py @@ -16,12 +16,51 @@ from importlib import import_module from typing import Any, Dict, List, Optional, Sequence, Union -from monai.apps.mmars.utils import find_refs_in_config, is_instantiable, resolve_config_with_refs +from monai.apps.manifest.utils import is_expression, is_instantiable from monai.utils import ensure_tuple, instantiate __all__ = ["ComponentLocator", "ConfigItem", "ConfigComponent"] +class Instantiable(ABC): + """ + Base class for instantiable object with module name and arguments. + + """ + + @abstractmethod + def resolve_module_name(self, *args: Any, **kwargs: Any): + """ + Utility function to resolve the target module name. + + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def resolve_args(self, *args: Any, **kwargs: Any): + """ + Utility function to resolve the arguments. + + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def is_disabled(self, *args: Any, **kwargs: Any): + """ + Utility function to check whether the target component is disabled. + + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def instantiate(self, *args: Any, **kwargs: Any): + """ + Instantiate the target component. + + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + class ComponentLocator: """ Scan all the available classes and functions in the MONAI package and map them with the module paths in a table. @@ -91,71 +130,44 @@ def get_component_module_name(self, name) -> Optional[Union[List[str], str]]: class ConfigItem: """ - Utility class to manage every item of the whole config content. + Base class of every item of the whole config content. When recursively parsing a complicated config content, every item (like: dict, list, string, int, float, etc.) can be treated as an "config item" then construct a `ConfigItem`. - For example, below are 5 config items when recursively parsing: - - a dict: `{"preprocessing": ["@transform1", "@transform2", "$lambda x: x"]}` - - a list: `["@transform1", "@transform2", "$lambda x: x"]` - - a string: `"transform1"` - - a string: `"transform2"` - - a string: `"$lambda x: x"` - - `ConfigItem` can set optional unique ID name, then another config item may refer to it. - The references mean the IDs of other config items used as "@XXX" in the current config item, for example: - config item with ID="A" is a list `[1, 2, 3]`, another config item can be `"args": {"input_list": "@A"}`. - If sub-item in the config is instantiable, also treat it as reference because must instantiate it before - resolving current config. - It can search the config content and find out all the references, and resolve the config content - when all the references are resolved. - - Here we predefined 3 kinds special marks (`#`, `@`, `$`) when parsing the whole config content: - - "XXX#YYY": join nested config IDs, like "transforms#5" is ID name of the 6th transform in a list ID="transforms". - - "@XXX": current config item refers to another config item XXX, like `{"args": {"data": "@dataset"}}` uses - resolved config content of `dataset` as the parameter "data". - - "$XXX": execute the string after "$" as python code with `eval()` function, like "$@model.parameters()". + For example, below are 4 config items when recursively parsing: + - a dict: `{"transform_keys": ["image", "label"]}` + - a list: `["image", "label"]` + - a string: `"image"` + - a string: `"label"` - The typical usage of the APIs: - - Initialize with config content. - - If having references, get the IDs of its referring components. - - When all the referring components are resolved, resolve the config content with them, - and execute expressions in the config. + `ConfigItem` can set optional unique ID name to identify itself. + + A typical usage of the APIs: + - Initialize / update with config content. + - Get the config content at runtime and modify something in it. + - Update the config content with new one. .. code-block:: python - config = {"lr": "$@epoch / 1000"} + config = {"lr": 0.001} - configer = ConfigComponent(config, id="test") - dep_ids = configer.get_id_of_refs() - configer.resolve_config(refs={"epoch": 10}) - lr = configer.get_resolved_config() + item = ConfigItem(config, id="test") + conf = item.get_config() + conf["lr"] = 0.0001 + item.update_config(conf) Args: config: content of a config item, can be a `dict`, `list`, `string`, `float`, `int`, etc. - id: ID name of current config item, useful to construct referring config items. - for example, config item A may have ID "transforms#A" and config item B refers to A - and uses the resolved config content of A as an arg `{"args": {"other": "@transforms#A"}}`. - `id` defaults to `None`, if some component refers to current component, `id` must be a `string`. - globals: to support executable string in the config, sometimes we need to provide the global variables - which are referred in the executable string. for example: `globals={"monai": monai} will be useful - for config `{"collate_fn": "$monai.data.list_data_collate"}`. + id: optional ID name of current config item, defaults to `None`. """ - def __init__(self, config: Any, id: Optional[str] = None, globals: Optional[Dict] = None) -> None: + def __init__(self, config: Any, id: Optional[str] = None) -> None: self.config = config - self.resolved_config = None - self.is_resolved = False self.id = id - self.globals = globals - self.update_config(config=config) def get_id(self) -> Optional[str]: """ - ID name of current config item, useful to construct referring config items. - for example, config item A may have ID "transforms#A" and config item B refers to A - and uses the resolved config content of A as an arg `{"args": {"other": "@transforms#A"}}`. - `id` defaults to `None`, if some component refers to current component, `id` must be a `string`. + Get the ID name of current config item, useful to identify config items during parsing. """ return self.id @@ -163,7 +175,6 @@ def get_id(self) -> Optional[str]: def update_config(self, config: Any): """ Update the config content for a config item at runtime. - If having references, need resolve the config later. A typical usage is to modify the initial config content at runtime and set back. Args: @@ -171,94 +182,20 @@ def update_config(self, config: Any): """ self.config = config - self.resolved_config = None - self.is_resolved = False - if not self.get_id_of_refs(): - # if no references, can resolve the config immediately - self.resolve(refs=None) def get_config(self): """ - Get the initial config content of current config item, usually set at the constructor. - It can be useful to dynamically update the config content before resolving. + Get the config content of current config item. + It can be useful to update the config content at runtime. """ return self.config - def get_id_of_refs(self) -> List[str]: - """ - Recursively search all the content of current config item to get the IDs of references. - It's used to detect and resolve all the references before resolving current config item. - For `dict` and `list`, recursively check the sub-items. - For example: `{"args": {"lr": "$@epoch / 1000"}}`, the reference IDs: `["epoch"]`. - - """ - return find_refs_in_config(self.config, id=self.id) - - def resolve(self, refs: Optional[Dict] = None): - """ - If all the references are resolved in `refs`, resolve the config content with them to construct `resolved_config`. - - Args: - refs: all the resolved referring items with ID as keys, default to `None`. - - """ - self.resolved_config = resolve_config_with_refs(self.config, id=self.id, refs=refs, globals=self.globals) - self.is_resolved = True - - def get_resolved_config(self): - """ - Get the resolved config content, constructed in `resolve_config()`. The returned config has no references, - then use it in the program, for example: initial config item `{"intervals": "@epoch / 10"}` and references - `{"epoch": 100}`, the resolved config will be `{"intervals": 10}`. - - """ - return self.resolved_config - - -class Instantiable(ABC): - """ - Base class for instantiable object with module name and arguments. - - """ - - @abstractmethod - def resolve_module_name(self, *args: Any, **kwargs: Any): - """ - Utility function to resolve the target module name. - - """ - raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") - - @abstractmethod - def resolve_args(self, *args: Any, **kwargs: Any): - """ - Utility function to resolve the arguments. - - """ - raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") - - @abstractmethod - def is_disabled(self, *args: Any, **kwargs: Any): - """ - Utility function to check whether the target component is disabled. - - """ - raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") - - @abstractmethod - def instantiate(self, *args: Any, **kwargs: Any): - """ - Instantiate the target component. - - """ - raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") - class ConfigComponent(ConfigItem, Instantiable): """ - Subclass of :py:class:`monai.apps.ConfigItem`, the config item represents a target component of class - or function, and support to instantiate the component. Example of config item: + Subclass of :py:class:`monai.apps.ConfigItem`, the config item represents a target component of `class` + or `function`, and support to instantiate the component. Example of config item: `{"": "LoadImage", "": {"keys": "image"}}` Here we predefined 4 keys: ``, ``, ``, `` for component config: @@ -268,35 +205,24 @@ class ConfigComponent(ConfigItem, Instantiable): - '' - if defined `'': True`, will skip the buiding, useful for development or tuning. The typical usage of the APIs: - - Initialize with config content. - - If no references, `instantiate` the component if having "" or "" keywords and return the instance. - - If having references, get the IDs of its referring components. - - When all the referring components are resolved, resolve the config content with them, execute expressions in - the config and `instantiate`. + - Initialize / update config content. + - `instantiate` the component if having "" or "" keywords and return the instance. .. code-block:: python locator = ComponentLocator(excludes=[""]) - config = {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}} + config = {"": "LoadImaged", "": {"keys": ["image", "label"]}} - configer = ConfigComponent(config, id="test_config", locator=locator) - configer.resolve_config(refs={"dataset": Dataset(data=[1, 2])}) - configer.get_resolved_config() + configer = ConfigComponent(config, id="test", locator=locator) dataloader: DataLoader = configer.instantiate() Args: - config: content of a component config item, should be a dict with `` or `` key. - id: ID name of current config item, useful to construct referring config items. - for example, config item A may have ID "transforms#A" and config item B refers to A - and uses the resolved config content of A as an arg `{"args": {"other": "@transforms#A"}}`. - `id` defaults to `None`, if some component refers to current component, `id` must be a `string`. - locator: `ComponentLocator` to help locate the module path of `` in the config and instantiate. + config: content of a config item, can be a `dict`, `list`, `string`, `float`, `int`, etc. + id: optional ID name of current config item, defaults to `None`. + locator: `ComponentLocator` to help locate the module path of `` in the config and instantiate it. if `None`, will create a new `ComponentLocator` with specified `excludes`. excludes: if `locator` is None, create a new `ComponentLocator` with `excludes`. any string of the `excludes` exists in the module name, don't import this module. - globals: to support executable string in the config, sometimes we need to provide the global variables - which are referred in the executable string. for example: `globals={"monai": monai} will be useful - for config `"collate_fn": "$monai.data.list_data_collate"`. """ @@ -306,18 +232,17 @@ def __init__( id: Optional[str] = None, locator: Optional[ComponentLocator] = None, excludes: Optional[Union[Sequence[str], str]] = None, - globals: Optional[Dict] = None, ) -> None: - super().__init__(config=config, id=id, globals=globals) + super().__init__(config=config, id=id) self.locator = ComponentLocator(excludes=excludes) if locator is None else locator def resolve_module_name(self): """ - Utility function used in `instantiate()` to resolve the target module name from provided config content. + Utility function used in `instantiate()` to resolve the target module name from current config content. The config content must have `` or ``. """ - config = self.get_resolved_config() + config = self.get_config() path = config.get("", None) if path is not None: if "" in config: @@ -341,34 +266,28 @@ def resolve_module_name(self): def resolve_args(self): """ - Utility function used in `instantiate()` to resolve the arguments from config content of target component. + Utility function used in `instantiate()` to resolve the arguments from current config content. """ - return self.get_resolved_config().get("", {}) + return self.get_config().get("", {}) def is_disabled(self): """ - Utility function used in `instantiate()` to check whether the target component is disabled. + Utility function used in `instantiate()` to check whether the current component is `disabled`. """ - return self.get_resolved_config().get("", False) + return self.get_config().get("", False) def instantiate(self, **kwargs) -> object: # type: ignore """ - Instantiate component based on the resolved config content. + Instantiate component based on current config content. The target component must be a `class` or a `function`, otherwise, return `None`. Args: kwargs: args to override / add the config args when instantiation. """ - if not self.is_resolved: - warnings.warn( - "the config content of current component has not been resolved," - " please try to resolve the references first." - ) - return None - if not is_instantiable(self.get_resolved_config()) or self.is_disabled(): + if not is_instantiable(self.get_config()) or self.is_disabled(): # if not a class or function or marked as `disabled`, skip parsing and return `None` return None @@ -376,3 +295,50 @@ def instantiate(self, **kwargs) -> object: # type: ignore args = self.resolve_args() args.update(kwargs) return instantiate(modname, **args) + + +class ConfigExpression(ConfigItem): + """ + Subclass of :py:class:`monai.apps.ConfigItem`, the config item represents an executable expression + string started with "$" mark, and support to execute based on python `eval()`, more details: + https://docs.python.org/3/library/functions.html#eval. + + An example of config item: `{"test_fn": "$lambda x: x + 100"}}` + + The typical usage of the APIs: + - Initialize / update config content. + - `execute` the config content if it is expression. + + .. code-block:: python + + config = "$monai.data.list_data_collate" + + expression = ConfigExpression(config, id="test", globals={"monai": monai}) + dataloader = DataLoader(..., collate_fn=expression.execute()) + + Args: + config: content of a config item, can be a `dict`, `list`, `string`, `float`, `int`, etc. + id: optional ID name of current config item, defaults to `None`. + globals: to execute expression string, sometimes we need to provide the global variables which are + referred in the expression string. for example: `globals={"monai": monai}` will be useful for + config `{"collate_fn": "$monai.data.list_data_collate"}`. + + """ + + def __init__(self, config: Any, id: Optional[str] = None, globals: Optional[Dict] = None) -> None: + super().__init__(config=config, id=id) + self.globals = globals + + def execute(self, locals: Optional[Dict] = None): + """ + Excute current config content and return the result if it is expression, based on python `eval()`. + For more details: https://docs.python.org/3/library/functions.html#eval. + + Args: + locals: besides `globals`, may also have some local variables used in the expression at runtime. + + """ + value = self.get_config() + if not is_expression(value): + return None + return eval(value[1:], self.globals, locals) diff --git a/monai/apps/manifest/utils.py b/monai/apps/manifest/utils.py new file mode 100644 index 0000000000..19d26e2d2d --- /dev/null +++ b/monai/apps/manifest/utils.py @@ -0,0 +1,36 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Union + + +def is_instantiable(config: Union[Dict, List, str]) -> bool: + """ + Check whether the content of the config represents a `class` or `function` to instantiate + with specified "" or "". + + Args: + config: input config content to check. + + """ + return isinstance(config, dict) and ("" in config or "" in config) + + +def is_expression(config: Union[Dict, List, str]) -> bool: + """ + Check whether the content of the config is executable expression string. + If True, the string should start with "$" mark. + + Args: + config: input config content to check. + + """ + return isinstance(config, str) and config.startswith("$") diff --git a/monai/apps/mmars/__init__.py b/monai/apps/mmars/__init__.py index e54cb7ac9d..8f1448bb06 100644 --- a/monai/apps/mmars/__init__.py +++ b/monai/apps/mmars/__init__.py @@ -9,14 +9,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_item import ComponentLocator, ConfigComponent, ConfigItem from .mmars import download_mmar, get_model_spec, load_from_mmar from .model_desc import MODEL_DESC, RemoteMMARKeys -from .utils import ( - find_refs_in_config, - is_expression, - is_instantiable, - match_refs_pattern, - resolve_config_with_refs, - resolve_refs_pattern, -) diff --git a/monai/apps/mmars/utils.py b/monai/apps/mmars/utils.py deleted file mode 100644 index bc7b516886..0000000000 --- a/monai/apps/mmars/utils.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -from typing import Callable, Dict, List, Optional, Union - - -def match_refs_pattern(value: str) -> List[str]: - """ - Match regular expression for the input string to find the references. - The reference part starts with "@", like: "@XXX#YYY#ZZZ". - - Args: - value: input value to match regular expression. - - """ - refs: List[str] = [] - # regular expression pattern to match "@XXX" or "@XXX#YYY" - result = re.compile(r"@\w*[\#\w]*").findall(value) - for item in result: - if is_expression(value) or value == item: - # only check when string starts with "$" or the whole content is "@XXX" - ref_obj_id = item[1:] - if ref_obj_id not in refs: - refs.append(ref_obj_id) - return refs - - -def resolve_refs_pattern(value: str, refs: Dict, globals: Optional[Dict] = None) -> str: - """ - Match regular expression for the input string to update content with the references. - The reference part starts with "@", like: "@XXX#YYY#ZZZ". - References dictionary must contain the referring IDs as keys. - - Args: - value: input value to match regular expression. - refs: all the referring components with ids as keys, default to `None`. - globals: predefined global variables to execute code string with `eval()`. - - """ - # regular expression pattern to match "@XXX" or "@XXX#YYY" - result = re.compile(r"@\w*[\#\w]*").findall(value) - for item in result: - ref_id = item[1:] - if is_expression(value): - # replace with local code and execute later - value = value.replace(item, f"refs['{ref_id}']") - elif value == item: - if ref_id not in refs: - raise KeyError(f"can not find expected ID '{ref_id}' in the references.") - value = refs[ref_id] - if is_expression(value): - # execute the code string with python `eval()` - value = eval(value[1:], globals, {"refs": refs}) - return value - - -def find_refs_in_config( - config: Union[Dict, List, str], - id: Optional[str] = None, - refs: Optional[List[str]] = None, - match_fn: Callable = match_refs_pattern, -) -> List[str]: - """ - Recursively search all the content of input config item to get the ids of references. - References mean (1) referring to the ID of other item, can be extracted by `match_fn`, for example: - `{"lr": "$@epoch / 100"}` with "@" mark, the referring IDs: `["epoch"]`. (2) if sub-item in the config - is instantiable, treat it as reference because must instantiate it before resolving current config. - For `dict` and `list`, recursively check the sub-items. - - Args: - config: input config content to search. - id: ID name for the input config, default to `None`. - refs: list of the ID name of existing references, default to `None`. - match_fn: callable function to match config item for references, take `config` as parameter. - - """ - refs_: List[str] = [] if refs is None else refs - if isinstance(config, str): - refs_ += match_fn(value=config) - - if isinstance(config, list): - for i, v in enumerate(config): - sub_id = f"{id}#{i}" if id is not None else f"{i}" - if is_instantiable(v): - refs_.append(sub_id) - refs_ = find_refs_in_config(v, sub_id, refs_, match_fn) - if isinstance(config, dict): - for k, v in config.items(): - sub_id = f"{id}#{k}" if id is not None else f"{k}" - if is_instantiable(v): - refs_.append(sub_id) - refs_ = find_refs_in_config(v, sub_id, refs_, match_fn) - return refs_ - - -def resolve_config_with_refs( - config: Union[Dict, List, str], - id: Optional[str] = None, - refs: Optional[Dict] = None, - globals: Optional[Dict] = None, - match_fn: Callable = resolve_refs_pattern, -): - """ - With all the references in `refs`, resolve the config content with them and return new config. - - Args: - config: input config content to resolve. - id: ID name for the input config, default to `None`. - refs: all the referring components with ids, default to `None`. - globals: predefined global variables to execute code string with `eval()`. - match_fn: callable function to match config item for references, take `config`, - `refs` and `globals` as parameters. - - """ - refs_: Dict = {} if refs is None else refs - if isinstance(config, str): - config = match_fn(config, refs, globals) - if isinstance(config, list): - # all the items in the list should be replaced with the references - ret_list: List = [] - for i, v in enumerate(config): - sub = f"{id}#{i}" if id is not None else f"{i}" - ret_list.append( - refs_[sub] if is_instantiable(v) else resolve_config_with_refs(v, sub, refs_, globals, match_fn) - ) - return ret_list - if isinstance(config, dict): - # all the items in the dict should be replaced with the references - ret_dict: Dict = {} - for k, v in config.items(): - sub = f"{id}#{k}" if id is not None else f"{k}" - ret_dict.update( - {k: refs_[sub] if is_instantiable(v) else resolve_config_with_refs(v, sub, refs_, globals, match_fn)} - ) - return ret_dict - return config - - -def is_instantiable(config: Union[Dict, List, str]) -> bool: - """ - Check whether the content of the config represents a `class` or `function` to instantiate - with specified "" or "". - - Args: - config: input config content to check. - - """ - return isinstance(config, dict) and ("" in config or "" in config) - - -def is_expression(config: Union[Dict, List, str]) -> bool: - """ - Check whether the content of the config is executable expression string. - If True, the string should start with "$" mark. - - Args: - config: input config content to check. - - """ - return isinstance(config, str) and config.startswith("$") diff --git a/tests/test_component_locator.py b/tests/test_component_locator.py index adff25457a..eafb2152d1 100644 --- a/tests/test_component_locator.py +++ b/tests/test_component_locator.py @@ -12,7 +12,7 @@ import unittest from pydoc import locate -from monai.apps.mmars import ComponentLocator +from monai.apps.manifest import ComponentLocator from monai.utils import optional_import _, has_ignite = optional_import("ignite") diff --git a/tests/test_config_item.py b/tests/test_config_item.py index 32c121dc9b..2650cc930d 100644 --- a/tests/test_config_item.py +++ b/tests/test_config_item.py @@ -11,26 +11,26 @@ import unittest from functools import partial -from typing import Callable, Iterator +from typing import Callable import torch from parameterized import parameterized import monai -from monai.apps import ComponentLocator, ConfigComponent, ConfigItem, is_instantiable +from monai.apps import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem from monai.data import DataLoader, Dataset from monai.transforms import LoadImaged, RandTorchVisiond from monai.utils import optional_import _, has_tv = optional_import("torchvision") -TEST_CASE_1 = [{"": "LoadImaged", "": {"keys": ["image"]}}, LoadImaged] +TEST_CASE_1 = [{"lr": 0.001}, 0.0001] + +TEST_CASE_2 = [{"": "LoadImaged", "": {"keys": ["image"]}}, LoadImaged] # test python `` -TEST_CASE_2 = [{"": "monai.transforms.LoadImaged", "": {"keys": ["image"]}}, LoadImaged] +TEST_CASE_3 = [{"": "monai.transforms.LoadImaged", "": {"keys": ["image"]}}, LoadImaged] # test `` -TEST_CASE_3 = [{"": "LoadImaged", "": True, "": {"keys": ["image"]}}, dict] -# test unresolved reference -TEST_CASE_4 = [{"": "LoadImaged", "": {"keys": ["@key_name"]}}] +TEST_CASE_4 = [{"": "LoadImaged", "": True, "": {"keys": ["image"]}}, dict] # test non-monai modules and excludes TEST_CASE_5 = [ {"": "torch.optim.Adam", "": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, @@ -42,49 +42,25 @@ {"": "RandTorchVisiond", "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}}, RandTorchVisiond, ] -# test references of dict config -TEST_CASE_8 = [{"dataset": "@dataset", "batch_size": 2}, ["dataset"]] -# test references of list config -TEST_CASE_9 = [{"dataset": "@dataset", "transforms": ["@trans0", "@trans1"]}, ["dataset", "trans0", "trans1"]] -# test references of execute code -TEST_CASE_10 = [ - {"dataset": "$@dataset.test_func()", "transforms": ["$torch.zeros([2, 2]) + @trans"]}, - ["dataset", "trans"], -] -# test references of lambda function -TEST_CASE_11 = [ - {"lr_range": "$lambda x: x + @num_epochs", "lr": ["$lambda x: torch.zeros([2, 2]) + @init_lr"]}, - ["num_epochs", "init_lr"], -] -# test instance with no references -TEST_CASE_12 = ["transform#1", {"": "LoadImaged", "": {"keys": ["image"]}}, {}, LoadImaged] -# test dataloader refers to `@dataset`, here we don't test recursive references, test that in `ConfigResolver` -TEST_CASE_13 = [ - "dataloader", - {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}}, - {"dataset": Dataset(data=[1, 2])}, - DataLoader, -] -# test references in code execution -TEST_CASE_14 = [ - "optimizer", - {"": "torch.optim.Adam", "": {"params": "$@model.parameters()", "lr": "@learning_rate"}}, - {"model": torch.nn.PReLU(), "learning_rate": 1e-4}, - torch.optim.Adam, -] -# test replace references with code execution result -TEST_CASE_15 = ["optimizer##params", "$@model.parameters()", {"model": torch.nn.PReLU()}, Iterator] # test execute some function in args, test pre-imported global packages `monai` -TEST_CASE_16 = ["dataloader##collate_fn", "$monai.data.list_data_collate", {}, Callable] -# test lambda function, should not execute the lambda function, just change the string with referring objects -TEST_CASE_17 = ["dataloader##collate_fn", "$lambda x: monai.data.list_data_collate(x) + 100", {}, Callable] +TEST_CASE_8 = ["collate_fn", "$monai.data.list_data_collate"] +# test lambda function, should not execute the lambda function, just change the string +TEST_CASE_9 = ["collate_fn", "$lambda x: monai.data.list_data_collate(x) + torch.tensor(var)"] class TestConfigItem(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) + def test_item(self, test_input, expected): + item = ConfigItem(config=test_input) + conf = item.get_config() + conf["lr"] = 0.0001 + item.update_config(config=conf) + self.assertEqual(item.get_config()["lr"], expected) + @parameterized.expand( - [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_5, TEST_CASE_6] + ([TEST_CASE_7] if has_tv else []) + [TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6] + ([TEST_CASE_7] if has_tv else []) ) - def test_instantiate(self, test_input, output_type): + def test_component(self, test_input, output_type): locator = ComponentLocator(excludes=["metrics"]) configer = ConfigComponent(id="test", config=test_input, locator=locator) ret = configer.instantiate() @@ -96,39 +72,21 @@ def test_instantiate(self, test_input, output_type): if isinstance(ret, LoadImaged): self.assertEqual(ret.keys[0], "image") - @parameterized.expand([TEST_CASE_4]) - def test_raise_error(self, test_input): - with self.assertRaises(KeyError): # has unresolved keys - configer = ConfigItem(id="test", config=test_input) - configer.resolve() - - @parameterized.expand([TEST_CASE_8, TEST_CASE_9, TEST_CASE_10, TEST_CASE_11]) - def test_referring_ids(self, test_input, ref_ids): - configer = ConfigItem(id="test", config=test_input) # also test default locator - ret = configer.get_id_of_refs() - self.assertListEqual(ret, ref_ids) - - @parameterized.expand([TEST_CASE_12, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15, TEST_CASE_16, TEST_CASE_17]) - def test_resolve_references(self, id, test_input, refs, output_type): - configer = ConfigComponent( - id=id, config=test_input, locator=None, excludes=["utils"], globals={"monai": monai, "torch": torch} - ) - configer.resolve(refs=refs) - ret = configer.get_resolved_config() - if is_instantiable(ret): - ret = configer.instantiate(**{}) # also test kwargs - self.assertTrue(isinstance(ret, output_type)) + @parameterized.expand([TEST_CASE_8, TEST_CASE_9]) + def test_expression(self, id, test_input): + configer = ConfigExpression(id=id, config=test_input, globals={"monai": monai, "torch": torch}) + var = 100 + ret = configer.execute(locals={"var": var}) + self.assertTrue(isinstance(ret, Callable)) def test_lazy_instantiation(self): - config = {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}} - refs = {"dataset": Dataset(data=[1, 2])} + config = {"": "DataLoader", "": {"dataset": Dataset(data=[1, 2]), "batch_size": 2}} configer = ConfigComponent(config=config, locator=None) init_config = configer.get_config() # modify config content at runtime init_config[""]["batch_size"] = 4 configer.update_config(config=init_config) - configer.resolve(refs=refs) ret = configer.instantiate() self.assertTrue(isinstance(ret, DataLoader)) self.assertEqual(ret.batch_size, 4) From a20beb86910d5b0cba9fbbcecbebd3ae5342b15b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 16 Feb 2022 19:10:15 +0000 Subject: [PATCH 21/25] update docstring Signed-off-by: Wenqi Li --- monai/apps/manifest/config_item.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/monai/apps/manifest/config_item.py b/monai/apps/manifest/config_item.py index 121c28335e..6e88a64fc6 100644 --- a/monai/apps/manifest/config_item.py +++ b/monai/apps/manifest/config_item.py @@ -26,29 +26,31 @@ class Instantiable(ABC): """ Base class for instantiable object with module name and arguments. + .. code-block:: python + + if not is_disabled(): + instantiate(module_name=resolve_module_name(), args=resolve_args()) + """ @abstractmethod def resolve_module_name(self, *args: Any, **kwargs: Any): """ - Utility function to resolve the target module name. - + Resolve the target module name, it should return an object class (or function) to be instantiated. """ raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") @abstractmethod def resolve_args(self, *args: Any, **kwargs: Any): """ - Utility function to resolve the arguments. - + Resolve the arguments, it should return arguments to be passed to the object when instantiating. """ raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") @abstractmethod - def is_disabled(self, *args: Any, **kwargs: Any): + def is_disabled(self, *args: Any, **kwargs: Any) -> bool: """ - Utility function to check whether the target component is disabled. - + Return a boolean flag to indicate whether the object should be instantiated. """ raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") @@ -56,7 +58,6 @@ def is_disabled(self, *args: Any, **kwargs: Any): def instantiate(self, *args: Any, **kwargs: Any): """ Instantiate the target component. - """ raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") @@ -111,7 +112,7 @@ def _find_classes_or_functions(self, modnames: Union[Sequence[str], str]) -> Dic def get_component_module_name(self, name) -> Optional[Union[List[str], str]]: """ - Get the full module name of the class / function with specified name. + Get the full module name of the class or function with specified ``name``. If target component name exists in multiple packages or modules, return a list of full module names. Args: From 9d1ca96e6c2e192c9c220c15e4fb714336d85a72 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 16 Feb 2022 23:11:42 +0000 Subject: [PATCH 22/25] updating ConfigComponent Signed-off-by: Wenqi Li --- monai/apps/manifest/config_item.py | 82 +++++++++++------------------- 1 file changed, 30 insertions(+), 52 deletions(-) diff --git a/monai/apps/manifest/config_item.py b/monai/apps/manifest/config_item.py index 6e88a64fc6..37a369a3e9 100644 --- a/monai/apps/manifest/config_item.py +++ b/monai/apps/manifest/config_item.py @@ -131,34 +131,15 @@ def get_component_module_name(self, name) -> Optional[Union[List[str], str]]: class ConfigItem: """ - Base class of every item of the whole config content. - When recursively parsing a complicated config content, every item (like: dict, list, string, int, float, etc.) - can be treated as an "config item" then construct a `ConfigItem`. - For example, below are 4 config items when recursively parsing: - - a dict: `{"transform_keys": ["image", "label"]}` - - a list: `["image", "label"]` - - a string: `"image"` - - a string: `"label"` - - `ConfigItem` can set optional unique ID name to identify itself. - - A typical usage of the APIs: - - Initialize / update with config content. - - Get the config content at runtime and modify something in it. - - Update the config content with new one. + Basic data structure to represent a configuration item. - .. code-block:: python - - config = {"lr": 0.001} - - item = ConfigItem(config, id="test") - conf = item.get_config() - conf["lr"] = 0.0001 - item.update_config(conf) + A `ConfigItem` instance can optionally have a string id, so that other items can refer to it. + It has a build-in `config` property to store the configuration object. Args: - config: content of a config item, can be a `dict`, `list`, `string`, `float`, `int`, etc. - id: optional ID name of current config item, defaults to `None`. + config: content of a config item, can be objects of any types, + a configuration resolver may interpret the content to generate a configuration object. + id: optional ID name of the current config item, defaults to `None`. """ @@ -175,11 +156,11 @@ def get_id(self) -> Optional[str]: def update_config(self, config: Any): """ - Update the config content for a config item at runtime. - A typical usage is to modify the initial config content at runtime and set back. + Replace the content of `self.config` with new `config`. + A typical usage is to modify the initial config content at runtime. Args: - config: content of a config item, can be a `dict`, `list`, `string`, `float`, `int`, etc. + config: content of a `ConfigItem`. """ self.config = config @@ -187,7 +168,6 @@ def update_config(self, config: Any): def get_config(self): """ Get the config content of current config item. - It can be useful to update the config content at runtime. """ return self.config @@ -195,35 +175,33 @@ def get_config(self): class ConfigComponent(ConfigItem, Instantiable): """ - Subclass of :py:class:`monai.apps.ConfigItem`, the config item represents a target component of `class` - or `function`, and support to instantiate the component. Example of config item: - `{"": "LoadImage", "": {"keys": "image"}}` + Subclass of :py:class:`monai.apps.ConfigItem`, this class uses a dictionary with string keys to + represent a component of `class` or `function` and supports instantiation. - Here we predefined 4 keys: ``, ``, ``, `` for component config: - - '' - class / function name in the modules of packages. - - '' - directly specify the module path, based on PYTHONPATH, ignore '' if specified. - - '' - arguments to initialize the component instance. - - '' - if defined `'': True`, will skip the buiding, useful for development or tuning. + Currently, four special keys (strings surrounded by ``<>``) are defined and interpreted beyond the regular literals: - The typical usage of the APIs: - - Initialize / update config content. - - `instantiate` the component if having "" or "" keywords and return the instance. + - class or function identifier of the python module, specified by one of the two keys. + - ``""``: indicates build-in python classes or functions such as "LoadImageDict". + - ``""``: full module name, such as "monai.transforms.LoadImageDict". + - ``""``: input arguments to the python module. + - ``""``: a boolean flag to indicate whether to skip the instantiation. .. code-block:: python - locator = ComponentLocator(excludes=[""]) + locator = ComponentLocator(excludes=["modules_to_exclude"]) config = {"": "LoadImaged", "": {"keys": ["image", "label"]}} configer = ConfigComponent(config, id="test", locator=locator) - dataloader: DataLoader = configer.instantiate() + image_loader = configer.instantiate() + print(image_loader) # Args: - config: content of a config item, can be a `dict`, `list`, `string`, `float`, `int`, etc. - id: optional ID name of current config item, defaults to `None`. - locator: `ComponentLocator` to help locate the module path of `` in the config and instantiate it. - if `None`, will create a new `ComponentLocator` with specified `excludes`. - excludes: if `locator` is None, create a new `ComponentLocator` with `excludes`. any string of the `excludes` - exists in the module name, don't import this module. + config: content of a config item. + id: optional name of the current config item, defaults to `None`. + locator: a `ComponentLocator` to convert a module name string into the actual python module. + if `None`, a ``ComponentLocator(excludes=excludes)`` will be used. + excludes: if ``locator`` is None, create a new ``ComponentLocator`` with ``excludes``. + See also: :py:class:`monai.apps.manifest.ComponentLocator`. """ @@ -243,14 +221,14 @@ def resolve_module_name(self): The config content must have `` or ``. """ - config = self.get_config() - path = config.get("", None) + config = dict(self.get_config()) + path = config.get("") if path is not None: if "" in config: - warnings.warn(f"should not set both '' and '', default to use '': {path}.") + warnings.warn(f"both '' and '', default to use '': {path}.") return path - name = config.get("", None) + name = config.get("") if name is None: raise ValueError("must provide `` or `` of target component to instantiate.") From 3ac88274899c7ebac3b04550655dd4760e44fb3a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 17 Feb 2022 00:49:04 +0000 Subject: [PATCH 23/25] revise confi* Signed-off-by: Wenqi Li --- docs/source/apps.rst | 4 +- monai/apps/__init__.py | 2 +- monai/apps/manifest/__init__.py | 1 - monai/apps/manifest/config_item.py | 90 ++++++++++++++++++++---------- monai/apps/manifest/utils.py | 36 ------------ 5 files changed, 62 insertions(+), 71 deletions(-) delete mode 100644 monai/apps/manifest/utils.py diff --git a/docs/source/apps.rst b/docs/source/apps.rst index e02246af56..6535ad82b7 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -29,8 +29,8 @@ Clara MMARs :annotation: -Model Package -------------- +Model Manifest +-------------- .. autoclass:: ComponentLocator :members: diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index 9ed6a90609..df085bddea 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -10,6 +10,6 @@ # limitations under the License. from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset -from .manifest import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, is_expression, is_instantiable +from .manifest import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/manifest/__init__.py b/monai/apps/manifest/__init__.py index 6f68ed1226..d8919a5249 100644 --- a/monai/apps/manifest/__init__.py +++ b/monai/apps/manifest/__init__.py @@ -10,4 +10,3 @@ # limitations under the License. from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem -from .utils import is_expression, is_instantiable diff --git a/monai/apps/manifest/config_item.py b/monai/apps/manifest/config_item.py index 37a369a3e9..a8e8ff134c 100644 --- a/monai/apps/manifest/config_item.py +++ b/monai/apps/manifest/config_item.py @@ -14,9 +14,8 @@ import warnings from abc import ABC, abstractmethod from importlib import import_module -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Mapping, Optional, Sequence, Union -from monai.apps.manifest.utils import is_expression, is_instantiable from monai.utils import ensure_tuple, instantiate __all__ = ["ComponentLocator", "ConfigItem", "ConfigComponent"] @@ -139,7 +138,7 @@ class ConfigItem: Args: config: content of a config item, can be objects of any types, a configuration resolver may interpret the content to generate a configuration object. - id: optional ID name of the current config item, defaults to `None`. + id: optional name of the current config item, defaults to `None`. """ @@ -184,12 +183,17 @@ class ConfigComponent(ConfigItem, Instantiable): - ``""``: indicates build-in python classes or functions such as "LoadImageDict". - ``""``: full module name, such as "monai.transforms.LoadImageDict". - ``""``: input arguments to the python module. - - ``""``: a boolean flag to indicate whether to skip the instantiation. + - ``""``: a flag to indicate whether to skip the instantiation. .. code-block:: python locator = ComponentLocator(excludes=["modules_to_exclude"]) - config = {"": "LoadImaged", "": {"keys": ["image", "label"]}} + config = { + "": "LoadImaged", + "": { + "keys": ["image", "label"] + } + } configer = ConfigComponent(config, id="test", locator=locator) image_loader = configer.instantiate() @@ -198,7 +202,7 @@ class ConfigComponent(ConfigItem, Instantiable): Args: config: content of a config item. id: optional name of the current config item, defaults to `None`. - locator: a `ComponentLocator` to convert a module name string into the actual python module. + locator: a ``ComponentLocator`` to convert a module name string into the actual python module. if `None`, a ``ComponentLocator(excludes=excludes)`` will be used. excludes: if ``locator`` is None, create a new ``ComponentLocator`` with ``excludes``. See also: :py:class:`monai.apps.manifest.ComponentLocator`. @@ -215,10 +219,22 @@ def __init__( super().__init__(config=config, id=id) self.locator = ComponentLocator(excludes=excludes) if locator is None else locator + @staticmethod + def is_instantiable(config: Any) -> bool: + """ + Check whether this config represents a `class` or `function` that is to be instantiated. + + Args: + config: input config content to check. + + """ + return isinstance(config, Mapping) and ("" in config or "" in config) + def resolve_module_name(self): """ - Utility function used in `instantiate()` to resolve the target module name from current config content. - The config content must have `` or ``. + Resolve the target module name from current config content. + The config content must have ``""`` or ``""``. + When both are specified, ``""`` will be used. """ config = dict(self.get_config()) @@ -250,23 +266,24 @@ def resolve_args(self): """ return self.get_config().get("", {}) - def is_disabled(self): + def is_disabled(self, *_args, **_kwargs) -> bool: """ - Utility function used in `instantiate()` to check whether the current component is `disabled`. + Utility function used in `instantiate()` to check whether to skip the instantiation. """ - return self.get_config().get("", False) + _is_disabled = self.get_config().get("", False) + return _is_disabled.lower().strip() == "false" if isinstance(_is_disabled, str) else bool(_is_disabled) def instantiate(self, **kwargs) -> object: # type: ignore """ - Instantiate component based on current config content. + Instantiate component based on ``self.config`` content. The target component must be a `class` or a `function`, otherwise, return `None`. Args: kwargs: args to override / add the config args when instantiation. """ - if not is_instantiable(self.get_config()) or self.is_disabled(): + if not ConfigComponent.is_instantiable(self.get_config()) or self.is_disabled(): # if not a class or function or marked as `disabled`, skip parsing and return `None` return None @@ -278,29 +295,28 @@ def instantiate(self, **kwargs) -> object: # type: ignore class ConfigExpression(ConfigItem): """ - Subclass of :py:class:`monai.apps.ConfigItem`, the config item represents an executable expression - string started with "$" mark, and support to execute based on python `eval()`, more details: - https://docs.python.org/3/library/functions.html#eval. + Subclass of :py:class:`monai.apps.ConfigItem`, the `ConfigItem` represents an executable expression + (execute based on ``eval()``). + + See also: - An example of config item: `{"test_fn": "$lambda x: x + 100"}}` + - https://docs.python.org/3/library/functions.html#eval. - The typical usage of the APIs: - - Initialize / update config content. - - `execute` the config content if it is expression. + For example: .. code-block:: python - config = "$monai.data.list_data_collate" + import monai + from monai.apps.manifest import ConfigExpression + config = "$monai.__version__" expression = ConfigExpression(config, id="test", globals={"monai": monai}) - dataloader = DataLoader(..., collate_fn=expression.execute()) + print(expression.execute()) Args: - config: content of a config item, can be a `dict`, `list`, `string`, `float`, `int`, etc. - id: optional ID name of current config item, defaults to `None`. - globals: to execute expression string, sometimes we need to provide the global variables which are - referred in the expression string. for example: `globals={"monai": monai}` will be useful for - config `{"collate_fn": "$monai.data.list_data_collate"}`. + config: content of a config item. + id: optional name of current config item, defaults to `None`. + globals: additional global context to evaluate the string. """ @@ -308,16 +324,28 @@ def __init__(self, config: Any, id: Optional[str] = None, globals: Optional[Dict super().__init__(config=config, id=id) self.globals = globals - def execute(self, locals: Optional[Dict] = None): + def execute(self, local_vars: Optional[Dict] = None): """ Excute current config content and return the result if it is expression, based on python `eval()`. For more details: https://docs.python.org/3/library/functions.html#eval. Args: - locals: besides `globals`, may also have some local variables used in the expression at runtime. + local_vars: besides ``globals``, may also have some local variables used in the expression at runtime. """ value = self.get_config() - if not is_expression(value): + if not ConfigExpression.is_expression(value): return None - return eval(value[1:], self.globals, locals) + return eval(value[1:], self.globals, local_vars) + + @staticmethod + def is_expression(config: Union[Dict, List, str]) -> bool: + """ + Check whether the config is an executable expression string. + Currently A string starts with ``"$"`` character is interpreted as an expression. + + Args: + config: input config content to check. + + """ + return isinstance(config, str) and config.startswith("$") diff --git a/monai/apps/manifest/utils.py b/monai/apps/manifest/utils.py deleted file mode 100644 index 19d26e2d2d..0000000000 --- a/monai/apps/manifest/utils.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, List, Union - - -def is_instantiable(config: Union[Dict, List, str]) -> bool: - """ - Check whether the content of the config represents a `class` or `function` to instantiate - with specified "" or "". - - Args: - config: input config content to check. - - """ - return isinstance(config, dict) and ("" in config or "" in config) - - -def is_expression(config: Union[Dict, List, str]) -> bool: - """ - Check whether the content of the config is executable expression string. - If True, the string should start with "$" mark. - - Args: - config: input config content to check. - - """ - return isinstance(config, str) and config.startswith("$") From 6d0aa5d2393aae7fa8c654acfba13ca10824c4fb Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 17 Feb 2022 11:54:38 +0800 Subject: [PATCH 24/25] [DLMED] fix unit tests Signed-off-by: Nic Ma --- monai/apps/manifest/config_item.py | 24 ++++++++++++++---------- tests/test_config_item.py | 17 ++++++++++------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/monai/apps/manifest/config_item.py b/monai/apps/manifest/config_item.py index a8e8ff134c..b47e231b0e 100644 --- a/monai/apps/manifest/config_item.py +++ b/monai/apps/manifest/config_item.py @@ -18,7 +18,7 @@ from monai.utils import ensure_tuple, instantiate -__all__ = ["ComponentLocator", "ConfigItem", "ConfigComponent"] +__all__ = ["ComponentLocator", "ConfigItem", "ConfigExpression", "ConfigComponent"] class Instantiable(ABC): @@ -109,7 +109,7 @@ def _find_classes_or_functions(self, modnames: Union[Sequence[str], str]) -> Dic pass return table - def get_component_module_name(self, name) -> Optional[Union[List[str], str]]: + def get_component_module_name(self, name: str) -> Optional[Union[List[str], str]]: """ Get the full module name of the class or function with specified ``name``. If target component name exists in multiple packages or modules, return a list of full module names. @@ -118,6 +118,8 @@ def get_component_module_name(self, name) -> Optional[Union[List[str], str]]: name: name of the expected class or function. """ + if not isinstance(name, str): + raise ValueError(f"`name` must be a valid string, but got: {name}.") if self._components_table is None: # init component and module mapping table self._components_table = self._find_classes_or_functions(self._find_module_names()) @@ -240,13 +242,15 @@ def resolve_module_name(self): config = dict(self.get_config()) path = config.get("") if path is not None: + if not isinstance(path, str): + raise ValueError(f"'' must be a string, but got: {path}.") if "" in config: warnings.warn(f"both '' and '', default to use '': {path}.") return path name = config.get("") - if name is None: - raise ValueError("must provide `` or `` of target component to instantiate.") + if not isinstance(name, str): + raise ValueError("must provide a string for `` or `` of target component to instantiate.") module = self.locator.get_component_module_name(name) if module is None: @@ -266,13 +270,13 @@ def resolve_args(self): """ return self.get_config().get("", {}) - def is_disabled(self, *_args, **_kwargs) -> bool: + def is_disabled(self) -> bool: # type: ignore """ Utility function used in `instantiate()` to check whether to skip the instantiation. """ _is_disabled = self.get_config().get("", False) - return _is_disabled.lower().strip() == "false" if isinstance(_is_disabled, str) else bool(_is_disabled) + return _is_disabled.lower().strip() == "true" if isinstance(_is_disabled, str) else bool(_is_disabled) def instantiate(self, **kwargs) -> object: # type: ignore """ @@ -283,7 +287,7 @@ def instantiate(self, **kwargs) -> object: # type: ignore kwargs: args to override / add the config args when instantiation. """ - if not ConfigComponent.is_instantiable(self.get_config()) or self.is_disabled(): + if not self.is_instantiable(self.get_config()) or self.is_disabled(): # if not a class or function or marked as `disabled`, skip parsing and return `None` return None @@ -324,19 +328,19 @@ def __init__(self, config: Any, id: Optional[str] = None, globals: Optional[Dict super().__init__(config=config, id=id) self.globals = globals - def execute(self, local_vars: Optional[Dict] = None): + def execute(self, locals: Optional[Dict] = None): """ Excute current config content and return the result if it is expression, based on python `eval()`. For more details: https://docs.python.org/3/library/functions.html#eval. Args: - local_vars: besides ``globals``, may also have some local variables used in the expression at runtime. + locals: besides ``globals``, may also have some local symbols used in the expression at runtime. """ value = self.get_config() if not ConfigExpression.is_expression(value): return None - return eval(value[1:], self.globals, local_vars) + return eval(value[1:], self.globals, locals) @staticmethod def is_expression(config: Union[Dict, List, str]) -> bool: diff --git a/tests/test_config_item.py b/tests/test_config_item.py index 2650cc930d..108e5d7aa6 100644 --- a/tests/test_config_item.py +++ b/tests/test_config_item.py @@ -31,21 +31,23 @@ TEST_CASE_3 = [{"": "monai.transforms.LoadImaged", "": {"keys": ["image"]}}, LoadImaged] # test `` TEST_CASE_4 = [{"": "LoadImaged", "": True, "": {"keys": ["image"]}}, dict] +# test `` +TEST_CASE_5 = [{"": "LoadImaged", "": "true", "": {"keys": ["image"]}}, dict] # test non-monai modules and excludes -TEST_CASE_5 = [ +TEST_CASE_6 = [ {"": "torch.optim.Adam", "": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, torch.optim.Adam, ] -TEST_CASE_6 = [{"": "decollate_batch", "": {"detach": True, "pad": True}}, partial] +TEST_CASE_7 = [{"": "decollate_batch", "": {"detach": True, "pad": True}}, partial] # test args contains "name" field -TEST_CASE_7 = [ +TEST_CASE_8 = [ {"": "RandTorchVisiond", "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}}, RandTorchVisiond, ] # test execute some function in args, test pre-imported global packages `monai` -TEST_CASE_8 = ["collate_fn", "$monai.data.list_data_collate"] +TEST_CASE_9 = ["collate_fn", "$monai.data.list_data_collate"] # test lambda function, should not execute the lambda function, just change the string -TEST_CASE_9 = ["collate_fn", "$lambda x: monai.data.list_data_collate(x) + torch.tensor(var)"] +TEST_CASE_10 = ["collate_fn", "$lambda x: monai.data.list_data_collate(x) + torch.tensor(var)"] class TestConfigItem(unittest.TestCase): @@ -58,7 +60,8 @@ def test_item(self, test_input, expected): self.assertEqual(item.get_config()["lr"], expected) @parameterized.expand( - [TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6] + ([TEST_CASE_7] if has_tv else []) + [TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7] + + ([TEST_CASE_8] if has_tv else []) ) def test_component(self, test_input, output_type): locator = ComponentLocator(excludes=["metrics"]) @@ -72,7 +75,7 @@ def test_component(self, test_input, output_type): if isinstance(ret, LoadImaged): self.assertEqual(ret.keys[0], "image") - @parameterized.expand([TEST_CASE_8, TEST_CASE_9]) + @parameterized.expand([TEST_CASE_9, TEST_CASE_10]) def test_expression(self, id, test_input): configer = ConfigExpression(id=id, config=test_input, globals={"monai": monai, "torch": torch}) var = 100 From 926848657646a59dcf7231d93818bc5d09ca2202 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 17 Feb 2022 20:30:25 +0800 Subject: [PATCH 25/25] [DLMED] update function name Signed-off-by: Nic Ma --- monai/apps/manifest/config_item.py | 2 +- tests/test_config_item.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/apps/manifest/config_item.py b/monai/apps/manifest/config_item.py index b47e231b0e..1f0c06c8fc 100644 --- a/monai/apps/manifest/config_item.py +++ b/monai/apps/manifest/config_item.py @@ -328,7 +328,7 @@ def __init__(self, config: Any, id: Optional[str] = None, globals: Optional[Dict super().__init__(config=config, id=id) self.globals = globals - def execute(self, locals: Optional[Dict] = None): + def evaluate(self, locals: Optional[Dict] = None): """ Excute current config content and return the result if it is expression, based on python `eval()`. For more details: https://docs.python.org/3/library/functions.html#eval. diff --git a/tests/test_config_item.py b/tests/test_config_item.py index 108e5d7aa6..b2c2fec6c6 100644 --- a/tests/test_config_item.py +++ b/tests/test_config_item.py @@ -79,7 +79,7 @@ def test_component(self, test_input, output_type): def test_expression(self, id, test_input): configer = ConfigExpression(id=id, config=test_input, globals={"monai": monai, "torch": torch}) var = 100 - ret = configer.execute(locals={"var": var}) + ret = configer.evaluate(locals={"var": var}) self.assertTrue(isinstance(ret, Callable)) def test_lazy_instantiation(self):