diff --git a/docs/source/apps.rst b/docs/source/apps.rst index f4f7aff2d2..6535ad82b7 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -29,6 +29,22 @@ Clara MMARs :annotation: +Model Manifest +-------------- + +.. autoclass:: ComponentLocator + :members: + +.. autoclass:: ConfigComponent + :members: + +.. autoclass:: ConfigExpression + :members: + +.. autoclass:: ConfigItem + :members: + + `Utilities` ----------- diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index 893f7877d2..df085bddea 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -10,5 +10,6 @@ # limitations under the License. from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset +from .manifest import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/manifest/__init__.py b/monai/apps/manifest/__init__.py new file mode 100644 index 0000000000..d8919a5249 --- /dev/null +++ b/monai/apps/manifest/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem diff --git a/monai/apps/manifest/config_item.py b/monai/apps/manifest/config_item.py new file mode 100644 index 0000000000..1f0c06c8fc --- /dev/null +++ b/monai/apps/manifest/config_item.py @@ -0,0 +1,355 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import sys +import warnings +from abc import ABC, abstractmethod +from importlib import import_module +from typing import Any, Dict, List, Mapping, Optional, Sequence, Union + +from monai.utils import ensure_tuple, instantiate + +__all__ = ["ComponentLocator", "ConfigItem", "ConfigExpression", "ConfigComponent"] + + +class Instantiable(ABC): + """ + Base class for instantiable object with module name and arguments. + + .. code-block:: python + + if not is_disabled(): + instantiate(module_name=resolve_module_name(), args=resolve_args()) + + """ + + @abstractmethod + def resolve_module_name(self, *args: Any, **kwargs: Any): + """ + Resolve the target module name, it should return an object class (or function) to be instantiated. + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def resolve_args(self, *args: Any, **kwargs: Any): + """ + Resolve the arguments, it should return arguments to be passed to the object when instantiating. + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def is_disabled(self, *args: Any, **kwargs: Any) -> bool: + """ + Return a boolean flag to indicate whether the object should be instantiated. + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def instantiate(self, *args: Any, **kwargs: Any): + """ + Instantiate the target component. + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + +class ComponentLocator: + """ + Scan all the available classes and functions in the MONAI package and map them with the module paths in a table. + It's used to locate the module path for provided component name. + + Args: + excludes: if any string of the `excludes` exists in the full module name, don't import this module. + + """ + + MOD_START = "monai" + + def __init__(self, excludes: Optional[Union[Sequence[str], str]] = None): + self.excludes = [] if excludes is None else ensure_tuple(excludes) + self._components_table: Optional[Dict[str, List]] = None + + def _find_module_names(self) -> List[str]: + """ + Find all the modules start with MOD_START and don't contain any of `excludes`. + + """ + return [ + m for m in sys.modules.keys() if m.startswith(self.MOD_START) and all(s not in m for s in self.excludes) + ] + + def _find_classes_or_functions(self, modnames: Union[Sequence[str], str]) -> Dict[str, List]: + """ + Find all the classes and functions in the modules with specified `modnames`. + + Args: + modnames: names of the target modules to find all the classes and functions. + + """ + table: Dict[str, List] = {} + # all the MONAI modules are already loaded by `load_submodules` + for modname in ensure_tuple(modnames): + try: + # scan all the classes and functions in the module + module = import_module(modname) + for name, obj in inspect.getmembers(module): + if (inspect.isclass(obj) or inspect.isfunction(obj)) and obj.__module__ == modname: + if name not in table: + table[name] = [] + table[name].append(modname) + except ModuleNotFoundError: + pass + return table + + def get_component_module_name(self, name: str) -> Optional[Union[List[str], str]]: + """ + Get the full module name of the class or function with specified ``name``. + If target component name exists in multiple packages or modules, return a list of full module names. + + Args: + name: name of the expected class or function. + + """ + if not isinstance(name, str): + raise ValueError(f"`name` must be a valid string, but got: {name}.") + if self._components_table is None: + # init component and module mapping table + self._components_table = self._find_classes_or_functions(self._find_module_names()) + + mods: Optional[Union[List[str], str]] = self._components_table.get(name, None) + if isinstance(mods, list) and len(mods) == 1: + mods = mods[0] + return mods + + +class ConfigItem: + """ + Basic data structure to represent a configuration item. + + A `ConfigItem` instance can optionally have a string id, so that other items can refer to it. + It has a build-in `config` property to store the configuration object. + + Args: + config: content of a config item, can be objects of any types, + a configuration resolver may interpret the content to generate a configuration object. + id: optional name of the current config item, defaults to `None`. + + """ + + def __init__(self, config: Any, id: Optional[str] = None) -> None: + self.config = config + self.id = id + + def get_id(self) -> Optional[str]: + """ + Get the ID name of current config item, useful to identify config items during parsing. + + """ + return self.id + + def update_config(self, config: Any): + """ + Replace the content of `self.config` with new `config`. + A typical usage is to modify the initial config content at runtime. + + Args: + config: content of a `ConfigItem`. + + """ + self.config = config + + def get_config(self): + """ + Get the config content of current config item. + + """ + return self.config + + +class ConfigComponent(ConfigItem, Instantiable): + """ + Subclass of :py:class:`monai.apps.ConfigItem`, this class uses a dictionary with string keys to + represent a component of `class` or `function` and supports instantiation. + + Currently, four special keys (strings surrounded by ``<>``) are defined and interpreted beyond the regular literals: + + - class or function identifier of the python module, specified by one of the two keys. + - ``""``: indicates build-in python classes or functions such as "LoadImageDict". + - ``""``: full module name, such as "monai.transforms.LoadImageDict". + - ``""``: input arguments to the python module. + - ``""``: a flag to indicate whether to skip the instantiation. + + .. code-block:: python + + locator = ComponentLocator(excludes=["modules_to_exclude"]) + config = { + "": "LoadImaged", + "": { + "keys": ["image", "label"] + } + } + + configer = ConfigComponent(config, id="test", locator=locator) + image_loader = configer.instantiate() + print(image_loader) # + + Args: + config: content of a config item. + id: optional name of the current config item, defaults to `None`. + locator: a ``ComponentLocator`` to convert a module name string into the actual python module. + if `None`, a ``ComponentLocator(excludes=excludes)`` will be used. + excludes: if ``locator`` is None, create a new ``ComponentLocator`` with ``excludes``. + See also: :py:class:`monai.apps.manifest.ComponentLocator`. + + """ + + def __init__( + self, + config: Any, + id: Optional[str] = None, + locator: Optional[ComponentLocator] = None, + excludes: Optional[Union[Sequence[str], str]] = None, + ) -> None: + super().__init__(config=config, id=id) + self.locator = ComponentLocator(excludes=excludes) if locator is None else locator + + @staticmethod + def is_instantiable(config: Any) -> bool: + """ + Check whether this config represents a `class` or `function` that is to be instantiated. + + Args: + config: input config content to check. + + """ + return isinstance(config, Mapping) and ("" in config or "" in config) + + def resolve_module_name(self): + """ + Resolve the target module name from current config content. + The config content must have ``""`` or ``""``. + When both are specified, ``""`` will be used. + + """ + config = dict(self.get_config()) + path = config.get("") + if path is not None: + if not isinstance(path, str): + raise ValueError(f"'' must be a string, but got: {path}.") + if "" in config: + warnings.warn(f"both '' and '', default to use '': {path}.") + return path + + name = config.get("") + if not isinstance(name, str): + raise ValueError("must provide a string for `` or `` of target component to instantiate.") + + module = self.locator.get_component_module_name(name) + if module is None: + raise ModuleNotFoundError(f"can not find component '{name}' in {self.locator.MOD_START} modules.") + if isinstance(module, list): + warnings.warn( + f"there are more than 1 component have name `{name}`: {module}, use the first one `{module[0]}." + f" if want to use others, please set its module path in `` directly." + ) + module = module[0] + return f"{module}.{name}" + + def resolve_args(self): + """ + Utility function used in `instantiate()` to resolve the arguments from current config content. + + """ + return self.get_config().get("", {}) + + def is_disabled(self) -> bool: # type: ignore + """ + Utility function used in `instantiate()` to check whether to skip the instantiation. + + """ + _is_disabled = self.get_config().get("", False) + return _is_disabled.lower().strip() == "true" if isinstance(_is_disabled, str) else bool(_is_disabled) + + def instantiate(self, **kwargs) -> object: # type: ignore + """ + Instantiate component based on ``self.config`` content. + The target component must be a `class` or a `function`, otherwise, return `None`. + + Args: + kwargs: args to override / add the config args when instantiation. + + """ + if not self.is_instantiable(self.get_config()) or self.is_disabled(): + # if not a class or function or marked as `disabled`, skip parsing and return `None` + return None + + modname = self.resolve_module_name() + args = self.resolve_args() + args.update(kwargs) + return instantiate(modname, **args) + + +class ConfigExpression(ConfigItem): + """ + Subclass of :py:class:`monai.apps.ConfigItem`, the `ConfigItem` represents an executable expression + (execute based on ``eval()``). + + See also: + + - https://docs.python.org/3/library/functions.html#eval. + + For example: + + .. code-block:: python + + import monai + from monai.apps.manifest import ConfigExpression + + config = "$monai.__version__" + expression = ConfigExpression(config, id="test", globals={"monai": monai}) + print(expression.execute()) + + Args: + config: content of a config item. + id: optional name of current config item, defaults to `None`. + globals: additional global context to evaluate the string. + + """ + + def __init__(self, config: Any, id: Optional[str] = None, globals: Optional[Dict] = None) -> None: + super().__init__(config=config, id=id) + self.globals = globals + + def evaluate(self, locals: Optional[Dict] = None): + """ + Excute current config content and return the result if it is expression, based on python `eval()`. + For more details: https://docs.python.org/3/library/functions.html#eval. + + Args: + locals: besides ``globals``, may also have some local symbols used in the expression at runtime. + + """ + value = self.get_config() + if not ConfigExpression.is_expression(value): + return None + return eval(value[1:], self.globals, locals) + + @staticmethod + def is_expression(config: Union[Dict, List, str]) -> bool: + """ + Check whether the config is an executable expression string. + Currently A string starts with ``"$"`` character is interpreted as an expression. + + Args: + config: input config content to check. + + """ + return isinstance(config, str) and config.startswith("$") diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index edd3fef887..636ea15c8d 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -70,6 +70,7 @@ get_full_type_name, get_package_version, get_torch_version_tuple, + instantiate, load_submodules, look_up_option, min_version, diff --git a/monai/utils/module.py b/monai/utils/module.py index b21828fbbb..8b7745c3ee 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -10,13 +10,15 @@ # limitations under the License. import enum +import inspect import os import re import sys import warnings -from functools import wraps +from functools import partial, wraps from importlib import import_module from pkgutil import walk_packages +from pydoc import locate from re import match from types import FunctionType from typing import Any, Callable, Collection, Hashable, Iterable, List, Mapping, Tuple, Union, cast @@ -36,6 +38,7 @@ "optional_import", "require_pkg", "load_submodules", + "instantiate", "get_full_type_name", "get_package_version", "get_torch_version_tuple", @@ -193,7 +196,36 @@ def load_submodules(basemod, load_all: bool = True, exclude_pattern: str = "(.*[ return submodules, err_mod +def instantiate(path: str, **kwargs): + """ + Create an object instance or partial function from a class or function represented by string. + `kwargs` will be part of the input arguments to the class constructor or function. + The target component must be a class or a function, if not, return the component directly. + + Args: + path: full path of the target class or function component. + kwargs: arguments to initialize the class instance or set default args + for `partial` function. + + """ + + component = locate(path) + if component is None: + raise ModuleNotFoundError(f"Cannot locate '{path}'.") + if inspect.isclass(component): + return component(**kwargs) + if inspect.isfunction(component): + return partial(component, **kwargs) + + warnings.warn(f"Component to instantiate must represent a valid class or function, but got {path}.") + return component + + def get_full_type_name(typeobj): + """ + Utility to get the full path name of a class or object type. + + """ module = typeobj.__module__ if module is None or module == str.__class__.__module__: return typeobj.__name__ # Avoid reporting __builtin__ diff --git a/tests/test_component_locator.py b/tests/test_component_locator.py new file mode 100644 index 0000000000..eafb2152d1 --- /dev/null +++ b/tests/test_component_locator.py @@ -0,0 +1,35 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from pydoc import locate + +from monai.apps.manifest import ComponentLocator +from monai.utils import optional_import + +_, has_ignite = optional_import("ignite") + + +class TestComponentLocator(unittest.TestCase): + def test_locate(self): + locator = ComponentLocator(excludes=None if has_ignite else ["monai.handlers"]) + # test init mapping table and get the module path of component + self.assertEqual(locator.get_component_module_name("LoadImage"), "monai.transforms.io.array") + self.assertGreater(len(locator._components_table), 0) + for _, mods in locator._components_table.items(): + for i in mods: + self.assertGreater(len(mods), 0) + # ensure we can locate all the items by `name` + self.assertIsNotNone(locate(i), msg=f"can not locate target: {i}.") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_config_item.py b/tests/test_config_item.py new file mode 100644 index 0000000000..b2c2fec6c6 --- /dev/null +++ b/tests/test_config_item.py @@ -0,0 +1,99 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import Callable + +import torch +from parameterized import parameterized + +import monai +from monai.apps import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem +from monai.data import DataLoader, Dataset +from monai.transforms import LoadImaged, RandTorchVisiond +from monai.utils import optional_import + +_, has_tv = optional_import("torchvision") + +TEST_CASE_1 = [{"lr": 0.001}, 0.0001] + +TEST_CASE_2 = [{"": "LoadImaged", "": {"keys": ["image"]}}, LoadImaged] +# test python `` +TEST_CASE_3 = [{"": "monai.transforms.LoadImaged", "": {"keys": ["image"]}}, LoadImaged] +# test `` +TEST_CASE_4 = [{"": "LoadImaged", "": True, "": {"keys": ["image"]}}, dict] +# test `` +TEST_CASE_5 = [{"": "LoadImaged", "": "true", "": {"keys": ["image"]}}, dict] +# test non-monai modules and excludes +TEST_CASE_6 = [ + {"": "torch.optim.Adam", "": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, + torch.optim.Adam, +] +TEST_CASE_7 = [{"": "decollate_batch", "": {"detach": True, "pad": True}}, partial] +# test args contains "name" field +TEST_CASE_8 = [ + {"": "RandTorchVisiond", "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}}, + RandTorchVisiond, +] +# test execute some function in args, test pre-imported global packages `monai` +TEST_CASE_9 = ["collate_fn", "$monai.data.list_data_collate"] +# test lambda function, should not execute the lambda function, just change the string +TEST_CASE_10 = ["collate_fn", "$lambda x: monai.data.list_data_collate(x) + torch.tensor(var)"] + + +class TestConfigItem(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) + def test_item(self, test_input, expected): + item = ConfigItem(config=test_input) + conf = item.get_config() + conf["lr"] = 0.0001 + item.update_config(config=conf) + self.assertEqual(item.get_config()["lr"], expected) + + @parameterized.expand( + [TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7] + + ([TEST_CASE_8] if has_tv else []) + ) + def test_component(self, test_input, output_type): + locator = ComponentLocator(excludes=["metrics"]) + configer = ConfigComponent(id="test", config=test_input, locator=locator) + ret = configer.instantiate() + if test_input.get("", False): + # test `` works fine + self.assertEqual(ret, None) + return + self.assertTrue(isinstance(ret, output_type)) + if isinstance(ret, LoadImaged): + self.assertEqual(ret.keys[0], "image") + + @parameterized.expand([TEST_CASE_9, TEST_CASE_10]) + def test_expression(self, id, test_input): + configer = ConfigExpression(id=id, config=test_input, globals={"monai": monai, "torch": torch}) + var = 100 + ret = configer.evaluate(locals={"var": var}) + self.assertTrue(isinstance(ret, Callable)) + + def test_lazy_instantiation(self): + config = {"": "DataLoader", "": {"dataset": Dataset(data=[1, 2]), "batch_size": 2}} + configer = ConfigComponent(config=config, locator=None) + init_config = configer.get_config() + # modify config content at runtime + init_config[""]["batch_size"] = 4 + configer.update_config(config=init_config) + + ret = configer.instantiate() + self.assertTrue(isinstance(ret, DataLoader)) + self.assertEqual(ret.batch_size, 4) + + +if __name__ == "__main__": + unittest.main()