diff --git a/docs/source/apps.rst b/docs/source/apps.rst index f4f7aff2d2..f6c6ecb283 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -29,6 +29,16 @@ Clara MMARs :annotation: +Model Package +------------- + +.. autoclass:: ConfigParser + :members: + +.. autoclass:: ModuleScanner + :members: + + `Utilities` ----------- diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index 1df6d74f9d..241abac497 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -10,5 +10,15 @@ # limitations under the License. from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset -from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar +from .mmars import ( + MODEL_DESC, + ConfigParser, + ModuleScanner, + RemoteMMARKeys, + download_mmar, + get_class, + get_model_spec, + instantiate_class, + load_from_mmar, +) from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/mmars/__init__.py b/monai/apps/mmars/__init__.py index 396be2e87d..dd93c39d00 100644 --- a/monai/apps/mmars/__init__.py +++ b/monai/apps/mmars/__init__.py @@ -9,5 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .config_parser import ConfigParser, ModuleScanner from .mmars import download_mmar, get_model_spec, load_from_mmar from .model_desc import MODEL_DESC, RemoteMMARKeys +from .utils import get_class, instantiate_class diff --git a/monai/apps/mmars/config_parser.py b/monai/apps/mmars/config_parser.py new file mode 100644 index 0000000000..ef69dd837d --- /dev/null +++ b/monai/apps/mmars/config_parser.py @@ -0,0 +1,108 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import pkgutil +from typing import Dict, Sequence + +from monai.apps.mmars.utils import instantiate_class + + +class ModuleScanner: + """ + Scan all the available classes in the specified packages and modules. + Map the all the class names and the module names in a table. + + Args: + pkgs: the expected packages to scan. + modules: the expected modules in the packages to scan. + + """ + + def __init__(self, pkgs: Sequence[str], modules: Sequence[str]): + self.pkgs = pkgs + self.modules = modules + self._class_table = self._create_classes_table() + + def _create_classes_table(self): + class_table = {} + for pkg in self.pkgs: + package = importlib.import_module(pkg) + + for _, modname, _ in pkgutil.walk_packages(path=package.__path__, prefix=package.__name__ + "."): + if any(name in modname for name in self.modules): + try: + module = importlib.import_module(modname) + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and obj.__module__ == modname: + class_table[name] = modname + except ModuleNotFoundError: + pass + return class_table + + def get_module_name(self, class_name): + return self._class_table.get(class_name, None) + + +class ConfigParser: + """ + Parse dictionary format config and build components. + + Args: + pkgs: the expected packages to scan. + modules: the expected modules in the packages to scan. + + """ + + def __init__(self, pkgs: Sequence[str], modules: Sequence[str]): + self.module_scanner = ModuleScanner(pkgs=pkgs, modules=modules) + + def build_component(self, config: Dict) -> object: + """ + Build component instance based on the provided dictonary config. + Supported keys for the config: + - 'name' - class name in the modules of packages. + - 'path' - directly specify the class path, based on PYTHONPATH, ignore 'name' if specified. + - 'args' - arguments to initialize the component instance. + - 'disabled' - if defined `'disabled': true`, will skip the buiding, useful for development or tuning. + + Args: + config: dictionary config to define a component. + + Raises: + ValueError: must provide `path` or `name` of class to build component. + ValueError: can not find component class. + + """ + if not isinstance(config, dict): + raise ValueError("config of component must be a dictionary.") + + if config.get("disabled") is True: + # if marked as `disabled`, skip parsing + return None + + class_args = config.get("args", {}) + class_path = self._get_class_path(config) + return instantiate_class(class_path, **class_args) + + def _get_class_path(self, config): + class_path = config.get("path", None) + if class_path is None: + class_name = config.get("name", None) + if class_name is None: + raise ValueError("must provide `path` or `name` of class to build component.") + module_name = self.module_scanner.get_module_name(class_name) + if module_name is None: + raise ValueError(f"can not find component class '{class_name}'.") + class_path = f"{module_name}.{class_name}" + + return class_path diff --git a/monai/apps/mmars/utils.py b/monai/apps/mmars/utils.py new file mode 100644 index 0000000000..f6c156c9bd --- /dev/null +++ b/monai/apps/mmars/utils.py @@ -0,0 +1,61 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib + + +def get_class(class_path: str): + """ + Get the class from specified class path. + + Args: + class_path (str): full path of the class. + + Raises: + ValueError: invalid class_path, missing the module name. + ValueError: class does not exist. + ValueError: module does not exist. + + """ + if len(class_path.split(".")) < 2: + raise ValueError(f"invalid class_path: {class_path}, missing the module name.") + module_name, class_name = class_path.rsplit(".", 1) + + try: + module_ = importlib.import_module(module_name) + + try: + class_ = getattr(module_, class_name) + except AttributeError as e: + raise ValueError(f"class {class_name} does not exist.") from e + + except AttributeError as e: + raise ValueError(f"module {module_name} does not exist.") from e + + return class_ + + +def instantiate_class(class_path: str, **kwargs): + """ + Method for creating an instance for the specified class. + + Args: + class_path: full path of the class. + kwargs: arguments to initialize the class instance. + + Raises: + ValueError: class has paramenters error. + """ + + try: + return get_class(class_path)(**kwargs) + except TypeError as e: + raise ValueError(f"class {class_path} has parameters error.") from e diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py new file mode 100644 index 0000000000..c8d041d3b9 --- /dev/null +++ b/tests/test_config_parser.py @@ -0,0 +1,60 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.apps import ConfigParser +from monai.transforms import LoadImaged + +TEST_CASE_1 = [ + dict(pkgs=["monai"], modules=["transforms"]), + {"name": "LoadImaged", "args": {"keys": ["image"]}}, + LoadImaged, +] +# test python `path` +TEST_CASE_2 = [ + dict(pkgs=[], modules=[]), + {"path": "monai.transforms.LoadImaged", "args": {"keys": ["image"]}}, + LoadImaged, +] +# test `disabled` +TEST_CASE_3 = [ + dict(pkgs=["monai"], modules=["transforms"]), + {"name": "LoadImaged", "disabled": True, "args": {"keys": ["image"]}}, + None, +] +# test non-monai modules +TEST_CASE_4 = [ + dict(pkgs=["torch.optim", "monai"], modules=["adam"]), + {"name": "Adam", "args": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, + torch.optim.Adam, +] + + +class TestConfigParser(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + def test_type(self, input_param, test_input, output_type): + configer = ConfigParser(**input_param) + result = configer.build_component(test_input) + if result is not None: + self.assertTrue(isinstance(result, output_type)) + if isinstance(result, LoadImaged): + self.assertEqual(result.keys[0], "image") + else: + # test `disabled` works fine + self.assertEqual(result, output_type) + + +if __name__ == "__main__": + unittest.main()