From 3c9d2d1602888d8d7ecc3e273628239f9d391f7f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 6 Jan 2022 16:31:31 +0800 Subject: [PATCH 01/30] 3482 Build component instance from dictionary config (#3518) * [DLMED] add ConfigParser Signed-off-by: Nic Ma * [DLMED] add more doc-string Signed-off-by: Nic Ma * [DLMED] add unit tests Signed-off-by: Nic Ma * [DLMED] fix CI error Signed-off-by: Nic Ma * [DLMED] fix test error Signed-off-by: Nic Ma * [DLMED] skip for windows Signed-off-by: Nic Ma * [DLMED] fix windows test Signed-off-by: Nic Ma --- docs/source/apps.rst | 10 +++ monai/apps/__init__.py | 12 +++- monai/apps/mmars/__init__.py | 2 + monai/apps/mmars/config_parser.py | 108 ++++++++++++++++++++++++++++++ monai/apps/mmars/utils.py | 61 +++++++++++++++++ tests/test_config_parser.py | 60 +++++++++++++++++ 6 files changed, 252 insertions(+), 1 deletion(-) create mode 100644 monai/apps/mmars/config_parser.py create mode 100644 monai/apps/mmars/utils.py create mode 100644 tests/test_config_parser.py diff --git a/docs/source/apps.rst b/docs/source/apps.rst index f4f7aff2d2..f6c6ecb283 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -29,6 +29,16 @@ Clara MMARs :annotation: +Model Package +------------- + +.. autoclass:: ConfigParser + :members: + +.. autoclass:: ModuleScanner + :members: + + `Utilities` ----------- diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index 893f7877d2..f021588a9f 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -10,5 +10,15 @@ # limitations under the License. from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset -from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar +from .mmars import ( + MODEL_DESC, + ConfigParser, + ModuleScanner, + RemoteMMARKeys, + download_mmar, + get_class, + get_model_spec, + instantiate_class, + load_from_mmar, +) from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/mmars/__init__.py b/monai/apps/mmars/__init__.py index 8f1448bb06..d6d70d1d70 100644 --- a/monai/apps/mmars/__init__.py +++ b/monai/apps/mmars/__init__.py @@ -9,5 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .config_parser import ConfigParser, ModuleScanner from .mmars import download_mmar, get_model_spec, load_from_mmar from .model_desc import MODEL_DESC, RemoteMMARKeys +from .utils import get_class, instantiate_class diff --git a/monai/apps/mmars/config_parser.py b/monai/apps/mmars/config_parser.py new file mode 100644 index 0000000000..ef69dd837d --- /dev/null +++ b/monai/apps/mmars/config_parser.py @@ -0,0 +1,108 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import pkgutil +from typing import Dict, Sequence + +from monai.apps.mmars.utils import instantiate_class + + +class ModuleScanner: + """ + Scan all the available classes in the specified packages and modules. + Map the all the class names and the module names in a table. + + Args: + pkgs: the expected packages to scan. + modules: the expected modules in the packages to scan. + + """ + + def __init__(self, pkgs: Sequence[str], modules: Sequence[str]): + self.pkgs = pkgs + self.modules = modules + self._class_table = self._create_classes_table() + + def _create_classes_table(self): + class_table = {} + for pkg in self.pkgs: + package = importlib.import_module(pkg) + + for _, modname, _ in pkgutil.walk_packages(path=package.__path__, prefix=package.__name__ + "."): + if any(name in modname for name in self.modules): + try: + module = importlib.import_module(modname) + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and obj.__module__ == modname: + class_table[name] = modname + except ModuleNotFoundError: + pass + return class_table + + def get_module_name(self, class_name): + return self._class_table.get(class_name, None) + + +class ConfigParser: + """ + Parse dictionary format config and build components. + + Args: + pkgs: the expected packages to scan. + modules: the expected modules in the packages to scan. + + """ + + def __init__(self, pkgs: Sequence[str], modules: Sequence[str]): + self.module_scanner = ModuleScanner(pkgs=pkgs, modules=modules) + + def build_component(self, config: Dict) -> object: + """ + Build component instance based on the provided dictonary config. + Supported keys for the config: + - 'name' - class name in the modules of packages. + - 'path' - directly specify the class path, based on PYTHONPATH, ignore 'name' if specified. + - 'args' - arguments to initialize the component instance. + - 'disabled' - if defined `'disabled': true`, will skip the buiding, useful for development or tuning. + + Args: + config: dictionary config to define a component. + + Raises: + ValueError: must provide `path` or `name` of class to build component. + ValueError: can not find component class. + + """ + if not isinstance(config, dict): + raise ValueError("config of component must be a dictionary.") + + if config.get("disabled") is True: + # if marked as `disabled`, skip parsing + return None + + class_args = config.get("args", {}) + class_path = self._get_class_path(config) + return instantiate_class(class_path, **class_args) + + def _get_class_path(self, config): + class_path = config.get("path", None) + if class_path is None: + class_name = config.get("name", None) + if class_name is None: + raise ValueError("must provide `path` or `name` of class to build component.") + module_name = self.module_scanner.get_module_name(class_name) + if module_name is None: + raise ValueError(f"can not find component class '{class_name}'.") + class_path = f"{module_name}.{class_name}" + + return class_path diff --git a/monai/apps/mmars/utils.py b/monai/apps/mmars/utils.py new file mode 100644 index 0000000000..f6c156c9bd --- /dev/null +++ b/monai/apps/mmars/utils.py @@ -0,0 +1,61 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib + + +def get_class(class_path: str): + """ + Get the class from specified class path. + + Args: + class_path (str): full path of the class. + + Raises: + ValueError: invalid class_path, missing the module name. + ValueError: class does not exist. + ValueError: module does not exist. + + """ + if len(class_path.split(".")) < 2: + raise ValueError(f"invalid class_path: {class_path}, missing the module name.") + module_name, class_name = class_path.rsplit(".", 1) + + try: + module_ = importlib.import_module(module_name) + + try: + class_ = getattr(module_, class_name) + except AttributeError as e: + raise ValueError(f"class {class_name} does not exist.") from e + + except AttributeError as e: + raise ValueError(f"module {module_name} does not exist.") from e + + return class_ + + +def instantiate_class(class_path: str, **kwargs): + """ + Method for creating an instance for the specified class. + + Args: + class_path: full path of the class. + kwargs: arguments to initialize the class instance. + + Raises: + ValueError: class has paramenters error. + """ + + try: + return get_class(class_path)(**kwargs) + except TypeError as e: + raise ValueError(f"class {class_path} has parameters error.") from e diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py new file mode 100644 index 0000000000..c8d041d3b9 --- /dev/null +++ b/tests/test_config_parser.py @@ -0,0 +1,60 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.apps import ConfigParser +from monai.transforms import LoadImaged + +TEST_CASE_1 = [ + dict(pkgs=["monai"], modules=["transforms"]), + {"name": "LoadImaged", "args": {"keys": ["image"]}}, + LoadImaged, +] +# test python `path` +TEST_CASE_2 = [ + dict(pkgs=[], modules=[]), + {"path": "monai.transforms.LoadImaged", "args": {"keys": ["image"]}}, + LoadImaged, +] +# test `disabled` +TEST_CASE_3 = [ + dict(pkgs=["monai"], modules=["transforms"]), + {"name": "LoadImaged", "disabled": True, "args": {"keys": ["image"]}}, + None, +] +# test non-monai modules +TEST_CASE_4 = [ + dict(pkgs=["torch.optim", "monai"], modules=["adam"]), + {"name": "Adam", "args": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, + torch.optim.Adam, +] + + +class TestConfigParser(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + def test_type(self, input_param, test_input, output_type): + configer = ConfigParser(**input_param) + result = configer.build_component(test_input) + if result is not None: + self.assertTrue(isinstance(result, output_type)) + if isinstance(result, LoadImaged): + self.assertEqual(result.keys[0], "image") + else: + # test `disabled` works fine + self.assertEqual(result, output_type) + + +if __name__ == "__main__": + unittest.main() From 85f679772e0aa57d126d62ab69d411c18094954e Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 19 Jan 2022 19:56:45 +0800 Subject: [PATCH 02/30] [DLMED] add ConfigResolver ConfigComponent Signed-off-by: Nic Ma --- monai/apps/mmars/__init__.py | 5 +- monai/apps/mmars/config_parser.py | 130 +++++++++--------------- monai/apps/mmars/config_resolver.py | 151 ++++++++++++++++++++++++++++ monai/apps/mmars/utils.py | 49 +++++++++ tests/test_config_parser.py | 3 +- 5 files changed, 250 insertions(+), 88 deletions(-) create mode 100644 monai/apps/mmars/config_resolver.py diff --git a/monai/apps/mmars/__init__.py b/monai/apps/mmars/__init__.py index d6d70d1d70..f357635598 100644 --- a/monai/apps/mmars/__init__.py +++ b/monai/apps/mmars/__init__.py @@ -9,7 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_parser import ConfigParser, ModuleScanner +from .config_parser import ConfigParser +from .config_resolver import ModuleScanner, ConfigComponent, ConfigResolver from .mmars import download_mmar, get_model_spec, load_from_mmar from .model_desc import MODEL_DESC, RemoteMMARKeys -from .utils import get_class, instantiate_class +from .utils import get_class, instantiate_class, search_configs_with_objs, update_configs_with_objs diff --git a/monai/apps/mmars/config_parser.py b/monai/apps/mmars/config_parser.py index ef69dd837d..b0cc5f30bf 100644 --- a/monai/apps/mmars/config_parser.py +++ b/monai/apps/mmars/config_parser.py @@ -9,48 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import importlib -import inspect -import pkgutil -from typing import Dict, Sequence - -from monai.apps.mmars.utils import instantiate_class - - -class ModuleScanner: - """ - Scan all the available classes in the specified packages and modules. - Map the all the class names and the module names in a table. - - Args: - pkgs: the expected packages to scan. - modules: the expected modules in the packages to scan. - - """ - - def __init__(self, pkgs: Sequence[str], modules: Sequence[str]): - self.pkgs = pkgs - self.modules = modules - self._class_table = self._create_classes_table() - - def _create_classes_table(self): - class_table = {} - for pkg in self.pkgs: - package = importlib.import_module(pkg) - - for _, modname, _ in pkgutil.walk_packages(path=package.__path__, prefix=package.__name__ + "."): - if any(name in modname for name in self.modules): - try: - module = importlib.import_module(modname) - for name, obj in inspect.getmembers(module): - if inspect.isclass(obj) and obj.__module__ == modname: - class_table[name] = modname - except ModuleNotFoundError: - pass - return class_table - - def get_module_name(self, class_name): - return self._class_table.get(class_name, None) +from typing import Any, Dict, Optional, Sequence +from config_resolver import ConfigComponent, ConfigResolver class ConfigParser: @@ -63,46 +23,46 @@ class ConfigParser: """ - def __init__(self, pkgs: Sequence[str], modules: Sequence[str]): - self.module_scanner = ModuleScanner(pkgs=pkgs, modules=modules) - - def build_component(self, config: Dict) -> object: - """ - Build component instance based on the provided dictonary config. - Supported keys for the config: - - 'name' - class name in the modules of packages. - - 'path' - directly specify the class path, based on PYTHONPATH, ignore 'name' if specified. - - 'args' - arguments to initialize the component instance. - - 'disabled' - if defined `'disabled': true`, will skip the buiding, useful for development or tuning. - - Args: - config: dictionary config to define a component. - - Raises: - ValueError: must provide `path` or `name` of class to build component. - ValueError: can not find component class. - - """ - if not isinstance(config, dict): - raise ValueError("config of component must be a dictionary.") - - if config.get("disabled") is True: - # if marked as `disabled`, skip parsing - return None - - class_args = config.get("args", {}) - class_path = self._get_class_path(config) - return instantiate_class(class_path, **class_args) - - def _get_class_path(self, config): - class_path = config.get("path", None) - if class_path is None: - class_name = config.get("name", None) - if class_name is None: - raise ValueError("must provide `path` or `name` of class to build component.") - module_name = self.module_scanner.get_module_name(class_name) - if module_name is None: - raise ValueError(f"can not find component class '{class_name}'.") - class_path = f"{module_name}.{class_name}" - - return class_path + def __init__(self, pkgs: Sequence[str], modules: Sequence[str], config: Optional[Dict] = None): + self.pkgs = pkgs + self.modules = modules + self.config = {} + if isinstance(config, dict): + self.set_config(config=config) + self.config_resolver: Optional[ConfigResolver] = None + self.resolved = False + + def set_config(self, config: Any, path: Optional[str] = None): + if isinstance(path, str): + keys = path.split(".") + config = self.config + for k in keys[:-1]: + config = config[k] + config[keys[-1]] = config + else: + self.config = config + self.resolved = False + + def get_config(self, config: Dict, path: Optional[str] = None): + if isinstance(path, str): + keys = path.split(".") + config = self.config + for k in keys[:-1]: + config = config[k] + return config[keys[-1]] + return self.config + + def resolve_config(self, resolve_all: bool = False): + self.config_resolver = ConfigResolver() + for k, v in self.config.items(): + # only prepare the components, lazy instantiation + # FIXME: only support "@" reference in top level config for now + self.config_resolver.update(ConfigComponent(id=k, config=v, pkgs=self.pkgs, modules=self.modules)) + if resolve_all: + self.config_resolver.resolve_all() + self.resolved = True + + def get_instance(self, id: str): + if self.config_resolver is None or not self.resolved: + self.resolve_config() + return self.config_resolver.resolve_one_object(id=id) diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py new file mode 100644 index 0000000000..9595989f5c --- /dev/null +++ b/monai/apps/mmars/config_resolver.py @@ -0,0 +1,151 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Sequence +import importlib +import inspect +import pkgutil + +from monai.apps.mmars.utils import instantiate_class, search_configs_with_objs, update_configs_with_objs + + +class ModuleScanner: + """ + Scan all the available classes in the specified packages and modules. + Map the all the class names and the module names in a table. + + Args: + pkgs: the expected packages to scan. + modules: the expected modules in the packages to scan. + + """ + + def __init__(self, pkgs: Sequence[str], modules: Sequence[str]): + self.pkgs = pkgs + self.modules = modules + self._class_table = self._create_classes_table() + + def _create_classes_table(self): + class_table = {} + for pkg in self.pkgs: + package = importlib.import_module(pkg) + + for _, modname, _ in pkgutil.walk_packages(path=package.__path__, prefix=package.__name__ + "."): + if any(name in modname for name in self.modules): + try: + module = importlib.import_module(modname) + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and obj.__module__ == modname: + class_table[name] = modname + except ModuleNotFoundError: + pass + return class_table + + def get_module_name(self, class_name): + return self._class_table.get(class_name, None) + + +class ConfigComponent: + def __init__(self, id: str, config: Dict, pkgs: Sequence[str], modules: Sequence[str]) -> None: + self.id = id + self.config = config + self.module_scanner = ModuleScanner(pkgs=pkgs, modules=modules) + + def get_id(self) -> str: + return self.id + + def get_referenced_ids(self) -> List[str]: + return search_configs_with_objs(self.config, []) + + def get_instance(self, refs: dict): + config = update_configs_with_objs(self.config, refs) + return self.build(config) if isinstance(config, dict) and ("name" in config or "path" in config) else config + + def build(self, config: Dict) -> object: + """ + Build component instance based on the provided dictonary config. + Supported keys for the config: + - 'name' - class name in the modules of packages. + - 'path' - directly specify the class path, based on PYTHONPATH, ignore 'name' if specified. + - 'args' - arguments to initialize the component instance. + - 'disabled' - if defined `'disabled': true`, will skip the buiding, useful for development or tuning. + + Args: + config: dictionary config to define a component. + + Raises: + ValueError: must provide `path` or `name` of class to build component. + ValueError: can not find component class. + + """ + if not isinstance(config, dict): + raise ValueError("config of component must be a dictionary.") + + if config.get("disabled") is True: + # if marked as `disabled`, skip parsing + return None + + class_args = config.get("args", {}) + class_path = self._get_class_path(config) + return instantiate_class(class_path, **class_args) + + def _get_class_path(self, config): + class_path = config.get("path", None) + if class_path is None: + class_name = config.get("name", None) + if class_name is None: + raise ValueError("must provide `path` or `name` of class to build component.") + module_name = self.module_scanner.get_module_name(class_name) + if module_name is None: + raise ValueError(f"can not find component class '{class_name}'.") + class_path = f"{module_name}.{class_name}" + + return class_path + + +class ConfigResolver: + def __init__(self, components: Optional[Dict[str, ConfigComponent]] = None): + self.resolved = {} + self.components = {} if components is None else components + + def update(self, component: ConfigComponent): + self.components[component.get_id()] = component + + def resolve_one_object(self, id: str) -> bool: + obj = self.components[id] + # check whether the obj has any unresolved refs in its args + ref_ids = obj.get_referenced_ids() + if not ref_ids: + # this object does not reference others + resolved_obj = obj.get_instance([]) + else: + # see whether all refs are resolved + refs = {} + for comp_id in ref_ids: + if comp_id not in self.resolved: + # this referenced object is not resolved + if comp_id not in self.components: + raise RuntimeError(f"the reference component `{comp_id}` is not in config.") + # resolve the dependency first + self.resolve_one_object(id=comp_id) + refs[comp_id] = self.resolved[comp_id] + # all referenced objects are resolved already + resolved_obj = obj.get_instance(refs) + + self.resolved[id] = resolved_obj + return resolved_obj + + def resolve_all(self): + for v in self.components.values(): + self.resolve_one_object(obj=v) + + def get_resolved(self, id: str): + return self.resolved[id] diff --git a/monai/apps/mmars/utils.py b/monai/apps/mmars/utils.py index f6c156c9bd..e768215fde 100644 --- a/monai/apps/mmars/utils.py +++ b/monai/apps/mmars/utils.py @@ -10,6 +10,8 @@ # limitations under the License. import importlib +import re +from typing import List, Union def get_class(class_path: str): @@ -59,3 +61,50 @@ def instantiate_class(class_path: str, **kwargs): return get_class(class_path)(**kwargs) except TypeError as e: raise ValueError(f"class {class_path} has parameters error.") from e + + +def search_configs_with_objs(configs: Union[dict, list, str], refs: List[str]): + pattern = re.compile(r'@\w*') + if isinstance(configs, list): + for i in configs: + refs = search_configs_with_objs(i, refs) + elif isinstance(configs, dict): + for _, v in configs.items(): + refs = search_configs_with_objs(v, refs) + elif isinstance(configs, str): + result = pattern.findall(configs) + for item in result: + # only parse `@` for: `@object`, `lambda ...`, `#lambda ...` + if configs.startswith("#") or configs.startswith("lambda") or configs == item: + ref_obj_id = item[1:] + if ref_obj_id not in refs: + refs.append(ref_obj_id) + return refs + + +def update_configs_with_objs(configs: Union[dict, list, str], refs: dict): + pattern = re.compile(r'@\w*') + if isinstance(configs, list): + configs = [update_configs_with_objs(i, refs) for i in configs] + elif isinstance(configs, dict): + configs = {k: update_configs_with_objs(v, refs) for k, v in configs.items()} + elif isinstance(configs, str): + result = pattern.findall(configs) + for item in result: + ref_obj_id = item[1:] + # only parse `@` for: `@object`, `lambda ...`, `#lambda ...` + if configs.startswith("lambda") or configs.startswith("#lambda"): + # if using @object in a lambda function, only support to convert the item to f-string + configs = configs.replace(item, f"{refs[ref_obj_id]}") + elif configs.startswith("#"): + # replace with local code and execute soon + configs = configs.replace(item, f"refs['{ref_obj_id}']") + elif configs == item: + configs = refs[ref_obj_id] + + if isinstance(configs, str): + if configs.startswith("#"): + configs = eval(configs[1:]) + elif configs.startswith("lambda"): + configs = eval(configs) + return configs diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index c8d041d3b9..6b077f5a76 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -46,7 +46,8 @@ class TestConfigParser(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_type(self, input_param, test_input, output_type): configer = ConfigParser(**input_param) - result = configer.build_component(test_input) + configer.set_config({"test": test_input}) + result = configer.get_instance("test") if result is not None: self.assertTrue(isinstance(result, output_type)) if isinstance(result, LoadImaged): From ba629b3aae5f8ec4d526415a4db84e5f57952fd0 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 19 Jan 2022 20:02:03 +0800 Subject: [PATCH 03/30] [DLMED] add to doc Signed-off-by: Nic Ma --- docs/source/apps.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index f6c6ecb283..783709a515 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -38,6 +38,12 @@ Model Package .. autoclass:: ModuleScanner :members: +.. autoclass:: ConfigComponent + :members: + +.. autoclass:: ConfigResolver + :members: + `Utilities` ----------- From da2a4bc25784839faf06308ec5a3941c53666680 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 24 Jan 2022 11:23:07 +0800 Subject: [PATCH 04/30] [DLMED] add details unit tests for ConfigComponent Signed-off-by: Nic Ma --- monai/apps/__init__.py | 4 + monai/apps/mmars/config_parser.py | 75 +++++++++----- monai/apps/mmars/config_resolver.py | 50 +++++----- monai/apps/mmars/utils.py | 74 +++++++------- tests/test_config_component.py | 148 ++++++++++++++++++++++++++++ tests/test_config_parser.py | 61 ------------ 6 files changed, 269 insertions(+), 143 deletions(-) create mode 100644 tests/test_config_component.py delete mode 100644 tests/test_config_parser.py diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index f021588a9f..a38e0c77c4 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -12,7 +12,9 @@ from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset from .mmars import ( MODEL_DESC, + ConfigComponent, ConfigParser, + ConfigResolver, ModuleScanner, RemoteMMARKeys, download_mmar, @@ -20,5 +22,7 @@ get_model_spec, instantiate_class, load_from_mmar, + search_configs_with_objs, + update_configs_with_objs, ) from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/mmars/config_parser.py b/monai/apps/mmars/config_parser.py index b0cc5f30bf..b935d44811 100644 --- a/monai/apps/mmars/config_parser.py +++ b/monai/apps/mmars/config_parser.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Sequence -from config_resolver import ConfigComponent, ConfigResolver +import importlib +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from monai.apps.mmars.config_resolver import ConfigComponent, ConfigResolver, ModuleScanner class ConfigParser: @@ -18,46 +19,72 @@ class ConfigParser: Parse dictionary format config and build components. Args: - pkgs: the expected packages to scan. - modules: the expected modules in the packages to scan. + pkgs: the expected packages to scan modules and parse class names in the config. + modules: the expected modules in the packages to scan for all the classes. + for example, to parser "LoadImage" in config, `pkgs` can be ["monai"], `modules` can be ["transforms"]. + global_imports: pre-import packages as global variables to execute `eval` commands. + for example, pre-import `monai`, then execute `eval("monai.data.list_data_collate")`. + config: config content to parse. """ - def __init__(self, pkgs: Sequence[str], modules: Sequence[str], config: Optional[Dict] = None): - self.pkgs = pkgs - self.modules = modules - self.config = {} - if isinstance(config, dict): + def __init__( + self, + pkgs: Sequence[str], + modules: Sequence[str], + global_imports: Optional[Sequence[str]] = None, + config: Optional[Any] = None, + ): + self.config = None + if config is not None: self.set_config(config=config) + self.module_scanner = ModuleScanner(pkgs=pkgs, modules=modules) + self.global_imports = {} + if global_imports is not None: + for i in global_imports: + self.global_imports[i] = importlib.import_module(i) self.config_resolver: Optional[ConfigResolver] = None self.resolved = False + def _get_last_config_and_key(config: Union[Dict, List], path: str) -> Tuple[Union[Dict, List], str]: + keys = path.split("#") + for k in keys[:-1]: + config = config[k] if isinstance(config, dict) else config[int(k)] + key = keys[-1] if isinstance(config, dict) else int(keys[-1]) + return config, key + def set_config(self, config: Any, path: Optional[str] = None): if isinstance(path, str): - keys = path.split(".") - config = self.config - for k in keys[:-1]: - config = config[k] - config[keys[-1]] = config + conf_, key = self._get_last_config_and_key(self.config, path) + conf_[key] = config else: self.config = config self.resolved = False - def get_config(self, config: Dict, path: Optional[str] = None): + def get_config(self, path: Optional[str] = None): if isinstance(path, str): - keys = path.split(".") - config = self.config - for k in keys[:-1]: - config = config[k] - return config[keys[-1]] + conf_, key = self._get_last_config_and_key(self.config, path) + return conf_[key] return self.config + def _do_scan(self, config, path: Optional[str] = None): + if isinstance(config, dict): + for k, v in config.items(): + sub_path = k if path is None else f"{path}#{k}" + self._do_scan(config=v, path=sub_path) + if isinstance(config, list): + for i, v in enumerate(config): + sub_path = i if path is None else f"{path}#{i}" + self._do_scan(config=v, path=sub_path) + if path is not None: + self.config_resolver.update( + ConfigComponent(id=path, config=config, module_scanner=self.module_scanner, globals=self.global_imports) + ) + def resolve_config(self, resolve_all: bool = False): self.config_resolver = ConfigResolver() - for k, v in self.config.items(): - # only prepare the components, lazy instantiation - # FIXME: only support "@" reference in top level config for now - self.config_resolver.update(ConfigComponent(id=k, config=v, pkgs=self.pkgs, modules=self.modules)) + self._do_scan(config=self.config) + if resolve_all: self.config_resolver.resolve_all() self.resolved = True diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py index 9595989f5c..fa92a06dc3 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_resolver.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence import importlib import inspect import pkgutil @@ -39,7 +39,8 @@ def _create_classes_table(self): package = importlib.import_module(pkg) for _, modname, _ in pkgutil.walk_packages(path=package.__path__, prefix=package.__name__ + "."): - if any(name in modname for name in self.modules): + # if no modules specified, load all modules in the package + if len(self.modules) == 0 or any(name in modname for name in self.modules): try: module = importlib.import_module(modname) for name, obj in inspect.getmembers(module): @@ -49,34 +50,38 @@ def _create_classes_table(self): pass return class_table - def get_module_name(self, class_name): + def get_class_module_name(self, class_name): return self._class_table.get(class_name, None) class ConfigComponent: - def __init__(self, id: str, config: Dict, pkgs: Sequence[str], modules: Sequence[str]) -> None: + def __init__(self, id: str, config: Any, module_scanner: ModuleScanner, globals: Optional[Dict] = None) -> None: self.id = id self.config = config - self.module_scanner = ModuleScanner(pkgs=pkgs, modules=modules) + self.module_scanner = module_scanner + self.globals = globals def get_id(self) -> str: return self.id + def get_config(self): + return self.config + def get_referenced_ids(self) -> List[str]: - return search_configs_with_objs(self.config, []) + return search_configs_with_objs(self.config, [], id=self.id) def get_instance(self, refs: dict): - config = update_configs_with_objs(self.config, refs) - return self.build(config) if isinstance(config, dict) and ("name" in config or "path" in config) else config + config = update_configs_with_objs(config=self.config, refs=refs, id=self.id, globals=self.globals) + return self.build(config) if isinstance(config, dict) and ("" in config or "" in config) else config - def build(self, config: Dict) -> object: + def build(self, config: Optional[Dict] = None) -> object: """ Build component instance based on the provided dictonary config. - Supported keys for the config: - - 'name' - class name in the modules of packages. - - 'path' - directly specify the class path, based on PYTHONPATH, ignore 'name' if specified. - - 'args' - arguments to initialize the component instance. - - 'disabled' - if defined `'disabled': true`, will skip the buiding, useful for development or tuning. + Supported special keys for the config: + - '' - class name in the modules of packages. + - '' - directly specify the class path, based on PYTHONPATH, ignore '' if specified. + - '' - arguments to initialize the component instance. + - '' - if defined `'': true`, will skip the buiding, useful for development or tuning. Args: config: dictionary config to define a component. @@ -86,24 +91,25 @@ def build(self, config: Dict) -> object: ValueError: can not find component class. """ + config = self.config if config is None else config if not isinstance(config, dict): - raise ValueError("config of component must be a dictionary.") + raise ValueError("only dictionary config can be built as component instance.") - if config.get("disabled") is True: + if config.get("") is True: # if marked as `disabled`, skip parsing - return None + return config - class_args = config.get("args", {}) + class_args = config.get("", {}) class_path = self._get_class_path(config) return instantiate_class(class_path, **class_args) def _get_class_path(self, config): - class_path = config.get("path", None) + class_path = config.get("", None) if class_path is None: - class_name = config.get("name", None) + class_name = config.get("", None) if class_name is None: - raise ValueError("must provide `path` or `name` of class to build component.") - module_name = self.module_scanner.get_module_name(class_name) + raise ValueError("must provide `` or `` of class to build component.") + module_name = self.module_scanner.get_class_module_name(class_name) if module_name is None: raise ValueError(f"can not find component class '{class_name}'.") class_path = f"{module_name}.{class_name}" diff --git a/monai/apps/mmars/utils.py b/monai/apps/mmars/utils.py index e768215fde..8336b4460b 100644 --- a/monai/apps/mmars/utils.py +++ b/monai/apps/mmars/utils.py @@ -10,8 +10,9 @@ # limitations under the License. import importlib +from optparse import Option import re -from typing import List, Union +from typing import Dict, List, Optional, Union def get_class(class_path: str): @@ -63,48 +64,49 @@ def instantiate_class(class_path: str, **kwargs): raise ValueError(f"class {class_path} has parameters error.") from e -def search_configs_with_objs(configs: Union[dict, list, str], refs: List[str]): - pattern = re.compile(r'@\w*') - if isinstance(configs, list): - for i in configs: - refs = search_configs_with_objs(i, refs) - elif isinstance(configs, dict): - for _, v in configs.items(): - refs = search_configs_with_objs(v, refs) - elif isinstance(configs, str): - result = pattern.findall(configs) +def search_configs_with_objs(config: Union[Dict, List, str], refs: List[str], id: str): + pattern = re.compile(r'@\w*[\#\w]*') # match ref as args: "@XXX#YYY#ZZZ" + if isinstance(config, list): + for i, v in enumerate(config): + sub_id = f"{id}#{i}" + # all the items in the list should be marked as dependent reference + refs.append(sub_id) + refs = search_configs_with_objs(v, refs, sub_id) + if isinstance(config, dict): + for k, v in config.items(): + sub_id = f"{id}#{k}" + # all the items in the dict should be marked as dependent reference + refs.append(sub_id) + refs = search_configs_with_objs(v, refs, sub_id) + if isinstance(config, str): + result = pattern.findall(config) for item in result: - # only parse `@` for: `@object`, `lambda ...`, `#lambda ...` - if configs.startswith("#") or configs.startswith("lambda") or configs == item: + if config.startswith("$") or config == item: ref_obj_id = item[1:] if ref_obj_id not in refs: refs.append(ref_obj_id) return refs -def update_configs_with_objs(configs: Union[dict, list, str], refs: dict): - pattern = re.compile(r'@\w*') - if isinstance(configs, list): - configs = [update_configs_with_objs(i, refs) for i in configs] - elif isinstance(configs, dict): - configs = {k: update_configs_with_objs(v, refs) for k, v in configs.items()} - elif isinstance(configs, str): - result = pattern.findall(configs) +def update_configs_with_objs(config: Union[Dict, List, str], refs: dict, id: str, globals: Optional[Dict] = None): + pattern = re.compile(r'@\w*[\#\w]*') # match ref as args: "@XXX#YYY#ZZZ" + if isinstance(config, list): + # all the items in the list should be replaced with the reference + config = [refs[f"{id}#{i}"] for i in range(len(config))] + if isinstance(config, dict): + # all the items in the dict should be replaced with the reference + config = {k: refs[f"{id}#{k}"] for k, _ in config.items()} + if isinstance(config, str): + result = pattern.findall(config) for item in result: ref_obj_id = item[1:] - # only parse `@` for: `@object`, `lambda ...`, `#lambda ...` - if configs.startswith("lambda") or configs.startswith("#lambda"): - # if using @object in a lambda function, only support to convert the item to f-string - configs = configs.replace(item, f"{refs[ref_obj_id]}") - elif configs.startswith("#"): + if config.startswith("$"): # replace with local code and execute soon - configs = configs.replace(item, f"refs['{ref_obj_id}']") - elif configs == item: - configs = refs[ref_obj_id] - - if isinstance(configs, str): - if configs.startswith("#"): - configs = eval(configs[1:]) - elif configs.startswith("lambda"): - configs = eval(configs) - return configs + config = config.replace(item, f"refs['{ref_obj_id}']") + elif config == item: + config = refs[ref_obj_id] + + if isinstance(config, str): + if config.startswith("$"): + config = eval(config[1:], globals, {"refs": refs}) + return config diff --git a/tests/test_config_component.py b/tests/test_config_component.py new file mode 100644 index 0000000000..1aa5e1572f --- /dev/null +++ b/tests/test_config_component.py @@ -0,0 +1,148 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Iterator +import unittest + +import torch +import monai +from parameterized import parameterized + +from monai.apps import ConfigComponent, ModuleScanner +from monai.data import Dataset, DataLoader +from monai.transforms import LoadImaged, RandTorchVisiond + +TEST_CASE_1 = [ + dict(pkgs=["monai"], modules=["transforms"]), + {"": "LoadImaged", "": {"keys": ["image"]}}, + LoadImaged, +] +# test python `` +TEST_CASE_2 = [ + dict(pkgs=[], modules=[]), + {"": "monai.transforms.LoadImaged", "": {"keys": ["image"]}}, + LoadImaged, +] +# test `` +TEST_CASE_3 = [ + dict(pkgs=["monai"], modules=["transforms"]), + {"": "LoadImaged", "": True, "": {"keys": ["image"]}}, + dict, +] +# test non-monai modules +TEST_CASE_4 = [ + dict(pkgs=["torch.optim", "monai"], modules=["adam"]), + {"": "Adam", "": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, + torch.optim.Adam, +] +# test args contains "name" field +TEST_CASE_5 = [ + dict(pkgs=["monai"], modules=["transforms"]), + {"": "RandTorchVisiond", "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}}, + RandTorchVisiond, +] +# test refs of dict config +TEST_CASE_6 = [ + {"dataset": "@dataset", "batch_size": 2}, + ["test#dataset", "dataset", "test#batch_size"], +] +# test refs of list config +TEST_CASE_7 = [ + {"dataset": "@dataset", "transforms": ["@trans0", "@trans1"]}, + ["test#dataset", "dataset", "test#transforms", "test#transforms#0", "trans0", "test#transforms#1", "trans1"], +] +# test refs of execute code +TEST_CASE_8 = [ + {"dataset": "$@dataset.test_func()", "transforms": ["$torch.zeros([2, 2]) + @trans"]}, + ["test#dataset", "dataset", "test#transforms", "test#transforms#0", "trans"], +] + +# test refs of lambda function +TEST_CASE_9 = [ + {"lr_range": "$lambda x: x + @num_epochs", "lr": ["$lambda x: torch.zeros([2, 2]) + @init_lr"]}, + ["test#lr_range", "num_epochs", "test#lr", "test#lr#0", "init_lr"], +] +# test instance with no reference +TEST_CASE_10 = [ + "transform#1", + {"": "LoadImaged", "": {"keys": ["image"]}}, + {"transform#1#": "LoadImaged", "transform#1#": {"keys": ["image"]}}, + LoadImaged, +] +# test dataloader refers to `@dataset`, here we don't test recursive reference, test that in `ConfigResolver` +TEST_CASE_11 = [ + "dataloader", + {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}}, + {"dataloader#": "DataLoader", "dataloader#": {"dataset": Dataset(data=[1, 2]), "batch_size": 2}}, + DataLoader, +] +# test reference in code execution +TEST_CASE_12 = [ + "optimizer", + {"": "Adam", "": {"params": "$@model.parameters()", "lr": "@learning_rate"}}, + {"optimizer#": "Adam", "optimizer#": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, + torch.optim.Adam, +] +# test replace reference with code execution result +TEST_CASE_13 = [ + "optimizer##params", + "$@model.parameters()", + {"model": torch.nn.PReLU()}, + Iterator, +] +# test execute some function in args, test pre-imported global packages `monai` +TEST_CASE_14 = [ + "dataloader##collate_fn", + "$monai.data.list_data_collate", + {}, + Callable, +] +# test lambda function, should not execute the lambda function, just change the string with reference objects +TEST_CASE_15 = [ + "dataloader##collate_fn", + "$lambda x: monai.data.list_data_collate(x) + 100", + {}, + Callable, +] + + +class TestConfigComponent(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + def test_build(self, input_param, test_input, output_type): + scanner = ModuleScanner(**input_param) + configer = ConfigComponent(id="test", config=test_input, module_scanner=scanner) + ret = configer.build() + self.assertTrue(isinstance(ret, output_type)) + if isinstance(ret, LoadImaged): + self.assertEqual(ret.keys[0], "image") + if isinstance(ret, dict): + # test `` works fine + self.assertDictEqual(ret, test_input) + + @parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9]) + def test_reference_ids(self, test_input, ref_ids): + scanner = ModuleScanner(pkgs=[], modules=[]) + configer = ConfigComponent(id="test", config=test_input, module_scanner=scanner) + ret = configer.get_referenced_ids() + self.assertListEqual(ret, ref_ids) + + @parameterized.expand([TEST_CASE_10, TEST_CASE_11, TEST_CASE_12, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15]) + def test_get_instance(self, id, test_input, refs, output_type): + scanner = ModuleScanner(pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"]) + configer = ConfigComponent( + id=id, config=test_input, module_scanner=scanner, globals={"monai": monai, "torch": torch} + ) + ret = configer.get_instance(refs) + self.assertTrue(isinstance(ret, output_type)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py deleted file mode 100644 index 6b077f5a76..0000000000 --- a/tests/test_config_parser.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import torch -from parameterized import parameterized - -from monai.apps import ConfigParser -from monai.transforms import LoadImaged - -TEST_CASE_1 = [ - dict(pkgs=["monai"], modules=["transforms"]), - {"name": "LoadImaged", "args": {"keys": ["image"]}}, - LoadImaged, -] -# test python `path` -TEST_CASE_2 = [ - dict(pkgs=[], modules=[]), - {"path": "monai.transforms.LoadImaged", "args": {"keys": ["image"]}}, - LoadImaged, -] -# test `disabled` -TEST_CASE_3 = [ - dict(pkgs=["monai"], modules=["transforms"]), - {"name": "LoadImaged", "disabled": True, "args": {"keys": ["image"]}}, - None, -] -# test non-monai modules -TEST_CASE_4 = [ - dict(pkgs=["torch.optim", "monai"], modules=["adam"]), - {"name": "Adam", "args": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, - torch.optim.Adam, -] - - -class TestConfigParser(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) - def test_type(self, input_param, test_input, output_type): - configer = ConfigParser(**input_param) - configer.set_config({"test": test_input}) - result = configer.get_instance("test") - if result is not None: - self.assertTrue(isinstance(result, output_type)) - if isinstance(result, LoadImaged): - self.assertEqual(result.keys[0], "image") - else: - # test `disabled` works fine - self.assertEqual(result, output_type) - - -if __name__ == "__main__": - unittest.main() From 7e07e7675a4634c42d35b2ef6d65b1e69622fb9d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 24 Jan 2022 12:22:44 +0800 Subject: [PATCH 05/30] [DLMED] add unit tests for ConfigResolver Signed-off-by: Nic Ma --- monai/apps/mmars/config_resolver.py | 66 ++++++++++++++++------------- tests/test_config_component.py | 5 ++- tests/test_config_resolver.py | 56 ++++++++++++++++++++++++ 3 files changed, 96 insertions(+), 31 deletions(-) create mode 100644 tests/test_config_resolver.py diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py index fa92a06dc3..c91513d953 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_resolver.py @@ -70,9 +70,8 @@ def get_config(self): def get_referenced_ids(self) -> List[str]: return search_configs_with_objs(self.config, [], id=self.id) - def get_instance(self, refs: dict): - config = update_configs_with_objs(config=self.config, refs=refs, id=self.id, globals=self.globals) - return self.build(config) if isinstance(config, dict) and ("" in config or "" in config) else config + def get_updated_config(self, refs: dict): + return update_configs_with_objs(config=self.config, refs=refs, id=self.id, globals=self.globals) def build(self, config: Optional[Dict] = None) -> object: """ @@ -92,10 +91,9 @@ def build(self, config: Optional[Dict] = None) -> object: """ config = self.config if config is None else config - if not isinstance(config, dict): - raise ValueError("only dictionary config can be built as component instance.") - - if config.get("") is True: + if not isinstance(config, dict) \ + or ("" not in config and "" not in config) \ + or config.get("") is True: # if marked as `disabled`, skip parsing return config @@ -119,39 +117,49 @@ def _get_class_path(self, config): class ConfigResolver: def __init__(self, components: Optional[Dict[str, ConfigComponent]] = None): - self.resolved = {} + self.resolved_configs = {} + self.resolved_components = {} self.components = {} if components is None else components def update(self, component: ConfigComponent): self.components[component.get_id()] = component - def resolve_one_object(self, id: str) -> bool: - obj = self.components[id] + def resolve_one_component(self, id: str, instantiate: bool = True) -> bool: + com = self.components[id] # check whether the obj has any unresolved refs in its args - ref_ids = obj.get_referenced_ids() - if not ref_ids: - # this object does not reference others - resolved_obj = obj.get_instance([]) - else: + ref_ids = com.get_referenced_ids() + refs = {} + if len(ref_ids) > 0: # see whether all refs are resolved - refs = {} for comp_id in ref_ids: - if comp_id not in self.resolved: - # this referenced object is not resolved + if comp_id not in self.resolved_components: + # this referenced component is not resolved if comp_id not in self.components: raise RuntimeError(f"the reference component `{comp_id}` is not in config.") # resolve the dependency first - self.resolve_one_object(id=comp_id) - refs[comp_id] = self.resolved[comp_id] - # all referenced objects are resolved already - resolved_obj = obj.get_instance(refs) + self.resolve_one_component(id=comp_id, instantiate=True) + refs[comp_id] = self.resolved_components[comp_id] + # all referenced components are resolved already + updated_config = com.get_updated_config(refs) + resolved_com = None - self.resolved[id] = resolved_obj - return resolved_obj + if instantiate: + resolved_com = com.build(updated_config) + self.resolved_configs[id] = updated_config + self.resolved_components[id] = resolved_com - def resolve_all(self): - for v in self.components.values(): - self.resolve_one_object(obj=v) + return updated_config, resolved_com - def get_resolved(self, id: str): - return self.resolved[id] + def resolve_all(self): + for k in self.components.keys(): + self.resolve_one_component(id=k) + + def get_resolved_compnent(self, id: str): + if id not in self.resolved_components: + self.resolve_one_component(id=id, instantiate=True) + return self.resolved_components[id] + + def get_resolved_config(self, id: str): + if id not in self.resolved_configs: + self.resolve_one_component(id=id, instantiate=False) + return self.resolved_configs[id] diff --git a/tests/test_config_component.py b/tests/test_config_component.py index 1aa5e1572f..18c4eaaa41 100644 --- a/tests/test_config_component.py +++ b/tests/test_config_component.py @@ -135,12 +135,13 @@ def test_reference_ids(self, test_input, ref_ids): self.assertListEqual(ret, ref_ids) @parameterized.expand([TEST_CASE_10, TEST_CASE_11, TEST_CASE_12, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15]) - def test_get_instance(self, id, test_input, refs, output_type): + def test_update_reference(self, id, test_input, refs, output_type): scanner = ModuleScanner(pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"]) configer = ConfigComponent( id=id, config=test_input, module_scanner=scanner, globals={"monai": monai, "torch": torch} ) - ret = configer.get_instance(refs) + config = configer.get_updated_config(refs) + ret = configer.build(config) self.assertTrue(isinstance(ret, output_type)) diff --git a/tests/test_config_resolver.py b/tests/test_config_resolver.py new file mode 100644 index 0000000000..923a87638e --- /dev/null +++ b/tests/test_config_resolver.py @@ -0,0 +1,56 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distutils.command.config import config +from typing import Callable, Iterator +import unittest + +import torch +import monai +from parameterized import parameterized + +from monai.apps import ConfigComponent, ConfigResolver, ModuleScanner +from monai.data import Dataset, DataLoader +from monai.transforms import LoadImaged + +# test instance with no reference +TEST_CASE_1 = [ + { + "transform#1": {"": "LoadImaged", "": {"keys": ["image"]}}, + "transform#1#": "LoadImaged", + "transform#1#": {"keys": ["image"]}, + "transform#1##keys": ["image"], + "transform#1##keys#0": "image", + }, + "transform#1", + LoadImaged, +] + + +class TestConfigComponent(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) + def test_get_instance(self, configs, expected_id, output_type): + scanner = ModuleScanner(pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"]) + resolver = ConfigResolver() + for k, v in configs.items(): + resolver.update(ConfigComponent( + id=k, config=v, module_scanner=scanner, globals={"monai": monai, "torch": torch} + )) + config, ins = resolver.resolve_one_component(expected_id) + self.assertTrue(isinstance(ins, output_type)) + # test lazy instantiation + config[""] = False + ins = ConfigComponent(id=expected_id, module_scanner=scanner, config=config).build() + self.assertTrue(isinstance(ins, output_type)) + + +if __name__ == "__main__": + unittest.main() From c579c593a2261958c67c72d21b970cc0e777c179 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 24 Jan 2022 15:58:57 +0800 Subject: [PATCH 06/30] [DLMED] add details unit tests for ConfigParser, ConfigResolver Signed-off-by: Nic Ma --- monai/apps/mmars/config_parser.py | 48 +++++++++++++--------- monai/apps/mmars/config_resolver.py | 42 +++++++++++++++---- tests/test_config_component.py | 36 +++++++++------- tests/test_config_parser.py | 64 +++++++++++++++++++++++++++++ tests/test_config_resolver.py | 53 ++++++++++++++++++++---- 5 files changed, 193 insertions(+), 50 deletions(-) create mode 100644 tests/test_config_parser.py diff --git a/monai/apps/mmars/config_parser.py b/monai/apps/mmars/config_parser.py index b935d44811..8ced6003bb 100644 --- a/monai/apps/mmars/config_parser.py +++ b/monai/apps/mmars/config_parser.py @@ -46,50 +46,60 @@ def __init__( self.config_resolver: Optional[ConfigResolver] = None self.resolved = False - def _get_last_config_and_key(config: Union[Dict, List], path: str) -> Tuple[Union[Dict, List], str]: - keys = path.split("#") + def _get_last_config_and_key(self, config: Union[Dict, List], id: str) -> Tuple[Union[Dict, List], str]: + keys = id.split("#") for k in keys[:-1]: config = config[k] if isinstance(config, dict) else config[int(k)] key = keys[-1] if isinstance(config, dict) else int(keys[-1]) return config, key - def set_config(self, config: Any, path: Optional[str] = None): - if isinstance(path, str): - conf_, key = self._get_last_config_and_key(self.config, path) + def set_config(self, config: Any, id: Optional[str] = None): + if isinstance(id, str): + conf_, key = self._get_last_config_and_key(config=self.config, id=id) conf_[key] = config else: self.config = config self.resolved = False - def get_config(self, path: Optional[str] = None): - if isinstance(path, str): - conf_, key = self._get_last_config_and_key(self.config, path) + def get_config(self, id: Optional[str] = None): + if isinstance(id, str): + conf_, key = self._get_last_config_and_key(config=self.config, id=id) return conf_[key] return self.config - def _do_scan(self, config, path: Optional[str] = None): + def _do_parse(self, config, id: Optional[str] = None): if isinstance(config, dict): for k, v in config.items(): - sub_path = k if path is None else f"{path}#{k}" - self._do_scan(config=v, path=sub_path) + sub_id = k if id is None else f"{id}#{k}" + self._do_parse(config=v, id=sub_id) if isinstance(config, list): for i, v in enumerate(config): - sub_path = i if path is None else f"{path}#{i}" - self._do_scan(config=v, path=sub_path) - if path is not None: - self.config_resolver.update( - ConfigComponent(id=path, config=config, module_scanner=self.module_scanner, globals=self.global_imports) + sub_id = i if id is None else f"{id}#{i}" + self._do_parse(config=v, id=sub_id) + if id is not None: + self.config_resolver.add( + ConfigComponent(id=id, config=config, module_scanner=self.module_scanner, globals=self.global_imports) ) def resolve_config(self, resolve_all: bool = False): self.config_resolver = ConfigResolver() - self._do_scan(config=self.config) + self._do_parse(config=self.config) if resolve_all: self.config_resolver.resolve_all() self.resolved = True - def get_instance(self, id: str): + def get_resolved_config(self, id: str): if self.config_resolver is None or not self.resolved: self.resolve_config() - return self.config_resolver.resolve_one_object(id=id) + return self.config_resolver.get_resolved_config(id=id) + + def get_resolved_component(self, id: str): + if self.config_resolver is None or not self.resolved: + self.resolve_config() + return self.config_resolver.get_resolved_compnent(id=id) + + def build(self, config: Dict): + return ConfigComponent( + id=None, config=config, module_scanner=self.module_scanner, globals=self.global_imports + ).build() diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py index c91513d953..3e570d65e3 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_resolver.py @@ -13,6 +13,9 @@ import importlib import inspect import pkgutil +from xmlrpc.client import FastMarshaller + +from torch import warnings from monai.apps.mmars.utils import instantiate_class, search_configs_with_objs, update_configs_with_objs @@ -73,6 +76,20 @@ def get_referenced_ids(self) -> List[str]: def get_updated_config(self, refs: dict): return update_configs_with_objs(config=self.config, refs=refs, id=self.id, globals=self.globals) + def _check_dependency(self, config): + if isinstance(config, list): + for i in config: + if self._check_dependency(i): + return True + if isinstance(config, dict): + for v in config.values(): + if self._check_dependency(v): + return True + if isinstance(config, str): + if config.startswith("&") or "@" in config: + return True + return False + def build(self, config: Optional[Dict] = None) -> object: """ Build component instance based on the provided dictonary config. @@ -91,6 +108,10 @@ def build(self, config: Optional[Dict] = None) -> object: """ config = self.config if config is None else config + if self._check_dependency(config=config): + warnings.warn("config content has other dependencies or executable string, skip `build`.") + return config + if not isinstance(config, dict) \ or ("" not in config and "" not in config) \ or config.get("") is True: @@ -121,10 +142,13 @@ def __init__(self, components: Optional[Dict[str, ConfigComponent]] = None): self.resolved_components = {} self.components = {} if components is None else components - def update(self, component: ConfigComponent): - self.components[component.get_id()] = component + def add(self, component: ConfigComponent): + id = component.get_id() + if id in self.components: + raise ValueError(f"id '{id}' is already added.") + self.components[id] = component - def resolve_one_component(self, id: str, instantiate: bool = True) -> bool: + def _resolve_one_component(self, id: str, instantiate: bool = True) -> bool: com = self.components[id] # check whether the obj has any unresolved refs in its args ref_ids = com.get_referenced_ids() @@ -137,7 +161,7 @@ def resolve_one_component(self, id: str, instantiate: bool = True) -> bool: if comp_id not in self.components: raise RuntimeError(f"the reference component `{comp_id}` is not in config.") # resolve the dependency first - self.resolve_one_component(id=comp_id, instantiate=True) + self._resolve_one_component(id=comp_id, instantiate=True) refs[comp_id] = self.resolved_components[comp_id] # all referenced components are resolved already updated_config = com.get_updated_config(refs) @@ -152,14 +176,16 @@ def resolve_one_component(self, id: str, instantiate: bool = True) -> bool: def resolve_all(self): for k in self.components.keys(): - self.resolve_one_component(id=k) + self._resolve_one_component(id=k, instantiate=True) def get_resolved_compnent(self, id: str): if id not in self.resolved_components: - self.resolve_one_component(id=id, instantiate=True) + self._resolve_one_component(id=id, instantiate=True) return self.resolved_components[id] def get_resolved_config(self, id: str): if id not in self.resolved_configs: - self.resolve_one_component(id=id, instantiate=False) - return self.resolved_configs[id] + config, _ = self._resolve_one_component(id=id, instantiate=False) + else: + config = self.resolved_configs[id] + return config diff --git a/tests/test_config_component.py b/tests/test_config_component.py index 18c4eaaa41..2ec5dfd46c 100644 --- a/tests/test_config_component.py +++ b/tests/test_config_component.py @@ -37,76 +37,82 @@ {"": "LoadImaged", "": True, "": {"keys": ["image"]}}, dict, ] -# test non-monai modules +# test unresolved dependency TEST_CASE_4 = [ + dict(pkgs=["monai"], modules=["transforms"]), + {"": "LoadImaged", "": {"keys": ["@key_name"]}}, + dict, +] +# test non-monai modules +TEST_CASE_5 = [ dict(pkgs=["torch.optim", "monai"], modules=["adam"]), {"": "Adam", "": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, torch.optim.Adam, ] # test args contains "name" field -TEST_CASE_5 = [ +TEST_CASE_6 = [ dict(pkgs=["monai"], modules=["transforms"]), {"": "RandTorchVisiond", "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}}, RandTorchVisiond, ] # test refs of dict config -TEST_CASE_6 = [ +TEST_CASE_7 = [ {"dataset": "@dataset", "batch_size": 2}, ["test#dataset", "dataset", "test#batch_size"], ] # test refs of list config -TEST_CASE_7 = [ +TEST_CASE_8 = [ {"dataset": "@dataset", "transforms": ["@trans0", "@trans1"]}, ["test#dataset", "dataset", "test#transforms", "test#transforms#0", "trans0", "test#transforms#1", "trans1"], ] # test refs of execute code -TEST_CASE_8 = [ +TEST_CASE_9 = [ {"dataset": "$@dataset.test_func()", "transforms": ["$torch.zeros([2, 2]) + @trans"]}, ["test#dataset", "dataset", "test#transforms", "test#transforms#0", "trans"], ] # test refs of lambda function -TEST_CASE_9 = [ +TEST_CASE_10 = [ {"lr_range": "$lambda x: x + @num_epochs", "lr": ["$lambda x: torch.zeros([2, 2]) + @init_lr"]}, ["test#lr_range", "num_epochs", "test#lr", "test#lr#0", "init_lr"], ] # test instance with no reference -TEST_CASE_10 = [ +TEST_CASE_11 = [ "transform#1", {"": "LoadImaged", "": {"keys": ["image"]}}, {"transform#1#": "LoadImaged", "transform#1#": {"keys": ["image"]}}, LoadImaged, ] # test dataloader refers to `@dataset`, here we don't test recursive reference, test that in `ConfigResolver` -TEST_CASE_11 = [ +TEST_CASE_12 = [ "dataloader", {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}}, {"dataloader#": "DataLoader", "dataloader#": {"dataset": Dataset(data=[1, 2]), "batch_size": 2}}, DataLoader, ] # test reference in code execution -TEST_CASE_12 = [ +TEST_CASE_13 = [ "optimizer", {"": "Adam", "": {"params": "$@model.parameters()", "lr": "@learning_rate"}}, {"optimizer#": "Adam", "optimizer#": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, torch.optim.Adam, ] # test replace reference with code execution result -TEST_CASE_13 = [ +TEST_CASE_14 = [ "optimizer##params", "$@model.parameters()", {"model": torch.nn.PReLU()}, Iterator, ] # test execute some function in args, test pre-imported global packages `monai` -TEST_CASE_14 = [ +TEST_CASE_15 = [ "dataloader##collate_fn", "$monai.data.list_data_collate", {}, Callable, ] # test lambda function, should not execute the lambda function, just change the string with reference objects -TEST_CASE_15 = [ +TEST_CASE_16 = [ "dataloader##collate_fn", "$lambda x: monai.data.list_data_collate(x) + 100", {}, @@ -115,7 +121,7 @@ class TestConfigComponent(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_build(self, input_param, test_input, output_type): scanner = ModuleScanner(**input_param) configer = ConfigComponent(id="test", config=test_input, module_scanner=scanner) @@ -127,14 +133,14 @@ def test_build(self, input_param, test_input, output_type): # test `` works fine self.assertDictEqual(ret, test_input) - @parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9]) + @parameterized.expand([TEST_CASE_7, TEST_CASE_8, TEST_CASE_9, TEST_CASE_10]) def test_reference_ids(self, test_input, ref_ids): scanner = ModuleScanner(pkgs=[], modules=[]) configer = ConfigComponent(id="test", config=test_input, module_scanner=scanner) ret = configer.get_referenced_ids() self.assertListEqual(ret, ref_ids) - @parameterized.expand([TEST_CASE_10, TEST_CASE_11, TEST_CASE_12, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15]) + @parameterized.expand([TEST_CASE_11, TEST_CASE_12, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15, TEST_CASE_16]) def test_update_reference(self, id, test_input, refs, output_type): scanner = ModuleScanner(pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"]) configer = ConfigComponent( diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py new file mode 100644 index 0000000000..2a987241b0 --- /dev/null +++ b/tests/test_config_parser.py @@ -0,0 +1,64 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from parameterized import parameterized + +from monai.apps import ConfigParser +from monai.data import Dataset, DataLoader +from monai.transforms import Compose, LoadImaged, RandTorchVisiond + +# test the resolved and parsed instances +TEST_CASE_1 = [ + { + "transform": { + "": "Compose", + "": {"transforms": [ + {"": "LoadImaged", "": {"keys": "image"}}, + {"": "RandTorchVisiond", "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}}, + ]} + }, + "dataset": {"": "Dataset", "": {"data": [1, 2], "transform": "@transform"}}, + "dataloader": { + "": "DataLoader", + "": {"dataset": "@dataset", "batch_size": 2, "collate_fn": "monai.data.list_data_collate"}, + }, + }, + ["transform", "transform##transforms#0", "transform##transforms#1", "dataset", "dataloader"], + [Compose, LoadImaged, RandTorchVisiond, Dataset, DataLoader], +] + + +class TestConfigComponent(unittest.TestCase): + def test_config_content(self): + parser = ConfigParser(pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"]) + test_config = {"preprocessing": [{"name": "LoadImage"}], "dataset": {"name": "Dataset"}} + parser.set_config(config=test_config) + self.assertEqual(str(parser.get_config()), str(test_config)) + parser.set_config(config={"name": "CacheDataset"}, id="preprocessing#0#datasets") + self.assertDictEqual(parser.get_config(id="preprocessing#0#datasets"), {"name": "CacheDataset"}) + + @parameterized.expand([TEST_CASE_1]) + def test_parse(self, config, expected_ids, output_types): + parser = ConfigParser( + pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"], global_imports=["monai"], config=config + ) + for id, cls in zip(expected_ids, output_types): + config = parser.get_resolved_config(id) + # test lazy instantiation + self.assertTrue(isinstance(config, dict)) + self.assertTrue(isinstance(parser.build(config), cls)) + # test get instance directly + self.assertTrue(isinstance(parser.get_resolved_component(id), cls)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_config_resolver.py b/tests/test_config_resolver.py index 923a87638e..3d35efdda0 100644 --- a/tests/test_config_resolver.py +++ b/tests/test_config_resolver.py @@ -9,8 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from distutils.command.config import config -from typing import Callable, Iterator import unittest import torch @@ -18,12 +16,13 @@ from parameterized import parameterized from monai.apps import ConfigComponent, ConfigResolver, ModuleScanner -from monai.data import Dataset, DataLoader -from monai.transforms import LoadImaged +from monai.data import DataLoader +from monai.transforms import LoadImaged, RandTorchVisiond # test instance with no reference TEST_CASE_1 = [ { + # all the recursively parsed config items "transform#1": {"": "LoadImaged", "": {"keys": ["image"]}}, "transform#1#": "LoadImaged", "transform#1#": {"keys": ["image"]}, @@ -33,19 +32,57 @@ "transform#1", LoadImaged, ] +# test depends on other component and executable code +TEST_CASE_2 = [ + { + # all the recursively parsed config items + "dataloader": { + "": "DataLoader", "": {"dataset": "@dataset", "collate_fn": "$monai.data.list_data_collate"} + }, + "dataset": {"": "Dataset", "": {"data": [1, 2]}}, + "dataloader#": "DataLoader", + "dataloader#": {"dataset": "@dataset", "collate_fn": "$monai.data.list_data_collate"}, + "dataloader##dataset": "@dataset", + "dataloader##collate_fn": "$monai.data.list_data_collate", + "dataset#": "Dataset", + "dataset#": {"data": [1, 2]}, + "dataset##data": [1, 2], + "dataset##data#0": 1, + "dataset##data#1": 2, + }, + "dataloader", + DataLoader, +] +# test config has key `name` +TEST_CASE_3 = [ + { + # all the recursively parsed config items + "transform#1": { + "": "RandTorchVisiond", "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25} + }, + "transform#1#": "RandTorchVisiond", + "transform#1#": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}, + "transform#1##keys": "image", + "transform#1##name": "ColorJitter", + "transform#1##brightness": 0.25, + }, + "transform#1", + RandTorchVisiond, +] class TestConfigComponent(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) - def test_get_instance(self, configs, expected_id, output_type): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_resolve(self, configs, expected_id, output_type): scanner = ModuleScanner(pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"]) resolver = ConfigResolver() for k, v in configs.items(): - resolver.update(ConfigComponent( + resolver.add(ConfigComponent( id=k, config=v, module_scanner=scanner, globals={"monai": monai, "torch": torch} )) - config, ins = resolver.resolve_one_component(expected_id) + ins = resolver.get_resolved_compnent(expected_id) self.assertTrue(isinstance(ins, output_type)) + config = resolver.get_resolved_config(expected_id) # test lazy instantiation config[""] = False ins = ConfigComponent(id=expected_id, module_scanner=scanner, config=config).build() From 9283623861a1717149a6af01f2fc434d062f1c27 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 24 Jan 2022 16:03:23 +0800 Subject: [PATCH 07/30] [DLMED] fix typo Signed-off-by: Nic Ma --- monai/apps/mmars/config_resolver.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py index 3e570d65e3..3fd050dc3a 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_resolver.py @@ -13,9 +13,7 @@ import importlib import inspect import pkgutil -from xmlrpc.client import FastMarshaller - -from torch import warnings +import warnings from monai.apps.mmars.utils import instantiate_class, search_configs_with_objs, update_configs_with_objs From d4ebe57069f89b528e81a808fd649f0044a95f5a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 24 Jan 2022 17:43:22 +0800 Subject: [PATCH 08/30] [DLMED] add details doc-strings Signed-off-by: Nic Ma --- monai/apps/mmars/config_parser.py | 94 ++++++++++++++++++-- monai/apps/mmars/config_resolver.py | 133 ++++++++++++++++++++++++++-- 2 files changed, 212 insertions(+), 15 deletions(-) diff --git a/monai/apps/mmars/config_parser.py b/monai/apps/mmars/config_parser.py index 8ced6003bb..57b9c366a5 100644 --- a/monai/apps/mmars/config_parser.py +++ b/monai/apps/mmars/config_parser.py @@ -16,14 +16,18 @@ class ConfigParser: """ - Parse dictionary format config and build components. + Parse a nested config and build components. + A typical usage is a config dictionary contains all the necessary components to define training workflow in JSON. + For more details of the config format, please check :py:class:`monai.apps.ConfigComponent`. Args: pkgs: the expected packages to scan modules and parse class names in the config. modules: the expected modules in the packages to scan for all the classes. for example, to parser "LoadImage" in config, `pkgs` can be ["monai"], `modules` can be ["transforms"]. - global_imports: pre-import packages as global variables to execute `eval` commands. + global_imports: pre-import packages as global variables to execute the python `eval` commands. for example, pre-import `monai`, then execute `eval("monai.data.list_data_collate")`. + default to `{"monai": "monai", "torch": "torch", "np": "numpy"}` as `numpy` and `torch` + are MONAI mininum requirements. config: config content to parse. """ @@ -32,7 +36,7 @@ def __init__( self, pkgs: Sequence[str], modules: Sequence[str], - global_imports: Optional[Sequence[str]] = None, + global_imports: Optional[Dict[str, str]] = {"monai": "monai", "torch": "torch", "np": "numpy"}, config: Optional[Any] = None, ): self.config = None @@ -41,12 +45,21 @@ def __init__( self.module_scanner = ModuleScanner(pkgs=pkgs, modules=modules) self.global_imports = {} if global_imports is not None: - for i in global_imports: - self.global_imports[i] = importlib.import_module(i) + for k, v in global_imports.items(): + self.global_imports[k] = importlib.import_module(v) self.config_resolver: Optional[ConfigResolver] = None self.resolved = False def _get_last_config_and_key(self, config: Union[Dict, List], id: str) -> Tuple[Union[Dict, List], str]: + """ + Utility to get the last config item and the id from the whole config content with nested id name. + + Args: + config: the whole config content. + id: nested id name to get the last item, joined by "#" mark, use index from 0 for list. + for example: "transforms#5", "transforms#5##keys", etc. + + """ keys = id.split("#") for k in keys[:-1]: config = config[k] if isinstance(config, dict) else config[int(k)] @@ -54,6 +67,15 @@ def _get_last_config_and_key(self, config: Union[Dict, List], id: str) -> Tuple[ return config, key def set_config(self, config: Any, id: Optional[str] = None): + """ + Set config content for the parser, if `id` provided, `config` will used to replace the config item with `id`. + + Args: + config: target config content to set. + id: nested id name to specify the target position, joined by "#" mark, use index from 0 for list. + for example: "transforms#5", "transforms#5##keys", etc. + + """ if isinstance(id, str): conf_, key = self._get_last_config_and_key(config=self.config, id=id) conf_[key] = config @@ -62,12 +84,35 @@ def set_config(self, config: Any, id: Optional[str] = None): self.resolved = False def get_config(self, id: Optional[str] = None): + """ + Get config content from the parser, if `id` provided, get the config item with `id`. + + Args: + id: nested id name to specify the expected position, joined by "#" mark, use index from 0 for list. + for example: "transforms#5", "transforms#5##keys", etc. + + """ if isinstance(id, str): conf_, key = self._get_last_config_and_key(config=self.config, id=id) return conf_[key] return self.config def _do_parse(self, config, id: Optional[str] = None): + """ + Recursively parse the nested config content, add every config item as component to the resolver. + For example, `{"preprocessing": [{"": "LoadImage", "": {"keys": "image"}}]}` is parsed as items: + - `id="preprocessing", config=[{"": "LoadImage", "": {"keys": "image"}}]` + - `id="preprocessing#0", config={"": "LoadImage", "": {"keys": "image"}}` + - `id="preprocessing#0#", config="LoadImage"` + - `id="preprocessing#0#", config={"keys": "image"}` + - `id="preprocessing#0##keys", config="image"` + + Args: + config: config content to parse. + id: id name of current config item, nested ids are joined by "#" mark. defaults to None. + for example: "transforms#5", "transforms#5##keys", etc. + + """ if isinstance(config, dict): for k, v in config.items(): sub_id = k if id is None else f"{id}#{k}" @@ -81,7 +126,14 @@ def _do_parse(self, config, id: Optional[str] = None): ConfigComponent(id=id, config=config, module_scanner=self.module_scanner, globals=self.global_imports) ) - def resolve_config(self, resolve_all: bool = False): + def parse_config(self, resolve_all: bool = False): + """ + Parse the config content, add every config item as component to the resolver. + + Args: + resolve_all: if True, resolve all the components and build instances directly. + + """ self.config_resolver = ConfigResolver() self._do_parse(config=self.config) @@ -90,16 +142,40 @@ def resolve_config(self, resolve_all: bool = False): self.resolved = True def get_resolved_config(self, id: str): + """ + Get the resolved instance component, if not resolved, try to resolve it first. + + Args: + id: id name of expected config component, nested ids are joined by "#" mark. + for example: "transforms#5", "transforms#5##keys", etc. + + """ if self.config_resolver is None or not self.resolved: - self.resolve_config() + self.parse_config() return self.config_resolver.get_resolved_config(id=id) def get_resolved_component(self, id: str): + """ + Get the resolved config component, if not resolved, try to resolve it first. + It can be used to modify the config again and support lazy instantiation. + + Args: + id: id name of expected config component, nested ids are joined by "#" mark. + for example: "transforms#5", "transforms#5##keys", etc. + + """ if self.config_resolver is None or not self.resolved: - self.resolve_config() - return self.config_resolver.get_resolved_compnent(id=id) + self.parse_config() + return self.config_resolver.get_resolved_component(id=id) def build(self, config: Dict): + """ + Build a config to instance if no dependencies, usually used for lazy instantiation or ad-hoc build. + + Args: + config: dictionary config content to build. + + """ return ConfigComponent( id=None, config=config, module_scanner=self.module_scanner, globals=self.global_imports ).build() diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py index 3fd050dc3a..8fb02491bb 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_resolver.py @@ -24,8 +24,9 @@ class ModuleScanner: Map the all the class names and the module names in a table. Args: - pkgs: the expected packages to scan. - modules: the expected modules in the packages to scan. + pkgs: the expected packages to scan modules and parse class names in the config. + modules: the expected modules in the packages to scan for all the classes. + for example, to parser "LoadImage" in config, `pkgs` can be ["monai"], `modules` can be ["transforms"]. """ @@ -52,10 +53,47 @@ def _create_classes_table(self): return class_table def get_class_module_name(self, class_name): + """ + Get the module name of the class with specified class name. + + Args: + class_name: name of the expected class. + + """ return self._class_table.get(class_name, None) class ConfigComponent: + """ + Utility class to manage every component in the config with a unique `id` name. + When recursively parsing a complicated config dictioanry, every item should be treated as a `ConfigComponent`. + For example: + - `{"preprocessing": [{"name": "LoadImage", "args": {"keys": "image"}}]}` + - `{"name": "LoadImage", "args": {"keys": "image"}}` + - `"name": "LoadImage"` + - `"keys": "image"` + + It can search the config content and find out all the dependencies, then build the config to instance + when all the dependencies are resolved. + + Here we predefined several special marks to parse the config content: + - "": like "" is the name of a class, to distinguish it with regular key "name" in the config content. + now we have 4 keys: ``, ``, ``, ``. + - "XXX#YYY": join nested config ids, like "transforms#5" is id name of the 6th transform in the transforms list. + - "@XXX": use an instance as config item, like `"dataset": "@dataset"` uses `dataset` instance as the parameter. + - "$XXX": execute the string after "$" as python code with `eval()` function, like "$@model.parameters()". + + Args: + id: id name of current config component, for nested config items, use `#` to join ids. + for list component, use index from `0` as id. + for example: `transform`, `transform#5`, `transform#5##keys`, etc. + config: config content of current component, can be a `dict`, `list`, `string`, `float`, `int`, etc. + module_scanner: ModuleScanner to help get the class name or path in the config and build instance. + globals: to support executable string in the config, sometimes we need to provide the global variables + which are referred in the executable string. for example: `globals={"monai": monai} will be useful + for config `"collate_fn": "$monai.data.list_data_collate"`. + + """ def __init__(self, id: str, config: Any, module_scanner: ModuleScanner, globals: Optional[Dict] = None) -> None: self.id = id self.config = config @@ -63,18 +101,49 @@ def __init__(self, id: str, config: Any, module_scanner: ModuleScanner, globals: self.globals = globals def get_id(self) -> str: + """ + Get the id name of current config component. + + """ return self.id def get_config(self): + """ + Get the raw config content of current config component. + + """ return self.config def get_referenced_ids(self) -> List[str]: + """ + Recursively search all the content of current config compoent to get the ids of dependencies. + Must build all the dependencies before build current config component. + For `dict` and `list`, treat every item as a dependency. + For example, for `{"name": "DataLoader", "args": {"dataset": "@dataset"}}`, the dependency ids: + `["name", "args", "args#dataset", "dataset"]`. + + """ return search_configs_with_objs(self.config, [], id=self.id) def get_updated_config(self, refs: dict): + """ + If all the dependencies are ready in `refs`, update the config content with them and return new config. + It can be used for lazy instantiation. + + Args: + refs: all the dependent components with ids. + + """ return update_configs_with_objs(config=self.config, refs=refs, id=self.id, globals=self.globals) def _check_dependency(self, config): + """ + Check whether current config still has unresolved dependencies or executable string code. + + Args: + config: config content to check. + + """ if isinstance(config, list): for i in config: if self._check_dependency(i): @@ -95,13 +164,13 @@ def build(self, config: Optional[Dict] = None) -> object: - '' - class name in the modules of packages. - '' - directly specify the class path, based on PYTHONPATH, ignore '' if specified. - '' - arguments to initialize the component instance. - - '' - if defined `'': true`, will skip the buiding, useful for development or tuning. + - '' - if defined `'': True`, will skip the buiding, useful for development or tuning. Args: - config: dictionary config to define a component. + config: dictionary config that defines a component. Raises: - ValueError: must provide `path` or `name` of class to build component. + ValueError: must provide `` or `` of class to build component. ValueError: can not find component class. """ @@ -121,6 +190,13 @@ def build(self, config: Optional[Dict] = None) -> object: return instantiate_class(class_path, **class_args) def _get_class_path(self, config): + """ + Get the path of class specified in the config content. + + Args: + config: dictionary config that defines a component. + + """ class_path = config.get("", None) if class_path is None: class_name = config.get("", None) @@ -135,18 +211,43 @@ def _get_class_path(self, config): class ConfigResolver: + """ + Utility class to resolve the dependencies between config components and build instance for specified `id`. + + Args: + components: config components to resolve, if None, can also `add()` component in runtime. + + """ def __init__(self, components: Optional[Dict[str, ConfigComponent]] = None): self.resolved_configs = {} self.resolved_components = {} self.components = {} if components is None else components def add(self, component: ConfigComponent): + """ + Add a component to the resolution graph. + + Args: + component: a config component to resolve. + + """ id = component.get_id() if id in self.components: raise ValueError(f"id '{id}' is already added.") self.components[id] = component def _resolve_one_component(self, id: str, instantiate: bool = True) -> bool: + """ + Resolve one component with specified id name. + If has unresolved dependencies, recursively resolve the dependencies first. + + Args: + id: id name of expected component to resolve. + instantiate: after resolving all the dependencies, whether to build instance. + if False, can support lazy instantiation with the resolved config later. + default to `True`. + + """ com = self.components[id] # check whether the obj has any unresolved refs in its args ref_ids = com.get_referenced_ids() @@ -173,15 +274,35 @@ def _resolve_one_component(self, id: str, instantiate: bool = True) -> bool: return updated_config, resolved_com def resolve_all(self): + """ + Resolve all the components and build instances. + + """ for k in self.components.keys(): self._resolve_one_component(id=k, instantiate=True) - def get_resolved_compnent(self, id: str): + def get_resolved_component(self, id: str): + """ + Get the resolved instance component with specified id name. + If not resolved, try to resolve it first. + + Args: + id: id name of the expected component. + + """ if id not in self.resolved_components: self._resolve_one_component(id=id, instantiate=True) return self.resolved_components[id] def get_resolved_config(self, id: str): + """ + Get the resolved config component with specified id name, then can be used for lazy instantiation. + If not resolved, try to resolve it with `instantiation=False` first. + + Args: + id: id name of the expected config component. + + """ if id not in self.resolved_configs: config, _ = self._resolve_one_component(id=id, instantiate=False) else: From 577630133bdb60021d949cce8ab65d1c4e5049cd Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 24 Jan 2022 17:46:53 +0800 Subject: [PATCH 09/30] [DLMED] fix typo Signed-off-by: Nic Ma --- monai/apps/mmars/config_resolver.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py index 8fb02491bb..ed53b10fda 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_resolver.py @@ -68,9 +68,9 @@ class ConfigComponent: Utility class to manage every component in the config with a unique `id` name. When recursively parsing a complicated config dictioanry, every item should be treated as a `ConfigComponent`. For example: - - `{"preprocessing": [{"name": "LoadImage", "args": {"keys": "image"}}]}` - - `{"name": "LoadImage", "args": {"keys": "image"}}` - - `"name": "LoadImage"` + - `{"preprocessing": [{"": "LoadImage", "": {"keys": "image"}}]}` + - `{"": "LoadImage", "": {"keys": "image"}}` + - `"": "LoadImage"` - `"keys": "image"` It can search the config content and find out all the dependencies, then build the config to instance @@ -119,8 +119,8 @@ def get_referenced_ids(self) -> List[str]: Recursively search all the content of current config compoent to get the ids of dependencies. Must build all the dependencies before build current config component. For `dict` and `list`, treat every item as a dependency. - For example, for `{"name": "DataLoader", "args": {"dataset": "@dataset"}}`, the dependency ids: - `["name", "args", "args#dataset", "dataset"]`. + For example, for `{"": "DataLoader", "": {"dataset": "@dataset"}}`, the dependency ids: + `["", "", "#dataset", "dataset"]`. """ return search_configs_with_objs(self.config, [], id=self.id) From 8cc4d3e7ce35efa1e3e3aabfc0cefa778d0fc918 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Jan 2022 09:50:43 +0000 Subject: [PATCH 10/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/mmars/config_resolver.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py index ed53b10fda..e553fb7876 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_resolver.py @@ -75,14 +75,14 @@ class ConfigComponent: It can search the config content and find out all the dependencies, then build the config to instance when all the dependencies are resolved. - + Here we predefined several special marks to parse the config content: - "": like "" is the name of a class, to distinguish it with regular key "name" in the config content. now we have 4 keys: ``, ``, ``, ``. - "XXX#YYY": join nested config ids, like "transforms#5" is id name of the 6th transform in the transforms list. - "@XXX": use an instance as config item, like `"dataset": "@dataset"` uses `dataset` instance as the parameter. - "$XXX": execute the string after "$" as python code with `eval()` function, like "$@model.parameters()". - + Args: id: id name of current config component, for nested config items, use `#` to join ids. for list component, use index from `0` as id. @@ -192,7 +192,7 @@ def build(self, config: Optional[Dict] = None) -> object: def _get_class_path(self, config): """ Get the path of class specified in the config content. - + Args: config: dictionary config that defines a component. @@ -213,7 +213,7 @@ def _get_class_path(self, config): class ConfigResolver: """ Utility class to resolve the dependencies between config components and build instance for specified `id`. - + Args: components: config components to resolve, if None, can also `add()` component in runtime. From 0f66ad8ccf3c373e195d08e3b79152d56594bc3f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 24 Jan 2022 20:37:05 +0800 Subject: [PATCH 11/30] [DLMED] update to dependencies Signed-off-by: Nic Ma --- monai/apps/mmars/config_resolver.py | 32 +++++++++--------- monai/apps/mmars/utils.py | 52 +++++++++++++++++++++-------- tests/test_config_component.py | 26 +++++++-------- tests/test_config_parser.py | 4 ++- tests/test_config_resolver.py | 4 +-- 5 files changed, 72 insertions(+), 46 deletions(-) diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py index e553fb7876..81324a97fb 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_resolver.py @@ -114,27 +114,27 @@ def get_config(self): """ return self.config - def get_referenced_ids(self) -> List[str]: + def get_dependent_ids(self) -> List[str]: """ Recursively search all the content of current config compoent to get the ids of dependencies. - Must build all the dependencies before build current config component. + It's used to build all the dependencies before build current config component. For `dict` and `list`, treat every item as a dependency. For example, for `{"": "DataLoader", "": {"dataset": "@dataset"}}`, the dependency ids: `["", "", "#dataset", "dataset"]`. """ - return search_configs_with_objs(self.config, [], id=self.id) + return search_configs_with_objs(config=self.config, id=self.id, deps=[]) - def get_updated_config(self, refs: dict): + def get_updated_config(self, deps: dict): """ - If all the dependencies are ready in `refs`, update the config content with them and return new config. + If all the dependencies are ready in `deps`, update the config content with them and return new config. It can be used for lazy instantiation. Args: - refs: all the dependent components with ids. + deps: all the dependent components with ids. """ - return update_configs_with_objs(config=self.config, refs=refs, id=self.id, globals=self.globals) + return update_configs_with_objs(config=self.config, deps=deps, id=self.id, globals=self.globals) def _check_dependency(self, config): """ @@ -249,21 +249,21 @@ def _resolve_one_component(self, id: str, instantiate: bool = True) -> bool: """ com = self.components[id] - # check whether the obj has any unresolved refs in its args - ref_ids = com.get_referenced_ids() - refs = {} + # check whether the obj has any unresolved deps in its args + ref_ids = com.get_dependent_ids() + deps = {} if len(ref_ids) > 0: - # see whether all refs are resolved + # see whether all deps are resolved for comp_id in ref_ids: if comp_id not in self.resolved_components: - # this referenced component is not resolved + # this dependent component is not resolved if comp_id not in self.components: - raise RuntimeError(f"the reference component `{comp_id}` is not in config.") + raise RuntimeError(f"the dependent component `{comp_id}` is not in config.") # resolve the dependency first self._resolve_one_component(id=comp_id, instantiate=True) - refs[comp_id] = self.resolved_components[comp_id] - # all referenced components are resolved already - updated_config = com.get_updated_config(refs) + deps[comp_id] = self.resolved_components[comp_id] + # all dependent components are resolved already + updated_config = com.get_updated_config(deps) resolved_com = None if instantiate: diff --git a/monai/apps/mmars/utils.py b/monai/apps/mmars/utils.py index 8336b4460b..83a4c9faad 100644 --- a/monai/apps/mmars/utils.py +++ b/monai/apps/mmars/utils.py @@ -64,49 +64,73 @@ def instantiate_class(class_path: str, **kwargs): raise ValueError(f"class {class_path} has parameters error.") from e -def search_configs_with_objs(config: Union[Dict, List, str], refs: List[str], id: str): +def search_configs_with_objs(config: Union[Dict, List, str], id: str, deps: List[str] = []): + """ + Recursively search all the content of input config compoent to get the ids of dependencies. + It's used to build all the dependencies before build current config component. + For `dict` and `list`, treat every item as a dependency. + For example, for `{"": "DataLoader", "": {"dataset": "@dataset"}}`, the dependency ids: + `["", "", "#dataset", "dataset"]`. + + Args: + config: input config content to search. + id: id name for the input config. + deps: list of the id name of existing dependencies, default to empty. + + """ pattern = re.compile(r'@\w*[\#\w]*') # match ref as args: "@XXX#YYY#ZZZ" if isinstance(config, list): for i, v in enumerate(config): sub_id = f"{id}#{i}" # all the items in the list should be marked as dependent reference - refs.append(sub_id) - refs = search_configs_with_objs(v, refs, sub_id) + deps.append(sub_id) + deps = search_configs_with_objs(v, sub_id, deps) if isinstance(config, dict): for k, v in config.items(): sub_id = f"{id}#{k}" # all the items in the dict should be marked as dependent reference - refs.append(sub_id) - refs = search_configs_with_objs(v, refs, sub_id) + deps.append(sub_id) + deps = search_configs_with_objs(v, sub_id, deps) if isinstance(config, str): result = pattern.findall(config) for item in result: if config.startswith("$") or config == item: ref_obj_id = item[1:] - if ref_obj_id not in refs: - refs.append(ref_obj_id) - return refs + if ref_obj_id not in deps: + deps.append(ref_obj_id) + return deps -def update_configs_with_objs(config: Union[Dict, List, str], refs: dict, id: str, globals: Optional[Dict] = None): +def update_configs_with_objs(config: Union[Dict, List, str], deps: dict, id: str, globals: Optional[Dict] = None): + """ + With all the dependencies in `deps`, update the config content with them and return new config. + It can be used for lazy instantiation. + + Args: + config: input config content to update. + deps: all the dependent components with ids. + id: id name for the input config. + globals: predefined global variables to execute code string with `eval()`. + + """ pattern = re.compile(r'@\w*[\#\w]*') # match ref as args: "@XXX#YYY#ZZZ" if isinstance(config, list): # all the items in the list should be replaced with the reference - config = [refs[f"{id}#{i}"] for i in range(len(config))] + config = [deps[f"{id}#{i}"] for i in range(len(config))] if isinstance(config, dict): # all the items in the dict should be replaced with the reference - config = {k: refs[f"{id}#{k}"] for k, _ in config.items()} + config = {k: deps[f"{id}#{k}"] for k, _ in config.items()} if isinstance(config, str): result = pattern.findall(config) for item in result: ref_obj_id = item[1:] if config.startswith("$"): # replace with local code and execute soon - config = config.replace(item, f"refs['{ref_obj_id}']") + config = config.replace(item, f"deps['{ref_obj_id}']") elif config == item: - config = refs[ref_obj_id] + config = deps[ref_obj_id] if isinstance(config, str): if config.startswith("$"): - config = eval(config[1:], globals, {"refs": refs}) + config = eval(config[1:], globals, {"deps": deps}) return config diff --git a/tests/test_config_component.py b/tests/test_config_component.py index 2ec5dfd46c..7a1b88abd0 100644 --- a/tests/test_config_component.py +++ b/tests/test_config_component.py @@ -55,49 +55,49 @@ {"": "RandTorchVisiond", "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}}, RandTorchVisiond, ] -# test refs of dict config +# test dependencies of dict config TEST_CASE_7 = [ {"dataset": "@dataset", "batch_size": 2}, ["test#dataset", "dataset", "test#batch_size"], ] -# test refs of list config +# test dependencies of list config TEST_CASE_8 = [ {"dataset": "@dataset", "transforms": ["@trans0", "@trans1"]}, ["test#dataset", "dataset", "test#transforms", "test#transforms#0", "trans0", "test#transforms#1", "trans1"], ] -# test refs of execute code +# test dependencies of execute code TEST_CASE_9 = [ {"dataset": "$@dataset.test_func()", "transforms": ["$torch.zeros([2, 2]) + @trans"]}, ["test#dataset", "dataset", "test#transforms", "test#transforms#0", "trans"], ] -# test refs of lambda function +# test dependencies of lambda function TEST_CASE_10 = [ {"lr_range": "$lambda x: x + @num_epochs", "lr": ["$lambda x: torch.zeros([2, 2]) + @init_lr"]}, ["test#lr_range", "num_epochs", "test#lr", "test#lr#0", "init_lr"], ] -# test instance with no reference +# test instance with no dependencies TEST_CASE_11 = [ "transform#1", {"": "LoadImaged", "": {"keys": ["image"]}}, {"transform#1#": "LoadImaged", "transform#1#": {"keys": ["image"]}}, LoadImaged, ] -# test dataloader refers to `@dataset`, here we don't test recursive reference, test that in `ConfigResolver` +# test dataloader refers to `@dataset`, here we don't test recursive dependencies, test that in `ConfigResolver` TEST_CASE_12 = [ "dataloader", {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}}, {"dataloader#": "DataLoader", "dataloader#": {"dataset": Dataset(data=[1, 2]), "batch_size": 2}}, DataLoader, ] -# test reference in code execution +# test dependencies in code execution TEST_CASE_13 = [ "optimizer", {"": "Adam", "": {"params": "$@model.parameters()", "lr": "@learning_rate"}}, {"optimizer#": "Adam", "optimizer#": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, torch.optim.Adam, ] -# test replace reference with code execution result +# test replace dependencies with code execution result TEST_CASE_14 = [ "optimizer##params", "$@model.parameters()", @@ -111,7 +111,7 @@ {}, Callable, ] -# test lambda function, should not execute the lambda function, just change the string with reference objects +# test lambda function, should not execute the lambda function, just change the string with dependent objects TEST_CASE_16 = [ "dataloader##collate_fn", "$lambda x: monai.data.list_data_collate(x) + 100", @@ -134,19 +134,19 @@ def test_build(self, input_param, test_input, output_type): self.assertDictEqual(ret, test_input) @parameterized.expand([TEST_CASE_7, TEST_CASE_8, TEST_CASE_9, TEST_CASE_10]) - def test_reference_ids(self, test_input, ref_ids): + def test_dependent_ids(self, test_input, ref_ids): scanner = ModuleScanner(pkgs=[], modules=[]) configer = ConfigComponent(id="test", config=test_input, module_scanner=scanner) - ret = configer.get_referenced_ids() + ret = configer.get_dependent_ids() self.assertListEqual(ret, ref_ids) @parameterized.expand([TEST_CASE_11, TEST_CASE_12, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15, TEST_CASE_16]) - def test_update_reference(self, id, test_input, refs, output_type): + def test_update_dependencies(self, id, test_input, deps, output_type): scanner = ModuleScanner(pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"]) configer = ConfigComponent( id=id, config=test_input, module_scanner=scanner, globals={"monai": monai, "torch": torch} ) - config = configer.get_updated_config(refs) + config = configer.get_updated_config(deps) ret = configer.build(config) self.assertTrue(isinstance(ret, output_type)) diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index 2a987241b0..ba3a6b63d2 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -49,7 +49,9 @@ def test_config_content(self): @parameterized.expand([TEST_CASE_1]) def test_parse(self, config, expected_ids, output_types): parser = ConfigParser( - pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"], global_imports=["monai"], config=config + pkgs=["torch.optim", "monai"], + modules=["data", "transforms", "adam"],global_imports={"monai": "monai"}, + config=config, ) for id, cls in zip(expected_ids, output_types): config = parser.get_resolved_config(id) diff --git a/tests/test_config_resolver.py b/tests/test_config_resolver.py index 3d35efdda0..3b6d3d4cb3 100644 --- a/tests/test_config_resolver.py +++ b/tests/test_config_resolver.py @@ -19,7 +19,7 @@ from monai.data import DataLoader from monai.transforms import LoadImaged, RandTorchVisiond -# test instance with no reference +# test instance with no dependencies TEST_CASE_1 = [ { # all the recursively parsed config items @@ -80,7 +80,7 @@ def test_resolve(self, configs, expected_id, output_type): resolver.add(ConfigComponent( id=k, config=v, module_scanner=scanner, globals={"monai": monai, "torch": torch} )) - ins = resolver.get_resolved_compnent(expected_id) + ins = resolver.get_resolved_component(expected_id) self.assertTrue(isinstance(ins, output_type)) config = resolver.get_resolved_config(expected_id) # test lazy instantiation From 9d38e444b40b8f19e04643eef00e7943e8483f0a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Jan 2022 12:37:56 +0000 Subject: [PATCH 12/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/mmars/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/mmars/utils.py b/monai/apps/mmars/utils.py index 83a4c9faad..24a985a47e 100644 --- a/monai/apps/mmars/utils.py +++ b/monai/apps/mmars/utils.py @@ -71,7 +71,7 @@ def search_configs_with_objs(config: Union[Dict, List, str], id: str, deps: List For `dict` and `list`, treat every item as a dependency. For example, for `{"": "DataLoader", "": {"dataset": "@dataset"}}`, the dependency ids: `["", "", "#dataset", "dataset"]`. - + Args: config: input config content to search. id: id name for the input config. From 5646b0ab220970c91d835fcd1a951330cd17f3a1 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 24 Jan 2022 21:45:53 +0800 Subject: [PATCH 13/30] [DLMED] add circle dependency check Signed-off-by: Nic Ma --- monai/apps/mmars/__init__.py | 2 +- monai/apps/mmars/config_parser.py | 3 ++- monai/apps/mmars/config_resolver.py | 39 ++++++++++++++++++++--------- monai/apps/mmars/utils.py | 12 +++++---- tests/test_config_component.py | 32 ++++++----------------- tests/test_config_parser.py | 19 +++++++++----- tests/test_config_resolver.py | 33 ++++++++++++++++++------ 7 files changed, 83 insertions(+), 57 deletions(-) diff --git a/monai/apps/mmars/__init__.py b/monai/apps/mmars/__init__.py index f357635598..e6d97b02c3 100644 --- a/monai/apps/mmars/__init__.py +++ b/monai/apps/mmars/__init__.py @@ -10,7 +10,7 @@ # limitations under the License. from .config_parser import ConfigParser -from .config_resolver import ModuleScanner, ConfigComponent, ConfigResolver +from .config_resolver import ConfigComponent, ConfigResolver, ModuleScanner from .mmars import download_mmar, get_model_spec, load_from_mmar from .model_desc import MODEL_DESC, RemoteMMARKeys from .utils import get_class, instantiate_class, search_configs_with_objs, update_configs_with_objs diff --git a/monai/apps/mmars/config_parser.py b/monai/apps/mmars/config_parser.py index 57b9c366a5..0b6049cd3a 100644 --- a/monai/apps/mmars/config_parser.py +++ b/monai/apps/mmars/config_parser.py @@ -11,6 +11,7 @@ import importlib from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + from monai.apps.mmars.config_resolver import ConfigComponent, ConfigResolver, ModuleScanner @@ -178,4 +179,4 @@ def build(self, config: Dict): """ return ConfigComponent( id=None, config=config, module_scanner=self.module_scanner, globals=self.global_imports - ).build() + ).build() diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py index 81324a97fb..2356f9a33b 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_resolver.py @@ -9,11 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Sequence import importlib import inspect import pkgutil import warnings +from typing import Any, Dict, List, Optional, Sequence from monai.apps.mmars.utils import instantiate_class, search_configs_with_objs, update_configs_with_objs @@ -94,6 +94,7 @@ class ConfigComponent: for config `"collate_fn": "$monai.data.list_data_collate"`. """ + def __init__(self, id: str, config: Any, module_scanner: ModuleScanner, globals: Optional[Dict] = None) -> None: self.id = id self.config = config @@ -123,7 +124,7 @@ def get_dependent_ids(self) -> List[str]: `["", "", "#dataset", "dataset"]`. """ - return search_configs_with_objs(config=self.config, id=self.id, deps=[]) + return search_configs_with_objs(config=self.config, id=self.id) def get_updated_config(self, deps: dict): """ @@ -179,9 +180,11 @@ def build(self, config: Optional[Dict] = None) -> object: warnings.warn("config content has other dependencies or executable string, skip `build`.") return config - if not isinstance(config, dict) \ - or ("" not in config and "" not in config) \ - or config.get("") is True: + if ( + not isinstance(config, dict) + or ("" not in config and "" not in config) + or config.get("") is True + ): # if marked as `disabled`, skip parsing return config @@ -218,6 +221,7 @@ class ConfigResolver: components: config components to resolve, if None, can also `add()` component in runtime. """ + def __init__(self, components: Optional[Dict[str, ConfigComponent]] = None): self.resolved_configs = {} self.resolved_components = {} @@ -236,7 +240,9 @@ def add(self, component: ConfigComponent): raise ValueError(f"id '{id}' is already added.") self.components[id] = component - def _resolve_one_component(self, id: str, instantiate: bool = True) -> bool: + def _resolve_one_component( + self, id: str, instantiate: bool = True, waiting_list: Optional[List[str]] = None + ) -> bool: """ Resolve one component with specified id name. If has unresolved dependencies, recursively resolve the dependencies first. @@ -246,21 +252,30 @@ def _resolve_one_component(self, id: str, instantiate: bool = True) -> bool: instantiate: after resolving all the dependencies, whether to build instance. if False, can support lazy instantiation with the resolved config later. default to `True`. + waiting_list: list of components wait to resolve dependencies. it's used to detect circular dependencies + when resolving dependencies like: `{"name": "A", "dep": "@B"}` and `{"name": "B", "dep": "@A"}`. """ + if waiting_list is None: + waiting_list = [] + waiting_list.append(id) com = self.components[id] - # check whether the obj has any unresolved deps in its args - ref_ids = com.get_dependent_ids() + dep_ids = com.get_dependent_ids() + # if current component has dependency already in the waiting list, that's circular dependencies + for d in dep_ids: + if d in waiting_list: + raise ValueError(f"detected circular dependencies for id='{d}' in the config content.") + deps = {} - if len(ref_ids) > 0: - # see whether all deps are resolved - for comp_id in ref_ids: + if len(dep_ids) > 0: + # # check whether the component has any unresolved deps + for comp_id in dep_ids: if comp_id not in self.resolved_components: # this dependent component is not resolved if comp_id not in self.components: raise RuntimeError(f"the dependent component `{comp_id}` is not in config.") # resolve the dependency first - self._resolve_one_component(id=comp_id, instantiate=True) + self._resolve_one_component(id=comp_id, instantiate=True, waiting_list=waiting_list) deps[comp_id] = self.resolved_components[comp_id] # all dependent components are resolved already updated_config = com.get_updated_config(deps) diff --git a/monai/apps/mmars/utils.py b/monai/apps/mmars/utils.py index 24a985a47e..db94bd6f15 100644 --- a/monai/apps/mmars/utils.py +++ b/monai/apps/mmars/utils.py @@ -10,8 +10,8 @@ # limitations under the License. import importlib -from optparse import Option import re +from optparse import Option from typing import Dict, List, Optional, Union @@ -64,7 +64,7 @@ def instantiate_class(class_path: str, **kwargs): raise ValueError(f"class {class_path} has parameters error.") from e -def search_configs_with_objs(config: Union[Dict, List, str], id: str, deps: List[str] = []): +def search_configs_with_objs(config: Union[Dict, List, str], id: str, deps: Optional[List[str]] = None): """ Recursively search all the content of input config compoent to get the ids of dependencies. It's used to build all the dependencies before build current config component. @@ -75,10 +75,12 @@ def search_configs_with_objs(config: Union[Dict, List, str], id: str, deps: List Args: config: input config content to search. id: id name for the input config. - deps: list of the id name of existing dependencies, default to empty. + deps: list of the id name of existing dependencies, default to None. """ - pattern = re.compile(r'@\w*[\#\w]*') # match ref as args: "@XXX#YYY#ZZZ" + if deps is None: + deps = [] + pattern = re.compile(r"@\w*[\#\w]*") # match ref as args: "@XXX#YYY#ZZZ" if isinstance(config, list): for i, v in enumerate(config): sub_id = f"{id}#{i}" @@ -113,7 +115,7 @@ def update_configs_with_objs(config: Union[Dict, List, str], deps: dict, id: str globals: predefined global variables to execute code string with `eval()`. """ - pattern = re.compile(r'@\w*[\#\w]*') # match ref as args: "@XXX#YYY#ZZZ" + pattern = re.compile(r"@\w*[\#\w]*") # match ref as args: "@XXX#YYY#ZZZ" if isinstance(config, list): # all the items in the list should be replaced with the reference config = [deps[f"{id}#{i}"] for i in range(len(config))] diff --git a/tests/test_config_component.py b/tests/test_config_component.py index 7a1b88abd0..e837c69e0d 100644 --- a/tests/test_config_component.py +++ b/tests/test_config_component.py @@ -9,15 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Iterator import unittest +from typing import Callable, Iterator import torch -import monai from parameterized import parameterized +import monai from monai.apps import ConfigComponent, ModuleScanner -from monai.data import Dataset, DataLoader +from monai.data import DataLoader, Dataset from monai.transforms import LoadImaged, RandTorchVisiond TEST_CASE_1 = [ @@ -56,10 +56,7 @@ RandTorchVisiond, ] # test dependencies of dict config -TEST_CASE_7 = [ - {"dataset": "@dataset", "batch_size": 2}, - ["test#dataset", "dataset", "test#batch_size"], -] +TEST_CASE_7 = [{"dataset": "@dataset", "batch_size": 2}, ["test#dataset", "dataset", "test#batch_size"]] # test dependencies of list config TEST_CASE_8 = [ {"dataset": "@dataset", "transforms": ["@trans0", "@trans1"]}, @@ -98,26 +95,11 @@ torch.optim.Adam, ] # test replace dependencies with code execution result -TEST_CASE_14 = [ - "optimizer##params", - "$@model.parameters()", - {"model": torch.nn.PReLU()}, - Iterator, -] +TEST_CASE_14 = ["optimizer##params", "$@model.parameters()", {"model": torch.nn.PReLU()}, Iterator] # test execute some function in args, test pre-imported global packages `monai` -TEST_CASE_15 = [ - "dataloader##collate_fn", - "$monai.data.list_data_collate", - {}, - Callable, -] +TEST_CASE_15 = ["dataloader##collate_fn", "$monai.data.list_data_collate", {}, Callable] # test lambda function, should not execute the lambda function, just change the string with dependent objects -TEST_CASE_16 = [ - "dataloader##collate_fn", - "$lambda x: monai.data.list_data_collate(x) + 100", - {}, - Callable, -] +TEST_CASE_16 = ["dataloader##collate_fn", "$lambda x: monai.data.list_data_collate(x) + 100", {}, Callable] class TestConfigComponent(unittest.TestCase): diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index ba3a6b63d2..e6bf8a83cc 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -10,10 +10,11 @@ # limitations under the License. import unittest + from parameterized import parameterized from monai.apps import ConfigParser -from monai.data import Dataset, DataLoader +from monai.data import DataLoader, Dataset from monai.transforms import Compose, LoadImaged, RandTorchVisiond # test the resolved and parsed instances @@ -21,10 +22,15 @@ { "transform": { "": "Compose", - "": {"transforms": [ - {"": "LoadImaged", "": {"keys": "image"}}, - {"": "RandTorchVisiond", "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}}, - ]} + "": { + "transforms": [ + {"": "LoadImaged", "": {"keys": "image"}}, + { + "": "RandTorchVisiond", + "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}, + }, + ] + }, }, "dataset": {"": "Dataset", "": {"data": [1, 2], "transform": "@transform"}}, "dataloader": { @@ -50,7 +56,8 @@ def test_config_content(self): def test_parse(self, config, expected_ids, output_types): parser = ConfigParser( pkgs=["torch.optim", "monai"], - modules=["data", "transforms", "adam"],global_imports={"monai": "monai"}, + modules=["data", "transforms", "adam"], + global_imports={"monai": "monai"}, config=config, ) for id, cls in zip(expected_ids, output_types): diff --git a/tests/test_config_resolver.py b/tests/test_config_resolver.py index 3b6d3d4cb3..e1a90e5885 100644 --- a/tests/test_config_resolver.py +++ b/tests/test_config_resolver.py @@ -10,11 +10,12 @@ # limitations under the License. import unittest +from distutils.command.config import config import torch -import monai from parameterized import parameterized +import monai from monai.apps import ConfigComponent, ConfigResolver, ModuleScanner from monai.data import DataLoader from monai.transforms import LoadImaged, RandTorchVisiond @@ -37,7 +38,8 @@ { # all the recursively parsed config items "dataloader": { - "": "DataLoader", "": {"dataset": "@dataset", "collate_fn": "$monai.data.list_data_collate"} + "": "DataLoader", + "": {"dataset": "@dataset", "collate_fn": "$monai.data.list_data_collate"}, }, "dataset": {"": "Dataset", "": {"data": [1, 2]}}, "dataloader#": "DataLoader", @@ -58,7 +60,8 @@ { # all the recursively parsed config items "transform#1": { - "": "RandTorchVisiond", "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25} + "": "RandTorchVisiond", + "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}, }, "transform#1#": "RandTorchVisiond", "transform#1#": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}, @@ -77,17 +80,33 @@ def test_resolve(self, configs, expected_id, output_type): scanner = ModuleScanner(pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"]) resolver = ConfigResolver() for k, v in configs.items(): - resolver.add(ConfigComponent( - id=k, config=v, module_scanner=scanner, globals={"monai": monai, "torch": torch} - )) + resolver.add( + ConfigComponent(id=k, config=v, module_scanner=scanner, globals={"monai": monai, "torch": torch}) + ) + ins = resolver.get_resolved_component(expected_id) + self.assertTrue(isinstance(ins, output_type)) + # test resolve all + resolver.resolved_configs = {} + resolver.resolved_components = {} + resolver.resolve_all() ins = resolver.get_resolved_component(expected_id) self.assertTrue(isinstance(ins, output_type)) - config = resolver.get_resolved_config(expected_id) # test lazy instantiation + config = resolver.get_resolved_config(expected_id) config[""] = False ins = ConfigComponent(id=expected_id, module_scanner=scanner, config=config).build() self.assertTrue(isinstance(ins, output_type)) + def test_circular_dependencies(self): + scanner = ModuleScanner(pkgs=[], modules=[]) + resolver = ConfigResolver() + configs = {"A": "@B", "B": "@C", "C": "@A"} + for k, v in configs.items(): + resolver.add(ConfigComponent(id=k, config=v, module_scanner=scanner)) + for k in ["A", "B", "C"]: + with self.assertRaises(ValueError): + resolver.get_resolved_component(k) + if __name__ == "__main__": unittest.main() From cfb2bdd54029a8f3c4857e91902167d3cef2d380 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 24 Jan 2022 22:27:43 +0800 Subject: [PATCH 14/30] [DLMED] fix flake8 error Signed-off-by: Nic Ma --- monai/apps/mmars/config_parser.py | 16 +++++------ monai/apps/mmars/config_resolver.py | 8 +++--- monai/apps/mmars/utils.py | 42 ++++++++++++++--------------- tests/test_config_resolver.py | 1 - 4 files changed, 32 insertions(+), 35 deletions(-) diff --git a/monai/apps/mmars/config_parser.py b/monai/apps/mmars/config_parser.py index 0b6049cd3a..b92a359bf9 100644 --- a/monai/apps/mmars/config_parser.py +++ b/monai/apps/mmars/config_parser.py @@ -10,7 +10,7 @@ # limitations under the License. import importlib -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Union from monai.apps.mmars.config_resolver import ConfigComponent, ConfigResolver, ModuleScanner @@ -37,21 +37,21 @@ def __init__( self, pkgs: Sequence[str], modules: Sequence[str], - global_imports: Optional[Dict[str, str]] = {"monai": "monai", "torch": "torch", "np": "numpy"}, + global_imports: Optional[Dict[str, Any]] = None, config: Optional[Any] = None, ): self.config = None if config is not None: self.set_config(config=config) self.module_scanner = ModuleScanner(pkgs=pkgs, modules=modules) - self.global_imports = {} + self.global_imports: Dict[str, Any] = {"monai": "monai", "torch": "torch", "np": "numpy"} if global_imports is not None: for k, v in global_imports.items(): self.global_imports[k] = importlib.import_module(v) - self.config_resolver: Optional[ConfigResolver] = None + self.config_resolver: ConfigResolver = ConfigResolver() self.resolved = False - def _get_last_config_and_key(self, config: Union[Dict, List], id: str) -> Tuple[Union[Dict, List], str]: + def _get_last_config_and_key(self, config: Union[Dict, List], id: str): """ Utility to get the last config item and the id from the whole config content with nested id name. @@ -77,7 +77,7 @@ def set_config(self, config: Any, id: Optional[str] = None): for example: "transforms#5", "transforms#5##keys", etc. """ - if isinstance(id, str): + if isinstance(id, str) and isinstance(self.config, (dict, list)): conf_, key = self._get_last_config_and_key(config=self.config, id=id) conf_[key] = config else: @@ -93,7 +93,7 @@ def get_config(self, id: Optional[str] = None): for example: "transforms#5", "transforms#5##keys", etc. """ - if isinstance(id, str): + if isinstance(id, str) and isinstance(self.config, (dict, list)): conf_, key = self._get_last_config_and_key(config=self.config, id=id) return conf_[key] return self.config @@ -178,5 +178,5 @@ def build(self, config: Dict): """ return ConfigComponent( - id=None, config=config, module_scanner=self.module_scanner, globals=self.global_imports + id="", config=config, module_scanner=self.module_scanner, globals=self.global_imports ).build() diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py index 2356f9a33b..ec96e70c9f 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_resolver.py @@ -223,8 +223,8 @@ class ConfigResolver: """ def __init__(self, components: Optional[Dict[str, ConfigComponent]] = None): - self.resolved_configs = {} - self.resolved_components = {} + self.resolved_configs: Dict[str, str] = {} + self.resolved_components: Dict[str, Any] = {} self.components = {} if components is None else components def add(self, component: ConfigComponent): @@ -240,9 +240,7 @@ def add(self, component: ConfigComponent): raise ValueError(f"id '{id}' is already added.") self.components[id] = component - def _resolve_one_component( - self, id: str, instantiate: bool = True, waiting_list: Optional[List[str]] = None - ) -> bool: + def _resolve_one_component(self, id: str, instantiate: bool = True, waiting_list: Optional[List[str]] = None): """ Resolve one component with specified id name. If has unresolved dependencies, recursively resolve the dependencies first. diff --git a/monai/apps/mmars/utils.py b/monai/apps/mmars/utils.py index db94bd6f15..51828362b7 100644 --- a/monai/apps/mmars/utils.py +++ b/monai/apps/mmars/utils.py @@ -11,7 +11,6 @@ import importlib import re -from optparse import Option from typing import Dict, List, Optional, Union @@ -64,7 +63,7 @@ def instantiate_class(class_path: str, **kwargs): raise ValueError(f"class {class_path} has parameters error.") from e -def search_configs_with_objs(config: Union[Dict, List, str], id: str, deps: Optional[List[str]] = None): +def search_configs_with_objs(config: Union[Dict, List, str], id: str, deps: Optional[List[str]] = None) -> List[str]: """ Recursively search all the content of input config compoent to get the ids of dependencies. It's used to build all the dependencies before build current config component. @@ -78,29 +77,28 @@ def search_configs_with_objs(config: Union[Dict, List, str], id: str, deps: Opti deps: list of the id name of existing dependencies, default to None. """ - if deps is None: - deps = [] + deps_: List[str] = [] if deps is None else deps pattern = re.compile(r"@\w*[\#\w]*") # match ref as args: "@XXX#YYY#ZZZ" if isinstance(config, list): for i, v in enumerate(config): sub_id = f"{id}#{i}" # all the items in the list should be marked as dependent reference - deps.append(sub_id) - deps = search_configs_with_objs(v, sub_id, deps) + deps_.append(sub_id) + deps_ = search_configs_with_objs(v, sub_id, deps_) if isinstance(config, dict): for k, v in config.items(): sub_id = f"{id}#{k}" # all the items in the dict should be marked as dependent reference - deps.append(sub_id) - deps = search_configs_with_objs(v, sub_id, deps) + deps_.append(sub_id) + deps_ = search_configs_with_objs(v, sub_id, deps_) if isinstance(config, str): result = pattern.findall(config) for item in result: if config.startswith("$") or config == item: ref_obj_id = item[1:] - if ref_obj_id not in deps: - deps.append(ref_obj_id) - return deps + if ref_obj_id not in deps_: + deps_.append(ref_obj_id) + return deps_ def update_configs_with_objs(config: Union[Dict, List, str], deps: dict, id: str, globals: Optional[Dict] = None): @@ -118,21 +116,23 @@ def update_configs_with_objs(config: Union[Dict, List, str], deps: dict, id: str pattern = re.compile(r"@\w*[\#\w]*") # match ref as args: "@XXX#YYY#ZZZ" if isinstance(config, list): # all the items in the list should be replaced with the reference - config = [deps[f"{id}#{i}"] for i in range(len(config))] + return [deps[f"{id}#{i}"] for i in range(len(config))] if isinstance(config, dict): # all the items in the dict should be replaced with the reference - config = {k: deps[f"{id}#{k}"] for k, _ in config.items()} + return {k: deps[f"{id}#{k}"] for k, _ in config.items()} if isinstance(config, str): result = pattern.findall(config) + config_: str = config for item in result: ref_obj_id = item[1:] - if config.startswith("$"): + if config_.startswith("$"): # replace with local code and execute soon - config = config.replace(item, f"deps['{ref_obj_id}']") - elif config == item: - config = deps[ref_obj_id] - - if isinstance(config, str): - if config.startswith("$"): - config = eval(config[1:], globals, {"deps": deps}) + config_ = config_.replace(item, f"deps['{ref_obj_id}']") + elif config_ == item: + config_ = deps[ref_obj_id] + + if isinstance(config_, str): + if config_.startswith("$"): + config_ = eval(config_[1:], globals, {"deps": deps}) + return config_ return config diff --git a/tests/test_config_resolver.py b/tests/test_config_resolver.py index e1a90e5885..52c4d6f76e 100644 --- a/tests/test_config_resolver.py +++ b/tests/test_config_resolver.py @@ -10,7 +10,6 @@ # limitations under the License. import unittest -from distutils.command.config import config import torch from parameterized import parameterized From 55a210e0255187548a6e5ce2fc3a9415415e98af Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 24 Jan 2022 23:13:40 +0800 Subject: [PATCH 15/30] [DLMED] fix min tests Signed-off-by: Nic Ma --- tests/test_config_component.py | 7 ++++++- tests/test_config_parser.py | 5 +++++ tests/test_config_resolver.py | 5 ++++- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/tests/test_config_component.py b/tests/test_config_component.py index e837c69e0d..c11cb62f21 100644 --- a/tests/test_config_component.py +++ b/tests/test_config_component.py @@ -19,6 +19,9 @@ from monai.apps import ConfigComponent, ModuleScanner from monai.data import DataLoader, Dataset from monai.transforms import LoadImaged, RandTorchVisiond +from monai.utils import optional_import + +_, has_tv = optional_import("torchvision") TEST_CASE_1 = [ dict(pkgs=["monai"], modules=["transforms"]), @@ -103,7 +106,9 @@ class TestConfigComponent(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand( + [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5] + ([TEST_CASE_6] if has_tv else []) + ) def test_build(self, input_param, test_input, output_type): scanner = ModuleScanner(**input_param) configer = ConfigComponent(id="test", config=test_input, module_scanner=scanner) diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index e6bf8a83cc..0cb1ec92b7 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -10,12 +10,16 @@ # limitations under the License. import unittest +from unittest import skipUnless from parameterized import parameterized from monai.apps import ConfigParser from monai.data import DataLoader, Dataset from monai.transforms import Compose, LoadImaged, RandTorchVisiond +from monai.utils import optional_import + +_, has_tv = optional_import("torchvision") # test the resolved and parsed instances TEST_CASE_1 = [ @@ -53,6 +57,7 @@ def test_config_content(self): self.assertDictEqual(parser.get_config(id="preprocessing#0#datasets"), {"name": "CacheDataset"}) @parameterized.expand([TEST_CASE_1]) + @skipUnless(has_tv, "Requires tifffile.") def test_parse(self, config, expected_ids, output_types): parser = ConfigParser( pkgs=["torch.optim", "monai"], diff --git a/tests/test_config_resolver.py b/tests/test_config_resolver.py index 52c4d6f76e..07a1a9a784 100644 --- a/tests/test_config_resolver.py +++ b/tests/test_config_resolver.py @@ -18,6 +18,9 @@ from monai.apps import ConfigComponent, ConfigResolver, ModuleScanner from monai.data import DataLoader from monai.transforms import LoadImaged, RandTorchVisiond +from monai.utils import optional_import + +_, has_tv = optional_import("torchvision") # test instance with no dependencies TEST_CASE_1 = [ @@ -74,7 +77,7 @@ class TestConfigComponent(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2] + ([TEST_CASE_3] if has_tv else [])) def test_resolve(self, configs, expected_id, output_type): scanner = ModuleScanner(pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"]) resolver = ConfigResolver() From d598be600d2363ceef94eb247cfbbe06908c4ecc Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 24 Jan 2022 23:31:30 +0800 Subject: [PATCH 16/30] [DLMED] fix docs Signed-off-by: Nic Ma --- monai/apps/mmars/config_resolver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py index ec96e70c9f..63667b9feb 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_resolver.py @@ -78,7 +78,7 @@ class ConfigComponent: Here we predefined several special marks to parse the config content: - "": like "" is the name of a class, to distinguish it with regular key "name" in the config content. - now we have 4 keys: ``, ``, ``, ``. + now we have 4 keys: ``, ``, ``, ``. - "XXX#YYY": join nested config ids, like "transforms#5" is id name of the 6th transform in the transforms list. - "@XXX": use an instance as config item, like `"dataset": "@dataset"` uses `dataset` instance as the parameter. - "$XXX": execute the string after "$" as python code with `eval()` function, like "$@model.parameters()". From 1bee7f183f970994ac40d97452638c29deb9b498 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 25 Jan 2022 12:42:17 +0800 Subject: [PATCH 17/30] [DLMED] update according to comments Signed-off-by: Nic Ma --- docs/source/apps.rst | 3 - monai/apps/__init__.py | 7 +- monai/apps/mmars/__init__.py | 4 +- monai/apps/mmars/config_parser.py | 9 +-- monai/apps/mmars/config_resolver.py | 65 +++-------------- monai/apps/mmars/utils.py | 58 ++-------------- monai/utils/__init__.py | 3 + monai/utils/module.py | 104 +++++++++++++++++++++++++++- tests/test_config_component.py | 16 ++--- tests/test_config_resolver.py | 14 ++-- 10 files changed, 143 insertions(+), 140 deletions(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 783709a515..dea6c56718 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -35,9 +35,6 @@ Model Package .. autoclass:: ConfigParser :members: -.. autoclass:: ModuleScanner - :members: - .. autoclass:: ConfigComponent :members: diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index a38e0c77c4..11c1cf784b 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -15,14 +15,11 @@ ConfigComponent, ConfigParser, ConfigResolver, - ModuleScanner, RemoteMMARKeys, download_mmar, - get_class, get_model_spec, - instantiate_class, load_from_mmar, - search_configs_with_objs, - update_configs_with_objs, + search_configs_with_deps, + update_configs_with_deps, ) from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/mmars/__init__.py b/monai/apps/mmars/__init__.py index e6d97b02c3..ec00d6ee10 100644 --- a/monai/apps/mmars/__init__.py +++ b/monai/apps/mmars/__init__.py @@ -10,7 +10,7 @@ # limitations under the License. from .config_parser import ConfigParser -from .config_resolver import ConfigComponent, ConfigResolver, ModuleScanner +from .config_resolver import ConfigComponent, ConfigResolver from .mmars import download_mmar, get_model_spec, load_from_mmar from .model_desc import MODEL_DESC, RemoteMMARKeys -from .utils import get_class, instantiate_class, search_configs_with_objs, update_configs_with_objs +from .utils import search_configs_with_deps, update_configs_with_deps diff --git a/monai/apps/mmars/config_parser.py b/monai/apps/mmars/config_parser.py index b92a359bf9..767ad763c3 100644 --- a/monai/apps/mmars/config_parser.py +++ b/monai/apps/mmars/config_parser.py @@ -12,7 +12,8 @@ import importlib from typing import Any, Dict, List, Optional, Sequence, Union -from monai.apps.mmars.config_resolver import ConfigComponent, ConfigResolver, ModuleScanner +from monai.apps.mmars.config_resolver import ConfigComponent, ConfigResolver +from monai.utils.module import ClassScanner class ConfigParser: @@ -43,7 +44,7 @@ def __init__( self.config = None if config is not None: self.set_config(config=config) - self.module_scanner = ModuleScanner(pkgs=pkgs, modules=modules) + self.class_scanner = ClassScanner(pkgs=pkgs, modules=modules) self.global_imports: Dict[str, Any] = {"monai": "monai", "torch": "torch", "np": "numpy"} if global_imports is not None: for k, v in global_imports.items(): @@ -124,7 +125,7 @@ def _do_parse(self, config, id: Optional[str] = None): self._do_parse(config=v, id=sub_id) if id is not None: self.config_resolver.add( - ConfigComponent(id=id, config=config, module_scanner=self.module_scanner, globals=self.global_imports) + ConfigComponent(id=id, config=config, class_scanner=self.class_scanner, globals=self.global_imports) ) def parse_config(self, resolve_all: bool = False): @@ -178,5 +179,5 @@ def build(self, config: Dict): """ return ConfigComponent( - id="", config=config, module_scanner=self.module_scanner, globals=self.global_imports + id="", config=config, class_scanner=self.class_scanner, globals=self.global_imports ).build() diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py index 63667b9feb..438ef5f55b 100644 --- a/monai/apps/mmars/config_resolver.py +++ b/monai/apps/mmars/config_resolver.py @@ -9,58 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import importlib -import inspect -import pkgutil import warnings -from typing import Any, Dict, List, Optional, Sequence +from typing import Any, Dict, List, Optional -from monai.apps.mmars.utils import instantiate_class, search_configs_with_objs, update_configs_with_objs - - -class ModuleScanner: - """ - Scan all the available classes in the specified packages and modules. - Map the all the class names and the module names in a table. - - Args: - pkgs: the expected packages to scan modules and parse class names in the config. - modules: the expected modules in the packages to scan for all the classes. - for example, to parser "LoadImage" in config, `pkgs` can be ["monai"], `modules` can be ["transforms"]. - - """ - - def __init__(self, pkgs: Sequence[str], modules: Sequence[str]): - self.pkgs = pkgs - self.modules = modules - self._class_table = self._create_classes_table() - - def _create_classes_table(self): - class_table = {} - for pkg in self.pkgs: - package = importlib.import_module(pkg) - - for _, modname, _ in pkgutil.walk_packages(path=package.__path__, prefix=package.__name__ + "."): - # if no modules specified, load all modules in the package - if len(self.modules) == 0 or any(name in modname for name in self.modules): - try: - module = importlib.import_module(modname) - for name, obj in inspect.getmembers(module): - if inspect.isclass(obj) and obj.__module__ == modname: - class_table[name] = modname - except ModuleNotFoundError: - pass - return class_table - - def get_class_module_name(self, class_name): - """ - Get the module name of the class with specified class name. - - Args: - class_name: name of the expected class. - - """ - return self._class_table.get(class_name, None) +from monai.apps.mmars.utils import search_configs_with_deps, update_configs_with_deps +from monai.utils.module import ClassScanner, instantiate_class class ConfigComponent: @@ -88,17 +41,17 @@ class ConfigComponent: for list component, use index from `0` as id. for example: `transform`, `transform#5`, `transform#5##keys`, etc. config: config content of current component, can be a `dict`, `list`, `string`, `float`, `int`, etc. - module_scanner: ModuleScanner to help get the class name or path in the config and build instance. + class_scanner: ClassScanner to help get the class name or path in the config and build instance. globals: to support executable string in the config, sometimes we need to provide the global variables which are referred in the executable string. for example: `globals={"monai": monai} will be useful for config `"collate_fn": "$monai.data.list_data_collate"`. """ - def __init__(self, id: str, config: Any, module_scanner: ModuleScanner, globals: Optional[Dict] = None) -> None: + def __init__(self, id: str, config: Any, class_scanner: ClassScanner, globals: Optional[Dict] = None) -> None: self.id = id self.config = config - self.module_scanner = module_scanner + self.class_scanner = class_scanner self.globals = globals def get_id(self) -> str: @@ -124,7 +77,7 @@ def get_dependent_ids(self) -> List[str]: `["", "", "#dataset", "dataset"]`. """ - return search_configs_with_objs(config=self.config, id=self.id) + return search_configs_with_deps(config=self.config, id=self.id) def get_updated_config(self, deps: dict): """ @@ -135,7 +88,7 @@ def get_updated_config(self, deps: dict): deps: all the dependent components with ids. """ - return update_configs_with_objs(config=self.config, deps=deps, id=self.id, globals=self.globals) + return update_configs_with_deps(config=self.config, deps=deps, id=self.id, globals=self.globals) def _check_dependency(self, config): """ @@ -205,7 +158,7 @@ def _get_class_path(self, config): class_name = config.get("", None) if class_name is None: raise ValueError("must provide `` or `` of class to build component.") - module_name = self.module_scanner.get_class_module_name(class_name) + module_name = self.class_scanner.get_class_module_name(class_name) if module_name is None: raise ValueError(f"can not find component class '{class_name}'.") class_path = f"{module_name}.{class_name}" diff --git a/monai/apps/mmars/utils.py b/monai/apps/mmars/utils.py index 51828362b7..f532ff8cc1 100644 --- a/monai/apps/mmars/utils.py +++ b/monai/apps/mmars/utils.py @@ -9,61 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import importlib import re from typing import Dict, List, Optional, Union -def get_class(class_path: str): - """ - Get the class from specified class path. - - Args: - class_path (str): full path of the class. - - Raises: - ValueError: invalid class_path, missing the module name. - ValueError: class does not exist. - ValueError: module does not exist. - - """ - if len(class_path.split(".")) < 2: - raise ValueError(f"invalid class_path: {class_path}, missing the module name.") - module_name, class_name = class_path.rsplit(".", 1) - - try: - module_ = importlib.import_module(module_name) - - try: - class_ = getattr(module_, class_name) - except AttributeError as e: - raise ValueError(f"class {class_name} does not exist.") from e - - except AttributeError as e: - raise ValueError(f"module {module_name} does not exist.") from e - - return class_ - - -def instantiate_class(class_path: str, **kwargs): - """ - Method for creating an instance for the specified class. - - Args: - class_path: full path of the class. - kwargs: arguments to initialize the class instance. - - Raises: - ValueError: class has paramenters error. - """ - - try: - return get_class(class_path)(**kwargs) - except TypeError as e: - raise ValueError(f"class {class_path} has parameters error.") from e - - -def search_configs_with_objs(config: Union[Dict, List, str], id: str, deps: Optional[List[str]] = None) -> List[str]: +def search_configs_with_deps(config: Union[Dict, List, str], id: str, deps: Optional[List[str]] = None) -> List[str]: """ Recursively search all the content of input config compoent to get the ids of dependencies. It's used to build all the dependencies before build current config component. @@ -84,13 +34,13 @@ def search_configs_with_objs(config: Union[Dict, List, str], id: str, deps: Opti sub_id = f"{id}#{i}" # all the items in the list should be marked as dependent reference deps_.append(sub_id) - deps_ = search_configs_with_objs(v, sub_id, deps_) + deps_ = search_configs_with_deps(v, sub_id, deps_) if isinstance(config, dict): for k, v in config.items(): sub_id = f"{id}#{k}" # all the items in the dict should be marked as dependent reference deps_.append(sub_id) - deps_ = search_configs_with_objs(v, sub_id, deps_) + deps_ = search_configs_with_deps(v, sub_id, deps_) if isinstance(config, str): result = pattern.findall(config) for item in result: @@ -101,7 +51,7 @@ def search_configs_with_objs(config: Union[Dict, List, str], id: str, deps: Opti return deps_ -def update_configs_with_objs(config: Union[Dict, List, str], deps: dict, id: str, globals: Optional[Dict] = None): +def update_configs_with_deps(config: Union[Dict, List, str], deps: dict, id: str, globals: Optional[Dict] = None): """ With all the dependencies in `deps`, update the config content with them and return new config. It can be used for lazy instantiation. diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index dec9bdea65..a1e9390c30 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -61,14 +61,17 @@ zip_with, ) from .module import ( + ClassScanner, InvalidPyTorchVersionError, OptionalImportError, damerau_levenshtein_distance, exact_version, export, + get_class, get_full_type_name, get_package_version, get_torch_version_tuple, + instantiate_class, load_submodules, look_up_option, min_version, diff --git a/monai/utils/module.py b/monai/utils/module.py index c0fc10a7c0..2994b80421 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -10,6 +10,7 @@ # limitations under the License. import enum +import inspect import os import re import sys @@ -19,7 +20,7 @@ from pkgutil import walk_packages from re import match from types import FunctionType -from typing import Any, Callable, Collection, Hashable, Iterable, List, Mapping, Tuple, cast +from typing import Any, Callable, Collection, Hashable, Iterable, List, Mapping, Sequence, Tuple, cast import torch @@ -36,6 +37,9 @@ "optional_import", "require_pkg", "load_submodules", + "ClassScanner", + "get_class", + "instantiate_class", "get_full_type_name", "get_package_version", "get_torch_version_tuple", @@ -193,7 +197,105 @@ def load_submodules(basemod, load_all: bool = True, exclude_pattern: str = "(.*[ return submodules, err_mod +class ClassScanner: + """ + Scan all the available classes in the specified packages and modules. + Map the all the class names and the module names in a table. + + Args: + pkgs: the expected packages to scan modules and parse class names in the config. + modules: the expected modules in the packages to scan for all the classes. + for example, to parser "LoadImage" in config, `pkgs` can be ["monai"], `modules` can be ["transforms"]. + + """ + + def __init__(self, pkgs: Sequence[str], modules: Sequence[str]): + self.pkgs = pkgs + self.modules = modules + self._class_table = self._create_classes_table() + + def _create_classes_table(self): + class_table = {} + for pkg in self.pkgs: + package = import_module(pkg) + + for _, modname, _ in walk_packages(path=package.__path__, prefix=package.__name__ + "."): + # if no modules specified, load all modules in the package + if len(self.modules) == 0 or any(name in modname for name in self.modules): + try: + module = import_module(modname) + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and obj.__module__ == modname: + class_table[name] = modname + except ModuleNotFoundError: + pass + return class_table + + def get_class_module_name(self, class_name): + """ + Get the module name of the class with specified class name. + + Args: + class_name: name of the expected class. + + """ + return self._class_table.get(class_name, None) + + +def get_class(class_path: str): + """ + Get the class from specified class path. + + Args: + class_path (str): full path of the class. + + Raises: + ValueError: invalid class_path, missing the module name. + ValueError: class does not exist. + ValueError: module does not exist. + + """ + if len(class_path.split(".")) < 2: + raise ValueError(f"invalid class_path: {class_path}, missing the module name.") + module_name, class_name = class_path.rsplit(".", 1) + + try: + module_ = import_module(module_name) + + try: + class_ = getattr(module_, class_name) + except AttributeError as e: + raise ValueError(f"class {class_name} does not exist.") from e + + except AttributeError as e: + raise ValueError(f"module {module_name} does not exist.") from e + + return class_ + + +def instantiate_class(class_path: str, **kwargs): + """ + Method for creating an instance for the specified class. + + Args: + class_path: full path of the class. + kwargs: arguments to initialize the class instance. + + Raises: + ValueError: class has paramenters error. + """ + + try: + return get_class(class_path)(**kwargs) + except TypeError as e: + raise ValueError(f"class {class_path} has parameters error.") from e + + def get_full_type_name(typeobj): + """ + Utility to get the full path name of a class or object type. + + """ module = typeobj.__module__ if module is None or module == str.__class__.__module__: return typeobj.__name__ # Avoid reporting __builtin__ diff --git a/tests/test_config_component.py b/tests/test_config_component.py index c11cb62f21..d48a4e274f 100644 --- a/tests/test_config_component.py +++ b/tests/test_config_component.py @@ -16,10 +16,10 @@ from parameterized import parameterized import monai -from monai.apps import ConfigComponent, ModuleScanner +from monai.apps import ConfigComponent from monai.data import DataLoader, Dataset from monai.transforms import LoadImaged, RandTorchVisiond -from monai.utils import optional_import +from monai.utils import ClassScanner, optional_import _, has_tv = optional_import("torchvision") @@ -110,8 +110,8 @@ class TestConfigComponent(unittest.TestCase): [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5] + ([TEST_CASE_6] if has_tv else []) ) def test_build(self, input_param, test_input, output_type): - scanner = ModuleScanner(**input_param) - configer = ConfigComponent(id="test", config=test_input, module_scanner=scanner) + scanner = ClassScanner(**input_param) + configer = ConfigComponent(id="test", config=test_input, class_scanner=scanner) ret = configer.build() self.assertTrue(isinstance(ret, output_type)) if isinstance(ret, LoadImaged): @@ -122,16 +122,16 @@ def test_build(self, input_param, test_input, output_type): @parameterized.expand([TEST_CASE_7, TEST_CASE_8, TEST_CASE_9, TEST_CASE_10]) def test_dependent_ids(self, test_input, ref_ids): - scanner = ModuleScanner(pkgs=[], modules=[]) - configer = ConfigComponent(id="test", config=test_input, module_scanner=scanner) + scanner = ClassScanner(pkgs=[], modules=[]) + configer = ConfigComponent(id="test", config=test_input, class_scanner=scanner) ret = configer.get_dependent_ids() self.assertListEqual(ret, ref_ids) @parameterized.expand([TEST_CASE_11, TEST_CASE_12, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15, TEST_CASE_16]) def test_update_dependencies(self, id, test_input, deps, output_type): - scanner = ModuleScanner(pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"]) + scanner = ClassScanner(pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"]) configer = ConfigComponent( - id=id, config=test_input, module_scanner=scanner, globals={"monai": monai, "torch": torch} + id=id, config=test_input, class_scanner=scanner, globals={"monai": monai, "torch": torch} ) config = configer.get_updated_config(deps) ret = configer.build(config) diff --git a/tests/test_config_resolver.py b/tests/test_config_resolver.py index 07a1a9a784..95200d0b81 100644 --- a/tests/test_config_resolver.py +++ b/tests/test_config_resolver.py @@ -15,10 +15,10 @@ from parameterized import parameterized import monai -from monai.apps import ConfigComponent, ConfigResolver, ModuleScanner +from monai.apps import ConfigComponent, ConfigResolver from monai.data import DataLoader from monai.transforms import LoadImaged, RandTorchVisiond -from monai.utils import optional_import +from monai.utils import ClassScanner, optional_import _, has_tv = optional_import("torchvision") @@ -79,11 +79,11 @@ class TestConfigComponent(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2] + ([TEST_CASE_3] if has_tv else [])) def test_resolve(self, configs, expected_id, output_type): - scanner = ModuleScanner(pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"]) + scanner = ClassScanner(pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"]) resolver = ConfigResolver() for k, v in configs.items(): resolver.add( - ConfigComponent(id=k, config=v, module_scanner=scanner, globals={"monai": monai, "torch": torch}) + ConfigComponent(id=k, config=v, class_scanner=scanner, globals={"monai": monai, "torch": torch}) ) ins = resolver.get_resolved_component(expected_id) self.assertTrue(isinstance(ins, output_type)) @@ -96,15 +96,15 @@ def test_resolve(self, configs, expected_id, output_type): # test lazy instantiation config = resolver.get_resolved_config(expected_id) config[""] = False - ins = ConfigComponent(id=expected_id, module_scanner=scanner, config=config).build() + ins = ConfigComponent(id=expected_id, class_scanner=scanner, config=config).build() self.assertTrue(isinstance(ins, output_type)) def test_circular_dependencies(self): - scanner = ModuleScanner(pkgs=[], modules=[]) + scanner = ClassScanner(pkgs=[], modules=[]) resolver = ConfigResolver() configs = {"A": "@B", "B": "@C", "C": "@A"} for k, v in configs.items(): - resolver.add(ConfigComponent(id=k, config=v, module_scanner=scanner)) + resolver.add(ConfigComponent(id=k, config=v, class_scanner=scanner)) for k in ["A", "B", "C"]: with self.assertRaises(ValueError): resolver.get_resolved_component(k) From 3049e280f2424962bb2a69261389fcc0b98e0036 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 27 Jan 2022 16:07:38 +0800 Subject: [PATCH 18/30] [DLMED] add scripts logic Signed-off-by: Nic Ma --- monai/apps/mmars/schema/metadata.json | 71 +++++++++++++++++++ .../mmars/scripts(pseudo code)/__init__.py | 11 +++ .../apps/mmars/scripts(pseudo code)/export.py | 60 ++++++++++++++++ .../mmars/scripts(pseudo code)/inference.py | 69 ++++++++++++++++++ .../scripts(pseudo code)/verify_network.py | 61 ++++++++++++++++ 5 files changed, 272 insertions(+) create mode 100644 monai/apps/mmars/schema/metadata.json create mode 100644 monai/apps/mmars/scripts(pseudo code)/__init__.py create mode 100644 monai/apps/mmars/scripts(pseudo code)/export.py create mode 100644 monai/apps/mmars/scripts(pseudo code)/inference.py create mode 100644 monai/apps/mmars/scripts(pseudo code)/verify_network.py diff --git a/monai/apps/mmars/schema/metadata.json b/monai/apps/mmars/schema/metadata.json new file mode 100644 index 0000000000..babee8b30e --- /dev/null +++ b/monai/apps/mmars/schema/metadata.json @@ -0,0 +1,71 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://monai.io/mmar_metadata_schema.json", + "title": "metadata", + "description": "metadata that defines the context information for MMAR.", + "type": "object", + "properties": { + "version": { + "description": "version number of this MMAR.", + "type": "string" + }, + "monai_version": { + "description": "version number of MONAI used in this MMAR.", + "type": "string" + }, + "pytorch_version": { + "description": "version number of PyTorch used in this MMAR.", + "type": "string" + }, + "numpy_version": { + "description": "version number of MONAI used in this MMAR.", + "type": "string" + }, + "network_data_format": { + "description": "define the input and output data format for network.", + "type": "object", + "properties": { + "inputs": { + "type": "object", + "properties": { + "image": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "the data format for `image`." + }, + "format": { + "type": "string" + }, + "num_channels": { + "type": "integer", + "minimum": 1 + }, + "spatial_shape": { + "type": "array", + "items": { + "type": "integer", + "minumum": 1 + } + }, + "dtype": { + "type": "string" + }, + "value_range": { + "type": "array", + "items": { + "type": "number", + "unuqueItems": true + } + }, + "required": ["num_channels", "spatial_shape", "value_range"] + } + } + } + } + } + }, + "required": ["monai_version", "pytorch_version", "network_data_format"] + } +} diff --git a/monai/apps/mmars/scripts(pseudo code)/__init__.py b/monai/apps/mmars/scripts(pseudo code)/__init__.py new file mode 100644 index 0000000000..7e5ba0883c --- /dev/null +++ b/monai/apps/mmars/scripts(pseudo code)/__init__.py @@ -0,0 +1,11 @@ + +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/apps/mmars/scripts(pseudo code)/export.py b/monai/apps/mmars/scripts(pseudo code)/export.py new file mode 100644 index 0000000000..22f76f90b8 --- /dev/null +++ b/monai/apps/mmars/scripts(pseudo code)/export.py @@ -0,0 +1,60 @@ + +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json + +import torch +from monai.apps import ConfigParser +from ignite.handlers import Checkpoint +from monai.data import save_net_with_metadata +from monai.networks import convert_to_torchscript + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--weights', '-w', type=str, help='file path of the trained model weights', required=True) + parser.add_argument('--config', '-c', type=str, help='file path of config file that defines network', required=True) + parser.add_argument('--meta', '-e', type=str, help='file path of the meta data') + args = parser.parse_args() + + # load config file + with open(args.config, "r") as f: + config_dict = json.load(f) + # load meta data + with open(args.meta, "r") as f: + meta_dict = json.load(f) + + net: torch.nn.Module = None + # TODO: parse network definiftion from config file and construct network instance + config_parser = ConfigParser(config_dict) + net = config_parser.get_instance("network") + + checkpoint = torch.load(args.weights) + # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver + Checkpoint.load_objects(to_load={"model": net}, checkpoint=checkpoint) + + # convert to TorchScript model and save with meta data + net = convert_to_torchscript(model=net) + + save_net_with_metadata( + jit_obj=net, + filename_prefix_or_stream="model.ts", + include_config_vals=False, + append_timestamp=False, + meta_values=meta_dict, + more_extra_files={args.config: json.dumps(config_dict).encode()}, + ) + + +if __name__ == '__main__': + main() diff --git a/monai/apps/mmars/scripts(pseudo code)/inference.py b/monai/apps/mmars/scripts(pseudo code)/inference.py new file mode 100644 index 0000000000..33a599f88e --- /dev/null +++ b/monai/apps/mmars/scripts(pseudo code)/inference.py @@ -0,0 +1,69 @@ + +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json + +import torch +from monai.apps import ConfigParser +from monai.data import decollate_batch +from monai.inferers import Inferer +from monai.transforms import Transform +from monai.utils.enums import CommonKeys + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--config', '-c', type=str, help='config file that defines components', required=True) + parser.add_argument('--meta', '-e', type=str, help='file path of the meta data') + parser.add_argument('--override', '-o', type=str, help='config file that override components', required=False) + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + configs = {} + + # load meta data + with open(args.meta, "r") as f: + configs.update(json.load(f)) + # load config file, can override meta data in config + with open(args.config, "r") as f: + configs.update(json.load(f)) + + model: torch.nn.Module = None + dataloader: torch.utils.data.DataLoader = None + inferer: Inferer = None + postprocessing: Transform = None + # TODO: parse inference config file and construct instances + config_parser = ConfigParser(configs) + + # change JSON config content in python code, lazy instantiation + model_conf = config_parser.get_config("model") + model_conf["disabled"] = False + model = config_parser.build(model_conf).to(device) + + # instantialize the components immediately + dataloader = config_parser.get_instance("dataloader") + inferer = config_parser.get_instance("inferer") + postprocessing = config_parser.get_instance("postprocessing") + + model.eval() + with torch.no_grad(): + for d in dataloader: + images = d[CommonKeys.IMAGE].to(device) + # define sliding window size and batch size for windows inference + d[CommonKeys.PRED] = inferer(inputs=images, predictor=model) + # decollate the batch data into a list of dictionaries, then execute postprocessing transforms + [postprocessing(i) for i in decollate_batch(d)] + + +if __name__ == '__main__': + main() diff --git a/monai/apps/mmars/scripts(pseudo code)/verify_network.py b/monai/apps/mmars/scripts(pseudo code)/verify_network.py new file mode 100644 index 0000000000..1fe0b76b6c --- /dev/null +++ b/monai/apps/mmars/scripts(pseudo code)/verify_network.py @@ -0,0 +1,61 @@ + +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json + +import torch +from monai.apps import ConfigParser +from monai.utils.type_conversion import get_equivalent_dtype + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--config', '-c', type=str, help='config file that defines components', required=True) + parser.add_argument('--meta', '-e', type=str, help='file path of the meta data') + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + configs = {} + + # load meta data + with open(args.meta, "r") as f: + configs.update(json.load(f)) + # load config file, can override meta data in config + with open(args.config, "r") as f: + configs.update(json.load(f)) + + model: torch.nn.Module = None + # TODO: parse inference config file and construct instances + config_parser = ConfigParser(configs) + + model = config_parser.get_instance("model") + input_channels = config_parser.get_config("network_data_format#inputs#image#num_channels") + input_spatial_shape = tuple(config_parser.get_config("network_data_format#inputs#image#spatial_shape")) + dtype = config_parser.get_config("network_data_format#inputs#image#dtype") + dtype = get_equivalent_dtype(dtype, data_type=torch.Tensor) + + output_channels = config_parser.get_config("network_data_format#outputs#pred#num_channels") + output_spatial_shape = tuple(config_parser.get_config("network_data_format#outputs#pred#spatial_shape")) + + model.eval() + with torch.no_grad(): + test_data = torch.rand(*(input_channels, *input_spatial_shape), dtype=dtype, device=device) + output = model(test_data) + if output.shape[0] != output_channels: + raise ValueError(f"channel number of output data doesn't match expection: {output_channels}.") + if output.shape[1:] != output_spatial_shape: + raise ValueError(f"spatial shape of output data doesn't match expection: {output_spatial_shape}.") + + +if __name__ == '__main__': + main() From 506b9089637d7dd9ed3496dde9deee8c669f269d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Jan 2022 08:08:28 +0000 Subject: [PATCH 19/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/mmars/scripts(pseudo code)/__init__.py | 1 - monai/apps/mmars/scripts(pseudo code)/export.py | 5 ++--- monai/apps/mmars/scripts(pseudo code)/inference.py | 5 ++--- monai/apps/mmars/scripts(pseudo code)/verify_network.py | 5 ++--- 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/monai/apps/mmars/scripts(pseudo code)/__init__.py b/monai/apps/mmars/scripts(pseudo code)/__init__.py index 7e5ba0883c..14ae193634 100644 --- a/monai/apps/mmars/scripts(pseudo code)/__init__.py +++ b/monai/apps/mmars/scripts(pseudo code)/__init__.py @@ -1,4 +1,3 @@ - # Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/monai/apps/mmars/scripts(pseudo code)/export.py b/monai/apps/mmars/scripts(pseudo code)/export.py index 22f76f90b8..59f404da29 100644 --- a/monai/apps/mmars/scripts(pseudo code)/export.py +++ b/monai/apps/mmars/scripts(pseudo code)/export.py @@ -1,4 +1,3 @@ - # Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,10 +27,10 @@ def main(): args = parser.parse_args() # load config file - with open(args.config, "r") as f: + with open(args.config) as f: config_dict = json.load(f) # load meta data - with open(args.meta, "r") as f: + with open(args.meta) as f: meta_dict = json.load(f) net: torch.nn.Module = None diff --git a/monai/apps/mmars/scripts(pseudo code)/inference.py b/monai/apps/mmars/scripts(pseudo code)/inference.py index 33a599f88e..bbe64740e1 100644 --- a/monai/apps/mmars/scripts(pseudo code)/inference.py +++ b/monai/apps/mmars/scripts(pseudo code)/inference.py @@ -1,4 +1,3 @@ - # Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,10 +31,10 @@ def main(): configs = {} # load meta data - with open(args.meta, "r") as f: + with open(args.meta) as f: configs.update(json.load(f)) # load config file, can override meta data in config - with open(args.config, "r") as f: + with open(args.config) as f: configs.update(json.load(f)) model: torch.nn.Module = None diff --git a/monai/apps/mmars/scripts(pseudo code)/verify_network.py b/monai/apps/mmars/scripts(pseudo code)/verify_network.py index 1fe0b76b6c..db9e247994 100644 --- a/monai/apps/mmars/scripts(pseudo code)/verify_network.py +++ b/monai/apps/mmars/scripts(pseudo code)/verify_network.py @@ -1,4 +1,3 @@ - # Copyright 2020 - 2021 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,10 +27,10 @@ def main(): configs = {} # load meta data - with open(args.meta, "r") as f: + with open(args.meta) as f: configs.update(json.load(f)) # load config file, can override meta data in config - with open(args.config, "r") as f: + with open(args.config) as f: configs.update(json.load(f)) model: torch.nn.Module = None From 9d1e716201b953defd34320ff409ea93450e1db8 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 16 Feb 2022 18:33:56 +0800 Subject: [PATCH 20/30] [DLMED] update to latest Signed-off-by: Nic Ma --- monai/apps/__init__.py | 10 +- monai/apps/deepgrow/dataset.py | 2 +- monai/apps/deepgrow/transforms.py | 12 +- monai/apps/mmars/__init__.py | 13 +- monai/apps/mmars/config_item.py | 378 ++++++++ monai/apps/mmars/config_parser.py | 145 +++- monai/apps/mmars/config_resolver.py | 276 ------ monai/apps/mmars/schema/metadata.json | 71 -- .../mmars/scripts(pseudo code)/__init__.py | 10 - .../apps/mmars/scripts(pseudo code)/export.py | 59 -- .../mmars/scripts(pseudo code)/inference.py | 68 -- .../scripts(pseudo code)/verify_network.py | 60 -- monai/apps/mmars/utils.py | 189 ++-- .../pathology/transforms/spatial/array.py | 3 +- monai/apps/utils.py | 38 +- monai/config/deviceconfig.py | 8 +- monai/config/type_definitions.py | 9 +- monai/csrc/ext.cpp | 2 + monai/csrc/resample/pushpull_cpu.cpp | 106 +-- monai/csrc/resample/pushpull_cuda.cu | 113 +-- monai/data/__init__.py | 12 + monai/data/dataloader.py | 2 +- monai/data/dataset.py | 65 +- monai/data/dataset_summary.py | 2 +- monai/data/folder_layout.py | 5 +- monai/data/image_reader.py | 6 +- monai/data/image_writer.py | 804 ++++++++++++++++++ monai/data/nifti_saver.py | 5 + monai/data/nifti_writer.py | 22 +- monai/data/png_saver.py | 6 +- monai/data/png_writer.py | 12 +- monai/data/samplers.py | 2 +- monai/data/test_time_augmentation.py | 123 ++- monai/data/thread_buffer.py | 21 +- monai/data/utils.py | 190 +++-- monai/engines/evaluator.py | 11 +- monai/engines/multi_gpu_supervised_trainer.py | 14 +- monai/handlers/lr_schedule_handler.py | 6 - monai/handlers/parameter_scheduler.py | 2 +- monai/handlers/roc_auc.py | 2 +- monai/handlers/stats_handler.py | 40 +- monai/losses/image_dissimilarity.py | 6 +- monai/metrics/confusion_matrix.py | 2 +- monai/metrics/cumulative_average.py | 2 +- monai/metrics/hausdorff_distance.py | 4 +- monai/metrics/meandice.py | 4 +- monai/metrics/metric.py | 6 +- monai/metrics/regression.py | 2 +- monai/metrics/rocauc.py | 2 +- monai/metrics/surface_distance.py | 8 +- monai/metrics/utils.py | 2 +- monai/networks/__init__.py | 2 + monai/networks/blocks/acti_norm.py | 2 +- monai/networks/blocks/patchembedding.py | 2 +- monai/networks/blocks/selfattention.py | 11 +- monai/networks/blocks/upsample.py | 4 +- monai/networks/blocks/warp.py | 2 +- monai/networks/layers/convutils.py | 2 +- monai/networks/layers/spatial_transforms.py | 6 +- monai/networks/nets/densenet.py | 3 + monai/networks/nets/dints.py | 10 +- monai/networks/nets/highresnet.py | 2 +- monai/networks/nets/regunet.py | 6 +- monai/networks/nets/segresnet.py | 6 +- monai/networks/nets/transchex.py | 11 +- monai/networks/nets/vit.py | 6 +- monai/networks/utils.py | 66 +- monai/transforms/__init__.py | 23 +- monai/transforms/compose.py | 17 +- monai/transforms/croppad/array.py | 6 +- monai/transforms/croppad/batch.py | 5 +- monai/transforms/croppad/dictionary.py | 8 +- monai/transforms/intensity/array.py | 48 +- monai/transforms/intensity/dictionary.py | 8 +- monai/transforms/inverse_batch_transform.py | 5 +- monai/transforms/io/array.py | 253 +++--- monai/transforms/io/dictionary.py | 110 +-- monai/transforms/post/array.py | 11 +- monai/transforms/post/dictionary.py | 83 +- monai/transforms/smooth_field/array.py | 4 +- monai/transforms/smooth_field/dictionary.py | 17 +- monai/transforms/spatial/array.py | 406 +++++++-- monai/transforms/spatial/dictionary.py | 205 ++++- monai/transforms/transform.py | 18 +- monai/transforms/utility/array.py | 34 +- monai/transforms/utility/dictionary.py | 14 +- monai/transforms/utils.py | 55 +- .../transforms/utils_create_transform_ims.py | 5 +- .../utils_pytorch_numpy_unification.py | 47 +- monai/utils/__init__.py | 5 +- monai/utils/misc.py | 55 +- monai/utils/module.py | 117 +-- monai/utils/type_conversion.py | 40 +- monai/visualize/class_activation_maps.py | 4 +- monai/visualize/img2tensorboard.py | 3 +- monai/visualize/occlusion_sensitivity.py | 4 +- monai/visualize/utils.py | 13 +- tests/min_tests.py | 2 + tests/test_affined.py | 7 + tests/test_apply_filter.py | 2 +- tests/test_cachedataset.py | 70 +- tests/test_component_locator.py | 35 + tests/test_compose.py | 2 +- tests/test_config_component.py | 142 ---- tests/test_config_item.py | 138 +++ tests/test_cross_validation.py | 11 +- tests/test_data_stats.py | 24 +- tests/test_data_statsd.py | 25 +- tests/test_decathlondataset.py | 11 +- tests/test_download_and_extract.py | 19 +- tests/test_efficientnet.py | 21 +- tests/test_ensemble_evaluator.py | 12 +- tests/test_global_mutual_information_loss.py | 5 +- tests/test_handler_checkpoint_loader.py | 8 - tests/test_handler_checkpoint_saver.py | 5 - tests/test_handler_lr_scheduler.py | 10 +- tests/test_handler_segmentation_saver.py | 8 +- tests/test_handler_stats.py | 67 +- tests/test_image_rw.py | 136 +++ tests/test_integration_classification_2d.py | 13 +- tests/test_integration_fast_train.py | 1 + tests/test_integration_segmentation_3d.py | 17 +- tests/test_integration_workflows.py | 3 - tests/test_integration_workflows_gan.py | 3 - tests/test_inverse.py | 13 +- tests/test_inverse_collation.py | 4 +- tests/test_invertd.py | 2 +- tests/test_itk_writer.py | 55 ++ tests/test_lmdbdataset.py | 6 +- tests/test_load_image.py | 16 +- tests/test_load_imaged.py | 2 +- tests/test_lr_finder.py | 18 +- tests/test_mednistdataset.py | 11 +- tests/test_mmar_download.py | 17 +- tests/test_module_list.py | 19 + tests/test_nifti_rw.py | 86 +- tests/test_ori_ras_lps.py | 46 + tests/test_pad_collation.py | 4 +- tests/test_parallel_execution_dist.py | 45 + tests/test_png_rw.py | 28 +- tests/test_rand_elastic_3d.py | 2 +- tests/test_rand_elasticd_3d.py | 2 +- tests/test_rand_rotate.py | 2 + tests/test_rand_rotated.py | 2 + tests/test_rotate.py | 6 +- tests/test_rotated.py | 12 +- tests/test_save_image.py | 15 +- tests/test_save_imaged.py | 12 +- tests/test_save_state.py | 70 ++ tests/test_scale_intensity.py | 2 +- .../test_scale_intensity_range_percentiles.py | 6 +- ...test_scale_intensity_range_percentilesd.py | 2 +- tests/test_separable_filter.py | 2 +- tests/test_spacing.py | 18 +- tests/test_spatial_resample.py | 146 ++++ tests/test_spatial_resampled.py | 113 +++ tests/test_testtimeaugmentation.py | 26 +- tests/test_tile_on_grid.py | 4 +- tests/test_tile_on_grid_dict.py | 4 +- tests/test_utils_pytorch_numpy_unification.py | 14 +- tests/test_vit.py | 14 +- tests/utils.py | 39 +- 162 files changed, 4337 insertions(+), 2071 deletions(-) create mode 100644 monai/apps/mmars/config_item.py delete mode 100644 monai/apps/mmars/config_resolver.py delete mode 100644 monai/apps/mmars/schema/metadata.json delete mode 100644 monai/apps/mmars/scripts(pseudo code)/__init__.py delete mode 100644 monai/apps/mmars/scripts(pseudo code)/export.py delete mode 100644 monai/apps/mmars/scripts(pseudo code)/inference.py delete mode 100644 monai/apps/mmars/scripts(pseudo code)/verify_network.py create mode 100644 monai/data/image_writer.py create mode 100644 tests/test_component_locator.py delete mode 100644 tests/test_config_component.py create mode 100644 tests/test_config_item.py create mode 100644 tests/test_image_rw.py create mode 100644 tests/test_itk_writer.py create mode 100644 tests/test_ori_ras_lps.py create mode 100644 tests/test_parallel_execution_dist.py create mode 100644 tests/test_save_state.py create mode 100644 tests/test_spatial_resample.py create mode 100644 tests/test_spatial_resampled.py diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index 11c1cf784b..2c3297d023 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -12,14 +12,20 @@ from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset from .mmars import ( MODEL_DESC, + ComponentLocator, ConfigComponent, + ConfigItem, ConfigParser, ConfigResolver, RemoteMMARKeys, download_mmar, + find_refs_in_config, get_model_spec, + is_expression, + is_instantiable, load_from_mmar, - search_configs_with_deps, - update_configs_with_deps, + match_refs_pattern, + resolve_config_with_refs, + resolve_refs_pattern, ) from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/deepgrow/dataset.py b/monai/apps/deepgrow/dataset.py index 763377763a..721781196b 100644 --- a/monai/apps/deepgrow/dataset.py +++ b/monai/apps/deepgrow/dataset.py @@ -126,8 +126,8 @@ def _default_transforms(image_key, label_key, pixdim): [ LoadImaged(keys=keys), AsChannelFirstd(keys=keys), - Spacingd(keys=keys, pixdim=pixdim, mode=mode), Orientationd(keys=keys, axcodes="RAS"), + Spacingd(keys=keys, pixdim=pixdim, mode=mode), ] ) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index 310931d236..1cef7ff03d 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -374,7 +374,7 @@ class SpatialCropForegroundd(MapTransform): the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `{key}_{meta_key_postfix}` to to fetch/store the meta data according + meta_key_postfix: if meta_keys is None, use `{key}_{meta_key_postfix}` to fetch/store the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. @@ -474,7 +474,7 @@ class AddGuidanceFromPointsd(Transform): for example, for data with key `image`, the metadata by default is in `image_meta_dict`. the meta data is a dictionary object which contains: filename, original_shape, etc. if None, will try to construct meta_keys by `{ref_image}_{meta_key_postfix}`. - meta_key_postfix: if meta_key is None, use `{ref_image}_{meta_key_postfix}` to to fetch the meta data according + meta_key_postfix: if meta_key is None, use `{ref_image}_{meta_key_postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. @@ -589,7 +589,7 @@ class SpatialCropGuidanced(MapTransform): the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. @@ -652,7 +652,7 @@ def __call__(self, data): return d guidance = d[self.guidance] - original_spatial_shape = d[first_key].shape[1:] # type: ignore + original_spatial_shape = d[first_key].shape[1:] box_start, box_end = self.bounding_box(np.array(guidance[0] + guidance[1]), original_spatial_shape) center = list(np.mean([box_start, box_end], axis=0).astype(int, copy=False)) spatial_size = self.spatial_size @@ -787,7 +787,7 @@ class RestoreLabeld(MapTransform): the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_key is None, use `key_{meta_key_postfix} to to fetch the meta data according + meta_key_postfix: if meta_key is None, use `key_{meta_key_postfix} to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. @@ -897,7 +897,7 @@ class Fetch2DSliced(MapTransform): the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: use `key_{meta_key_postfix}` to to fetch the meta data according to the key data, + meta_key_postfix: use `key_{meta_key_postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. diff --git a/monai/apps/mmars/__init__.py b/monai/apps/mmars/__init__.py index ec00d6ee10..081c2e511d 100644 --- a/monai/apps/mmars/__init__.py +++ b/monai/apps/mmars/__init__.py @@ -9,8 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_parser import ConfigParser -from .config_resolver import ConfigComponent, ConfigResolver +from .config_item import ComponentLocator, ConfigComponent, ConfigItem +from .config_parser import ConfigParser, ConfigResolver from .mmars import download_mmar, get_model_spec, load_from_mmar from .model_desc import MODEL_DESC, RemoteMMARKeys -from .utils import search_configs_with_deps, update_configs_with_deps +from .utils import ( + find_refs_in_config, + is_expression, + is_instantiable, + match_refs_pattern, + resolve_config_with_refs, + resolve_refs_pattern, +) diff --git a/monai/apps/mmars/config_item.py b/monai/apps/mmars/config_item.py new file mode 100644 index 0000000000..536743c5e9 --- /dev/null +++ b/monai/apps/mmars/config_item.py @@ -0,0 +1,378 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import sys +import warnings +from abc import ABC, abstractmethod +from importlib import import_module +from typing import Any, Dict, List, Optional, Sequence, Union + +from monai.apps.mmars.utils import find_refs_in_config, is_instantiable, resolve_config_with_refs +from monai.utils import ensure_tuple, instantiate + +__all__ = ["ComponentLocator", "ConfigItem", "ConfigComponent"] + + +class ComponentLocator: + """ + Scan all the available classes and functions in the MONAI package and map them with the module paths in a table. + It's used to locate the module path for provided component name. + + Args: + excludes: if any string of the `excludes` exists in the full module name, don't import this module. + + """ + + MOD_START = "monai" + + def __init__(self, excludes: Optional[Union[Sequence[str], str]] = None): + self.excludes = [] if excludes is None else ensure_tuple(excludes) + self._components_table: Optional[Dict[str, List]] = None + + def _find_module_names(self) -> List[str]: + """ + Find all the modules start with MOD_START and don't contain any of `excludes`. + + """ + return [ + m for m in sys.modules.keys() if m.startswith(self.MOD_START) and all(s not in m for s in self.excludes) + ] + + def _find_classes_or_functions(self, modnames: Union[Sequence[str], str]) -> Dict[str, List]: + """ + Find all the classes and functions in the modules with specified `modnames`. + + Args: + modnames: names of the target modules to find all the classes and functions. + + """ + table: Dict[str, List] = {} + # all the MONAI modules are already loaded by `load_submodules` + for modname in ensure_tuple(modnames): + try: + # scan all the classes and functions in the module + module = import_module(modname) + for name, obj in inspect.getmembers(module): + if (inspect.isclass(obj) or inspect.isfunction(obj)) and obj.__module__ == modname: + if name not in table: + table[name] = [] + table[name].append(modname) + except ModuleNotFoundError: + pass + return table + + def get_component_module_name(self, name) -> Optional[Union[List[str], str]]: + """ + Get the full module name of the class / function with specified name. + If target component name exists in multiple packages or modules, return a list of full module names. + + Args: + name: name of the expected class or function. + + """ + if self._components_table is None: + # init component and module mapping table + self._components_table = self._find_classes_or_functions(self._find_module_names()) + + mods: Optional[Union[List[str], str]] = self._components_table.get(name, None) + if isinstance(mods, list) and len(mods) == 1: + mods = mods[0] + return mods + + +class ConfigItem: + """ + Utility class to manage every item of the whole config content. + When recursively parsing a complicated config content, every item (like: dict, list, string, int, float, etc.) + can be treated as an "config item" then construct a `ConfigItem`. + For example, below are 5 config items when recursively parsing: + - a dict: `{"preprocessing": ["@transform1", "@transform2", "$lambda x: x"]}` + - a list: `["@transform1", "@transform2", "$lambda x: x"]` + - a string: `"transform1"` + - a string: `"transform2"` + - a string: `"$lambda x: x"` + + `ConfigItem` can set optional unique ID name, then another config item may refer to it. + The references mean the IDs of other config items used as "@XXX" in the current config item, for example: + config item with ID="A" is a list `[1, 2, 3]`, another config item can be `"args": {"input_list": "@A"}`. + If sub-item in the config is instantiable, also treat it as reference because must instantiate it before + resolving current config. + It can search the config content and find out all the references, and resolve the config content + when all the references are resolved. + + Here we predefined 3 kinds special marks (`#`, `@`, `$`) when parsing the whole config content: + - "XXX#YYY": join nested config IDs, like "transforms#5" is ID name of the 6th transform in a list ID="transforms". + - "@XXX": current config item refers to another config item XXX, like `{"args": {"data": "@dataset"}}` uses + resolved config content of `dataset` as the parameter "data". + - "$XXX": execute the string after "$" as python code with `eval()` function, like "$@model.parameters()". + + The typical usage of the APIs: + - Initialize with config content. + - If having references, get the IDs of its referring components. + - When all the referring components are resolved, resolve the config content with them, + and execute expressions in the config. + + .. code-block:: python + + config = {"lr": "$@epoch / 1000"} + + configer = ConfigComponent(config, id="test") + dep_ids = configer.get_id_of_refs() + configer.resolve_config(refs={"epoch": 10}) + lr = configer.get_resolved_config() + + Args: + config: content of a config item, can be a `dict`, `list`, `string`, `float`, `int`, etc. + id: ID name of current config item, useful to construct referring config items. + for example, config item A may have ID "transforms#A" and config item B refers to A + and uses the resolved config content of A as an arg `{"args": {"other": "@transforms#A"}}`. + `id` defaults to `None`, if some component refers to current component, `id` must be a `string`. + globals: to support executable string in the config, sometimes we need to provide the global variables + which are referred in the executable string. for example: `globals={"monai": monai} will be useful + for config `{"collate_fn": "$monai.data.list_data_collate"}`. + + """ + + def __init__(self, config: Any, id: Optional[str] = None, globals: Optional[Dict] = None) -> None: + self.config = config + self.resolved_config = None + self.is_resolved = False + self.id = id + self.globals = globals + self.update_config(config=config) + + def get_id(self) -> Optional[str]: + """ + ID name of current config item, useful to construct referring config items. + for example, config item A may have ID "transforms#A" and config item B refers to A + and uses the resolved config content of A as an arg `{"args": {"other": "@transforms#A"}}`. + `id` defaults to `None`, if some component refers to current component, `id` must be a `string`. + + """ + return self.id + + def update_config(self, config: Any): + """ + Update the config content for a config item at runtime. + If having references, need resolve the config later. + A typical usage is to modify the initial config content at runtime and set back. + + Args: + config: content of a config item, can be a `dict`, `list`, `string`, `float`, `int`, etc. + + """ + self.config = config + self.resolved_config = None + self.is_resolved = False + if not self.get_id_of_refs(): + # if no references, can resolve the config immediately + self.resolve(refs=None) + + def get_config(self): + """ + Get the initial config content of current config item, usually set at the constructor. + It can be useful to dynamically update the config content before resolving. + + """ + return self.config + + def get_id_of_refs(self) -> List[str]: + """ + Recursively search all the content of current config item to get the IDs of references. + It's used to detect and resolve all the references before resolving current config item. + For `dict` and `list`, recursively check the sub-items. + For example: `{"args": {"lr": "$@epoch / 1000"}}`, the reference IDs: `["epoch"]`. + + """ + return find_refs_in_config(self.config, id=self.id) + + def resolve(self, refs: Optional[Dict] = None): + """ + If all the references are resolved in `refs`, resolve the config content with them to construct `resolved_config`. + + Args: + refs: all the resolved referring items with ID as keys, default to `None`. + + """ + self.resolved_config = resolve_config_with_refs(self.config, id=self.id, refs=refs, globals=self.globals) + self.is_resolved = True + + def get_resolved_config(self): + """ + Get the resolved config content, constructed in `resolve_config()`. The returned config has no references, + then use it in the program, for example: initial config item `{"intervals": "@epoch / 10"}` and references + `{"epoch": 100}`, the resolved config will be `{"intervals": 10}`. + + """ + return self.resolved_config + + +class Instantiable(ABC): + """ + Base class for instantiable object with module name and arguments. + + """ + + @abstractmethod + def resolve_module_name(self, *args: Any, **kwargs: Any): + """ + Utility function to resolve the target module name. + + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def resolve_args(self, *args: Any, **kwargs: Any): + """ + Utility function to resolve the arguments. + + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def is_disabled(self, *args: Any, **kwargs: Any): + """ + Utility function to check whether the target component is disabled. + + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def instantiate(self, *args: Any, **kwargs: Any): + """ + Instantiate the target component. + + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + +class ConfigComponent(ConfigItem, Instantiable): + """ + Subclass of :py:class:`monai.apps.ConfigItem`, the config item represents a target component of class + or function, and support to instantiate the component. Example of config item: + `{"": "LoadImage", "": {"keys": "image"}}` + + Here we predefined 4 keys: ``, ``, ``, `` for component config: + - '' - class / function name in the modules of packages. + - '' - directly specify the module path, based on PYTHONPATH, ignore '' if specified. + - '' - arguments to initialize the component instance. + - '' - if defined `'': True`, will skip the buiding, useful for development or tuning. + + The typical usage of the APIs: + - Initialize with config content. + - If no references, `instantiate` the component if having "" or "" keywords and return the instance. + - If having references, get the IDs of its referring components. + - When all the referring components are resolved, resolve the config content with them, execute expressions in + the config and `instantiate`. + + .. code-block:: python + + locator = ComponentLocator(excludes=[""]) + config = {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}} + + configer = ConfigComponent(config, id="test_config", locator=locator) + configer.resolve_config(refs={"dataset": Dataset(data=[1, 2])}) + configer.get_resolved_config() + dataloader: DataLoader = configer.instantiate() + + Args: + config: content of a component config item, should be a dict with `` or `` key. + id: ID name of current config item, useful to construct referring config items. + for example, config item A may have ID "transforms#A" and config item B refers to A + and uses the resolved config content of A as an arg `{"args": {"other": "@transforms#A"}}`. + `id` defaults to `None`, if some component refers to current component, `id` must be a `string`. + locator: `ComponentLocator` to help locate the module path of `` in the config and instantiate. + if `None`, will create a new `ComponentLocator` with specified `excludes`. + excludes: if `locator` is None, create a new `ComponentLocator` with `excludes`. any string of the `excludes` + exists in the module name, don't import this module. + globals: to support executable string in the config, sometimes we need to provide the global variables + which are referred in the executable string. for example: `globals={"monai": monai} will be useful + for config `"collate_fn": "$monai.data.list_data_collate"`. + + """ + + def __init__( + self, + config: Any, + id: Optional[str] = None, + locator: Optional[ComponentLocator] = None, + excludes: Optional[Union[Sequence[str], str]] = None, + globals: Optional[Dict] = None, + ) -> None: + super().__init__(config=config, id=id, globals=globals) + self.locator = ComponentLocator(excludes=excludes) if locator is None else locator + + def resolve_module_name(self): + """ + Utility function used in `instantiate()` to resolve the target module name from provided config content. + The config content must have `` or ``. + + """ + config = self.get_resolved_config() + path = config.get("", None) + if path is not None: + if "" in config: + warnings.warn(f"should not set both '' and '', default to use '': {path}.") + return path + + name = config.get("", None) + if name is None: + raise ValueError("must provide `` or `` of target component to instantiate.") + + module = self.locator.get_component_module_name(name) + if module is None: + raise ModuleNotFoundError(f"can not find component '{name}' in {self.locator.MOD_START} modules.") + if isinstance(module, list): + warnings.warn( + f"there are more than 1 component have name `{name}`: {module}, use the first one `{module[0]}." + f" if want to use others, please set its module path in `` directly." + ) + module = module[0] + return f"{module}.{name}" + + def resolve_args(self): + """ + Utility function used in `instantiate()` to resolve the arguments from config content of target component. + + """ + return self.get_resolved_config().get("", {}) + + def is_disabled(self): + """ + Utility function used in `instantiate()` to check whether the target component is disabled. + + """ + return self.get_resolved_config().get("", False) + + def instantiate(self, **kwargs) -> object: # type: ignore + """ + Instantiate component based on the resolved config content. + The target component must be a `class` or a `function`, otherwise, return `None`. + + Args: + kwargs: args to override / add the config args when instantiation. + + """ + if not self.is_resolved: + warnings.warn( + "the config content of current component has not been resolved," + " please try to resolve the references first." + ) + return None + if not is_instantiable(self.get_resolved_config()) or self.is_disabled(): + # if not a class or function or marked as `disabled`, skip parsing and return `None` + return None + + modname = self.resolve_module_name() + args = self.resolve_args() + args.update(kwargs) + return instantiate(modname, **args) diff --git a/monai/apps/mmars/config_parser.py b/monai/apps/mmars/config_parser.py index 767ad763c3..768a2c89b1 100644 --- a/monai/apps/mmars/config_parser.py +++ b/monai/apps/mmars/config_parser.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,8 +12,117 @@ import importlib from typing import Any, Dict, List, Optional, Sequence, Union -from monai.apps.mmars.config_resolver import ConfigComponent, ConfigResolver -from monai.utils.module import ClassScanner +from monai.apps.mmars.config_item import ConfigComponent, ConfigItem, ComponentLocator +from monai.apps.mmars.utils import is_instantiable + +class ConfigResolver: + """ + Utility class to resolve the dependencies between config components and build instance for specified `id`. + + Args: + components: config components to resolve, if None, can also `add()` component in runtime. + + """ + + def __init__(self, components: Optional[Dict[str, ConfigComponent]] = None): + self.resolved_configs: Dict[str, str] = {} + self.resolved_components: Dict[str, Any] = {} + self.components = {} if components is None else components + + def add(self, component: ConfigComponent): + """ + Add a component to the resolution graph. + + Args: + component: a config component to resolve. + + """ + id = component.get_id() + if id in self.components: + raise ValueError(f"id '{id}' is already added.") + self.components[id] = component + + def _resolve_one_component(self, id: str, instantiate: bool = True, waiting_list: Optional[List[str]] = None): + """ + Resolve one component with specified id name. + If has unresolved dependencies, recursively resolve the dependencies first. + + Args: + id: id name of expected component to resolve. + instantiate: after resolving all the dependencies, whether to build instance. + if False, can support lazy instantiation with the resolved config later. + default to `True`. + waiting_list: list of components wait to resolve dependencies. it's used to detect circular dependencies + when resolving dependencies like: `{"name": "A", "dep": "@B"}` and `{"name": "B", "dep": "@A"}`. + + """ + if waiting_list is None: + waiting_list = [] + waiting_list.append(id) + com = self.components[id] + dep_ids = com.get_dependent_ids() + # if current component has dependency already in the waiting list, that's circular dependencies + for d in dep_ids: + if d in waiting_list: + raise ValueError(f"detected circular dependencies for id='{d}' in the config content.") + + deps = {} + if len(dep_ids) > 0: + # # check whether the component has any unresolved deps + for comp_id in dep_ids: + if comp_id not in self.resolved_components: + # this dependent component is not resolved + if comp_id not in self.components: + raise RuntimeError(f"the dependent component `{comp_id}` is not in config.") + # resolve the dependency first + self._resolve_one_component(id=comp_id, instantiate=True, waiting_list=waiting_list) + deps[comp_id] = self.resolved_components[comp_id] + # all dependent components are resolved already + updated_config = com.get_updated_config(deps) + resolved_com = None + + if instantiate: + resolved_com = com.build(updated_config) + self.resolved_configs[id] = updated_config + self.resolved_components[id] = resolved_com + + return updated_config, resolved_com + + def resolve_all(self): + """ + Resolve all the components and build instances. + + """ + for k in self.components.keys(): + self._resolve_one_component(id=k, instantiate=True) + + def get_resolved_component(self, id: str): + """ + Get the resolved instance component with specified id name. + If not resolved, try to resolve it first. + + Args: + id: id name of the expected component. + + """ + if id not in self.resolved_components: + self._resolve_one_component(id=id, instantiate=True) + return self.resolved_components[id] + + def get_resolved_config(self, id: str): + """ + Get the resolved config component with specified id name, then can be used for lazy instantiation. + If not resolved, try to resolve it with `instantiation=False` first. + + Args: + id: id name of the expected config component. + + """ + if id not in self.resolved_configs: + config, _ = self._resolve_one_component(id=id, instantiate=False) + else: + config = self.resolved_configs[id] + return config class ConfigParser: @@ -23,9 +132,7 @@ class ConfigParser: For more details of the config format, please check :py:class:`monai.apps.ConfigComponent`. Args: - pkgs: the expected packages to scan modules and parse class names in the config. - modules: the expected modules in the packages to scan for all the classes. - for example, to parser "LoadImage" in config, `pkgs` can be ["monai"], `modules` can be ["transforms"]. + excludes: if any string of the `excludes` exists in the full module name, don't import this module. global_imports: pre-import packages as global variables to execute the python `eval` commands. for example, pre-import `monai`, then execute `eval("monai.data.list_data_collate")`. default to `{"monai": "monai", "torch": "torch", "np": "numpy"}` as `numpy` and `torch` @@ -36,15 +143,14 @@ class ConfigParser: def __init__( self, - pkgs: Sequence[str], - modules: Sequence[str], + excludes: Optional[Union[Sequence[str], str]] = None, global_imports: Optional[Dict[str, Any]] = None, config: Optional[Any] = None, ): self.config = None if config is not None: self.set_config(config=config) - self.class_scanner = ClassScanner(pkgs=pkgs, modules=modules) + self.locator = ComponentLocator(excludes=excludes) self.global_imports: Dict[str, Any] = {"monai": "monai", "torch": "torch", "np": "numpy"} if global_imports is not None: for k, v in global_imports.items(): @@ -124,9 +230,12 @@ def _do_parse(self, config, id: Optional[str] = None): sub_id = i if id is None else f"{id}#{i}" self._do_parse(config=v, id=sub_id) if id is not None: - self.config_resolver.add( - ConfigComponent(id=id, config=config, class_scanner=self.class_scanner, globals=self.global_imports) - ) + if is_instantiable(config): + self.config_resolver.add( + ConfigComponent(id=id, config=config, locator=self.locator, globals=self.global_imports) + ) + else: + self.config_resolver.add(ConfigItem(id=id, config=config, globals=self.global_imports)) def parse_config(self, resolve_all: bool = False): """ @@ -169,15 +278,3 @@ def get_resolved_component(self, id: str): if self.config_resolver is None or not self.resolved: self.parse_config() return self.config_resolver.get_resolved_component(id=id) - - def build(self, config: Dict): - """ - Build a config to instance if no dependencies, usually used for lazy instantiation or ad-hoc build. - - Args: - config: dictionary config content to build. - - """ - return ConfigComponent( - id="", config=config, class_scanner=self.class_scanner, globals=self.global_imports - ).build() diff --git a/monai/apps/mmars/config_resolver.py b/monai/apps/mmars/config_resolver.py deleted file mode 100644 index 438ef5f55b..0000000000 --- a/monai/apps/mmars/config_resolver.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from typing import Any, Dict, List, Optional - -from monai.apps.mmars.utils import search_configs_with_deps, update_configs_with_deps -from monai.utils.module import ClassScanner, instantiate_class - - -class ConfigComponent: - """ - Utility class to manage every component in the config with a unique `id` name. - When recursively parsing a complicated config dictioanry, every item should be treated as a `ConfigComponent`. - For example: - - `{"preprocessing": [{"": "LoadImage", "": {"keys": "image"}}]}` - - `{"": "LoadImage", "": {"keys": "image"}}` - - `"": "LoadImage"` - - `"keys": "image"` - - It can search the config content and find out all the dependencies, then build the config to instance - when all the dependencies are resolved. - - Here we predefined several special marks to parse the config content: - - "": like "" is the name of a class, to distinguish it with regular key "name" in the config content. - now we have 4 keys: ``, ``, ``, ``. - - "XXX#YYY": join nested config ids, like "transforms#5" is id name of the 6th transform in the transforms list. - - "@XXX": use an instance as config item, like `"dataset": "@dataset"` uses `dataset` instance as the parameter. - - "$XXX": execute the string after "$" as python code with `eval()` function, like "$@model.parameters()". - - Args: - id: id name of current config component, for nested config items, use `#` to join ids. - for list component, use index from `0` as id. - for example: `transform`, `transform#5`, `transform#5##keys`, etc. - config: config content of current component, can be a `dict`, `list`, `string`, `float`, `int`, etc. - class_scanner: ClassScanner to help get the class name or path in the config and build instance. - globals: to support executable string in the config, sometimes we need to provide the global variables - which are referred in the executable string. for example: `globals={"monai": monai} will be useful - for config `"collate_fn": "$monai.data.list_data_collate"`. - - """ - - def __init__(self, id: str, config: Any, class_scanner: ClassScanner, globals: Optional[Dict] = None) -> None: - self.id = id - self.config = config - self.class_scanner = class_scanner - self.globals = globals - - def get_id(self) -> str: - """ - Get the id name of current config component. - - """ - return self.id - - def get_config(self): - """ - Get the raw config content of current config component. - - """ - return self.config - - def get_dependent_ids(self) -> List[str]: - """ - Recursively search all the content of current config compoent to get the ids of dependencies. - It's used to build all the dependencies before build current config component. - For `dict` and `list`, treat every item as a dependency. - For example, for `{"": "DataLoader", "": {"dataset": "@dataset"}}`, the dependency ids: - `["", "", "#dataset", "dataset"]`. - - """ - return search_configs_with_deps(config=self.config, id=self.id) - - def get_updated_config(self, deps: dict): - """ - If all the dependencies are ready in `deps`, update the config content with them and return new config. - It can be used for lazy instantiation. - - Args: - deps: all the dependent components with ids. - - """ - return update_configs_with_deps(config=self.config, deps=deps, id=self.id, globals=self.globals) - - def _check_dependency(self, config): - """ - Check whether current config still has unresolved dependencies or executable string code. - - Args: - config: config content to check. - - """ - if isinstance(config, list): - for i in config: - if self._check_dependency(i): - return True - if isinstance(config, dict): - for v in config.values(): - if self._check_dependency(v): - return True - if isinstance(config, str): - if config.startswith("&") or "@" in config: - return True - return False - - def build(self, config: Optional[Dict] = None) -> object: - """ - Build component instance based on the provided dictonary config. - Supported special keys for the config: - - '' - class name in the modules of packages. - - '' - directly specify the class path, based on PYTHONPATH, ignore '' if specified. - - '' - arguments to initialize the component instance. - - '' - if defined `'': True`, will skip the buiding, useful for development or tuning. - - Args: - config: dictionary config that defines a component. - - Raises: - ValueError: must provide `` or `` of class to build component. - ValueError: can not find component class. - - """ - config = self.config if config is None else config - if self._check_dependency(config=config): - warnings.warn("config content has other dependencies or executable string, skip `build`.") - return config - - if ( - not isinstance(config, dict) - or ("" not in config and "" not in config) - or config.get("") is True - ): - # if marked as `disabled`, skip parsing - return config - - class_args = config.get("", {}) - class_path = self._get_class_path(config) - return instantiate_class(class_path, **class_args) - - def _get_class_path(self, config): - """ - Get the path of class specified in the config content. - - Args: - config: dictionary config that defines a component. - - """ - class_path = config.get("", None) - if class_path is None: - class_name = config.get("", None) - if class_name is None: - raise ValueError("must provide `` or `` of class to build component.") - module_name = self.class_scanner.get_class_module_name(class_name) - if module_name is None: - raise ValueError(f"can not find component class '{class_name}'.") - class_path = f"{module_name}.{class_name}" - - return class_path - - -class ConfigResolver: - """ - Utility class to resolve the dependencies between config components and build instance for specified `id`. - - Args: - components: config components to resolve, if None, can also `add()` component in runtime. - - """ - - def __init__(self, components: Optional[Dict[str, ConfigComponent]] = None): - self.resolved_configs: Dict[str, str] = {} - self.resolved_components: Dict[str, Any] = {} - self.components = {} if components is None else components - - def add(self, component: ConfigComponent): - """ - Add a component to the resolution graph. - - Args: - component: a config component to resolve. - - """ - id = component.get_id() - if id in self.components: - raise ValueError(f"id '{id}' is already added.") - self.components[id] = component - - def _resolve_one_component(self, id: str, instantiate: bool = True, waiting_list: Optional[List[str]] = None): - """ - Resolve one component with specified id name. - If has unresolved dependencies, recursively resolve the dependencies first. - - Args: - id: id name of expected component to resolve. - instantiate: after resolving all the dependencies, whether to build instance. - if False, can support lazy instantiation with the resolved config later. - default to `True`. - waiting_list: list of components wait to resolve dependencies. it's used to detect circular dependencies - when resolving dependencies like: `{"name": "A", "dep": "@B"}` and `{"name": "B", "dep": "@A"}`. - - """ - if waiting_list is None: - waiting_list = [] - waiting_list.append(id) - com = self.components[id] - dep_ids = com.get_dependent_ids() - # if current component has dependency already in the waiting list, that's circular dependencies - for d in dep_ids: - if d in waiting_list: - raise ValueError(f"detected circular dependencies for id='{d}' in the config content.") - - deps = {} - if len(dep_ids) > 0: - # # check whether the component has any unresolved deps - for comp_id in dep_ids: - if comp_id not in self.resolved_components: - # this dependent component is not resolved - if comp_id not in self.components: - raise RuntimeError(f"the dependent component `{comp_id}` is not in config.") - # resolve the dependency first - self._resolve_one_component(id=comp_id, instantiate=True, waiting_list=waiting_list) - deps[comp_id] = self.resolved_components[comp_id] - # all dependent components are resolved already - updated_config = com.get_updated_config(deps) - resolved_com = None - - if instantiate: - resolved_com = com.build(updated_config) - self.resolved_configs[id] = updated_config - self.resolved_components[id] = resolved_com - - return updated_config, resolved_com - - def resolve_all(self): - """ - Resolve all the components and build instances. - - """ - for k in self.components.keys(): - self._resolve_one_component(id=k, instantiate=True) - - def get_resolved_component(self, id: str): - """ - Get the resolved instance component with specified id name. - If not resolved, try to resolve it first. - - Args: - id: id name of the expected component. - - """ - if id not in self.resolved_components: - self._resolve_one_component(id=id, instantiate=True) - return self.resolved_components[id] - - def get_resolved_config(self, id: str): - """ - Get the resolved config component with specified id name, then can be used for lazy instantiation. - If not resolved, try to resolve it with `instantiation=False` first. - - Args: - id: id name of the expected config component. - - """ - if id not in self.resolved_configs: - config, _ = self._resolve_one_component(id=id, instantiate=False) - else: - config = self.resolved_configs[id] - return config diff --git a/monai/apps/mmars/schema/metadata.json b/monai/apps/mmars/schema/metadata.json deleted file mode 100644 index babee8b30e..0000000000 --- a/monai/apps/mmars/schema/metadata.json +++ /dev/null @@ -1,71 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://monai.io/mmar_metadata_schema.json", - "title": "metadata", - "description": "metadata that defines the context information for MMAR.", - "type": "object", - "properties": { - "version": { - "description": "version number of this MMAR.", - "type": "string" - }, - "monai_version": { - "description": "version number of MONAI used in this MMAR.", - "type": "string" - }, - "pytorch_version": { - "description": "version number of PyTorch used in this MMAR.", - "type": "string" - }, - "numpy_version": { - "description": "version number of MONAI used in this MMAR.", - "type": "string" - }, - "network_data_format": { - "description": "define the input and output data format for network.", - "type": "object", - "properties": { - "inputs": { - "type": "object", - "properties": { - "image": { - "type": "object", - "properties": { - "type": { - "type": "string", - "description": "the data format for `image`." - }, - "format": { - "type": "string" - }, - "num_channels": { - "type": "integer", - "minimum": 1 - }, - "spatial_shape": { - "type": "array", - "items": { - "type": "integer", - "minumum": 1 - } - }, - "dtype": { - "type": "string" - }, - "value_range": { - "type": "array", - "items": { - "type": "number", - "unuqueItems": true - } - }, - "required": ["num_channels", "spatial_shape", "value_range"] - } - } - } - } - } - }, - "required": ["monai_version", "pytorch_version", "network_data_format"] - } -} diff --git a/monai/apps/mmars/scripts(pseudo code)/__init__.py b/monai/apps/mmars/scripts(pseudo code)/__init__.py deleted file mode 100644 index 14ae193634..0000000000 --- a/monai/apps/mmars/scripts(pseudo code)/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/monai/apps/mmars/scripts(pseudo code)/export.py b/monai/apps/mmars/scripts(pseudo code)/export.py deleted file mode 100644 index 59f404da29..0000000000 --- a/monai/apps/mmars/scripts(pseudo code)/export.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import json - -import torch -from monai.apps import ConfigParser -from ignite.handlers import Checkpoint -from monai.data import save_net_with_metadata -from monai.networks import convert_to_torchscript - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--weights', '-w', type=str, help='file path of the trained model weights', required=True) - parser.add_argument('--config', '-c', type=str, help='file path of config file that defines network', required=True) - parser.add_argument('--meta', '-e', type=str, help='file path of the meta data') - args = parser.parse_args() - - # load config file - with open(args.config) as f: - config_dict = json.load(f) - # load meta data - with open(args.meta) as f: - meta_dict = json.load(f) - - net: torch.nn.Module = None - # TODO: parse network definiftion from config file and construct network instance - config_parser = ConfigParser(config_dict) - net = config_parser.get_instance("network") - - checkpoint = torch.load(args.weights) - # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver - Checkpoint.load_objects(to_load={"model": net}, checkpoint=checkpoint) - - # convert to TorchScript model and save with meta data - net = convert_to_torchscript(model=net) - - save_net_with_metadata( - jit_obj=net, - filename_prefix_or_stream="model.ts", - include_config_vals=False, - append_timestamp=False, - meta_values=meta_dict, - more_extra_files={args.config: json.dumps(config_dict).encode()}, - ) - - -if __name__ == '__main__': - main() diff --git a/monai/apps/mmars/scripts(pseudo code)/inference.py b/monai/apps/mmars/scripts(pseudo code)/inference.py deleted file mode 100644 index bbe64740e1..0000000000 --- a/monai/apps/mmars/scripts(pseudo code)/inference.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import json - -import torch -from monai.apps import ConfigParser -from monai.data import decollate_batch -from monai.inferers import Inferer -from monai.transforms import Transform -from monai.utils.enums import CommonKeys - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--config', '-c', type=str, help='config file that defines components', required=True) - parser.add_argument('--meta', '-e', type=str, help='file path of the meta data') - parser.add_argument('--override', '-o', type=str, help='config file that override components', required=False) - args = parser.parse_args() - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - configs = {} - - # load meta data - with open(args.meta) as f: - configs.update(json.load(f)) - # load config file, can override meta data in config - with open(args.config) as f: - configs.update(json.load(f)) - - model: torch.nn.Module = None - dataloader: torch.utils.data.DataLoader = None - inferer: Inferer = None - postprocessing: Transform = None - # TODO: parse inference config file and construct instances - config_parser = ConfigParser(configs) - - # change JSON config content in python code, lazy instantiation - model_conf = config_parser.get_config("model") - model_conf["disabled"] = False - model = config_parser.build(model_conf).to(device) - - # instantialize the components immediately - dataloader = config_parser.get_instance("dataloader") - inferer = config_parser.get_instance("inferer") - postprocessing = config_parser.get_instance("postprocessing") - - model.eval() - with torch.no_grad(): - for d in dataloader: - images = d[CommonKeys.IMAGE].to(device) - # define sliding window size and batch size for windows inference - d[CommonKeys.PRED] = inferer(inputs=images, predictor=model) - # decollate the batch data into a list of dictionaries, then execute postprocessing transforms - [postprocessing(i) for i in decollate_batch(d)] - - -if __name__ == '__main__': - main() diff --git a/monai/apps/mmars/scripts(pseudo code)/verify_network.py b/monai/apps/mmars/scripts(pseudo code)/verify_network.py deleted file mode 100644 index db9e247994..0000000000 --- a/monai/apps/mmars/scripts(pseudo code)/verify_network.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import json - -import torch -from monai.apps import ConfigParser -from monai.utils.type_conversion import get_equivalent_dtype - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--config', '-c', type=str, help='config file that defines components', required=True) - parser.add_argument('--meta', '-e', type=str, help='file path of the meta data') - args = parser.parse_args() - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - configs = {} - - # load meta data - with open(args.meta) as f: - configs.update(json.load(f)) - # load config file, can override meta data in config - with open(args.config) as f: - configs.update(json.load(f)) - - model: torch.nn.Module = None - # TODO: parse inference config file and construct instances - config_parser = ConfigParser(configs) - - model = config_parser.get_instance("model") - input_channels = config_parser.get_config("network_data_format#inputs#image#num_channels") - input_spatial_shape = tuple(config_parser.get_config("network_data_format#inputs#image#spatial_shape")) - dtype = config_parser.get_config("network_data_format#inputs#image#dtype") - dtype = get_equivalent_dtype(dtype, data_type=torch.Tensor) - - output_channels = config_parser.get_config("network_data_format#outputs#pred#num_channels") - output_spatial_shape = tuple(config_parser.get_config("network_data_format#outputs#pred#spatial_shape")) - - model.eval() - with torch.no_grad(): - test_data = torch.rand(*(input_channels, *input_spatial_shape), dtype=dtype, device=device) - output = model(test_data) - if output.shape[0] != output_channels: - raise ValueError(f"channel number of output data doesn't match expection: {output_channels}.") - if output.shape[1:] != output_spatial_shape: - raise ValueError(f"spatial shape of output data doesn't match expection: {output_spatial_shape}.") - - -if __name__ == '__main__': - main() diff --git a/monai/apps/mmars/utils.py b/monai/apps/mmars/utils.py index f532ff8cc1..bc7b516886 100644 --- a/monai/apps/mmars/utils.py +++ b/monai/apps/mmars/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,79 +10,160 @@ # limitations under the License. import re -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union -def search_configs_with_deps(config: Union[Dict, List, str], id: str, deps: Optional[List[str]] = None) -> List[str]: +def match_refs_pattern(value: str) -> List[str]: """ - Recursively search all the content of input config compoent to get the ids of dependencies. - It's used to build all the dependencies before build current config component. - For `dict` and `list`, treat every item as a dependency. - For example, for `{"": "DataLoader", "": {"dataset": "@dataset"}}`, the dependency ids: - `["", "", "#dataset", "dataset"]`. + Match regular expression for the input string to find the references. + The reference part starts with "@", like: "@XXX#YYY#ZZZ". + + Args: + value: input value to match regular expression. + + """ + refs: List[str] = [] + # regular expression pattern to match "@XXX" or "@XXX#YYY" + result = re.compile(r"@\w*[\#\w]*").findall(value) + for item in result: + if is_expression(value) or value == item: + # only check when string starts with "$" or the whole content is "@XXX" + ref_obj_id = item[1:] + if ref_obj_id not in refs: + refs.append(ref_obj_id) + return refs + + +def resolve_refs_pattern(value: str, refs: Dict, globals: Optional[Dict] = None) -> str: + """ + Match regular expression for the input string to update content with the references. + The reference part starts with "@", like: "@XXX#YYY#ZZZ". + References dictionary must contain the referring IDs as keys. + + Args: + value: input value to match regular expression. + refs: all the referring components with ids as keys, default to `None`. + globals: predefined global variables to execute code string with `eval()`. + + """ + # regular expression pattern to match "@XXX" or "@XXX#YYY" + result = re.compile(r"@\w*[\#\w]*").findall(value) + for item in result: + ref_id = item[1:] + if is_expression(value): + # replace with local code and execute later + value = value.replace(item, f"refs['{ref_id}']") + elif value == item: + if ref_id not in refs: + raise KeyError(f"can not find expected ID '{ref_id}' in the references.") + value = refs[ref_id] + if is_expression(value): + # execute the code string with python `eval()` + value = eval(value[1:], globals, {"refs": refs}) + return value + + +def find_refs_in_config( + config: Union[Dict, List, str], + id: Optional[str] = None, + refs: Optional[List[str]] = None, + match_fn: Callable = match_refs_pattern, +) -> List[str]: + """ + Recursively search all the content of input config item to get the ids of references. + References mean (1) referring to the ID of other item, can be extracted by `match_fn`, for example: + `{"lr": "$@epoch / 100"}` with "@" mark, the referring IDs: `["epoch"]`. (2) if sub-item in the config + is instantiable, treat it as reference because must instantiate it before resolving current config. + For `dict` and `list`, recursively check the sub-items. Args: config: input config content to search. - id: id name for the input config. - deps: list of the id name of existing dependencies, default to None. + id: ID name for the input config, default to `None`. + refs: list of the ID name of existing references, default to `None`. + match_fn: callable function to match config item for references, take `config` as parameter. """ - deps_: List[str] = [] if deps is None else deps - pattern = re.compile(r"@\w*[\#\w]*") # match ref as args: "@XXX#YYY#ZZZ" + refs_: List[str] = [] if refs is None else refs + if isinstance(config, str): + refs_ += match_fn(value=config) + if isinstance(config, list): for i, v in enumerate(config): - sub_id = f"{id}#{i}" - # all the items in the list should be marked as dependent reference - deps_.append(sub_id) - deps_ = search_configs_with_deps(v, sub_id, deps_) + sub_id = f"{id}#{i}" if id is not None else f"{i}" + if is_instantiable(v): + refs_.append(sub_id) + refs_ = find_refs_in_config(v, sub_id, refs_, match_fn) if isinstance(config, dict): for k, v in config.items(): - sub_id = f"{id}#{k}" - # all the items in the dict should be marked as dependent reference - deps_.append(sub_id) - deps_ = search_configs_with_deps(v, sub_id, deps_) - if isinstance(config, str): - result = pattern.findall(config) - for item in result: - if config.startswith("$") or config == item: - ref_obj_id = item[1:] - if ref_obj_id not in deps_: - deps_.append(ref_obj_id) - return deps_ + sub_id = f"{id}#{k}" if id is not None else f"{k}" + if is_instantiable(v): + refs_.append(sub_id) + refs_ = find_refs_in_config(v, sub_id, refs_, match_fn) + return refs_ -def update_configs_with_deps(config: Union[Dict, List, str], deps: dict, id: str, globals: Optional[Dict] = None): +def resolve_config_with_refs( + config: Union[Dict, List, str], + id: Optional[str] = None, + refs: Optional[Dict] = None, + globals: Optional[Dict] = None, + match_fn: Callable = resolve_refs_pattern, +): """ - With all the dependencies in `deps`, update the config content with them and return new config. - It can be used for lazy instantiation. + With all the references in `refs`, resolve the config content with them and return new config. Args: - config: input config content to update. - deps: all the dependent components with ids. - id: id name for the input config. + config: input config content to resolve. + id: ID name for the input config, default to `None`. + refs: all the referring components with ids, default to `None`. globals: predefined global variables to execute code string with `eval()`. + match_fn: callable function to match config item for references, take `config`, + `refs` and `globals` as parameters. """ - pattern = re.compile(r"@\w*[\#\w]*") # match ref as args: "@XXX#YYY#ZZZ" + refs_: Dict = {} if refs is None else refs + if isinstance(config, str): + config = match_fn(config, refs, globals) if isinstance(config, list): - # all the items in the list should be replaced with the reference - return [deps[f"{id}#{i}"] for i in range(len(config))] + # all the items in the list should be replaced with the references + ret_list: List = [] + for i, v in enumerate(config): + sub = f"{id}#{i}" if id is not None else f"{i}" + ret_list.append( + refs_[sub] if is_instantiable(v) else resolve_config_with_refs(v, sub, refs_, globals, match_fn) + ) + return ret_list if isinstance(config, dict): - # all the items in the dict should be replaced with the reference - return {k: deps[f"{id}#{k}"] for k, _ in config.items()} - if isinstance(config, str): - result = pattern.findall(config) - config_: str = config - for item in result: - ref_obj_id = item[1:] - if config_.startswith("$"): - # replace with local code and execute soon - config_ = config_.replace(item, f"deps['{ref_obj_id}']") - elif config_ == item: - config_ = deps[ref_obj_id] - - if isinstance(config_, str): - if config_.startswith("$"): - config_ = eval(config_[1:], globals, {"deps": deps}) - return config_ + # all the items in the dict should be replaced with the references + ret_dict: Dict = {} + for k, v in config.items(): + sub = f"{id}#{k}" if id is not None else f"{k}" + ret_dict.update( + {k: refs_[sub] if is_instantiable(v) else resolve_config_with_refs(v, sub, refs_, globals, match_fn)} + ) + return ret_dict return config + + +def is_instantiable(config: Union[Dict, List, str]) -> bool: + """ + Check whether the content of the config represents a `class` or `function` to instantiate + with specified "" or "". + + Args: + config: input config content to check. + + """ + return isinstance(config, dict) and ("" in config or "" in config) + + +def is_expression(config: Union[Dict, List, str]) -> bool: + """ + Check whether the content of the config is executable expression string. + If True, the string should start with "$" mark. + + Args: + config: input config content to check. + + """ + return isinstance(config, str) and config.startswith("$") diff --git a/monai/apps/pathology/transforms/spatial/array.py b/monai/apps/pathology/transforms/spatial/array.py index fe1383c08d..b271a3331f 100644 --- a/monai/apps/pathology/transforms/spatial/array.py +++ b/monai/apps/pathology/transforms/spatial/array.py @@ -190,8 +190,7 @@ def randomize(self, img_size: Sequence[int]) -> None: self.random_idxs = np.array((0,)) def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: - img_np: np.ndarray - img_np, *_ = convert_data_type(image, np.ndarray) # type: ignore + img_np, *_ = convert_data_type(image, np.ndarray) # add random offset self.randomize(img_size=img_np.shape) diff --git a/monai/apps/utils.py b/monai/apps/utils.py index 9db62f336d..209dc796cf 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -183,24 +183,26 @@ def download_url( ) logger.info(f"File exists: {filepath}, skipped downloading.") return - - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_name = Path(tmp_dir, _basename(filepath)) - if urlparse(url).netloc == "drive.google.com": - if not has_gdown: - raise RuntimeError("To download files from Google Drive, please install the gdown dependency.") - gdown.download(url, f"{tmp_name}", quiet=not progress) - else: - _download_with_progress(url, tmp_name, progress=progress) - if not tmp_name.exists(): - raise RuntimeError( - f"Download of file from {url} to {filepath} failed due to network issue or denied permission." - ) - file_dir = filepath.parent - if file_dir: - os.makedirs(file_dir, exist_ok=True) - shutil.move(f"{tmp_name}", f"{filepath}") # copy the downloaded to a user-specified cache. - logger.info(f"Downloaded: {filepath}") + try: + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_name = Path(tmp_dir, _basename(filepath)) + if urlparse(url).netloc == "drive.google.com": + if not has_gdown: + raise RuntimeError("To download files from Google Drive, please install the gdown dependency.") + gdown.download(url, f"{tmp_name}", quiet=not progress) + else: + _download_with_progress(url, tmp_name, progress=progress) + if not tmp_name.exists(): + raise RuntimeError( + f"Download of file from {url} to {filepath} failed due to network issue or denied permission." + ) + file_dir = filepath.parent + if file_dir: + os.makedirs(file_dir, exist_ok=True) + shutil.move(f"{tmp_name}", f"{filepath}") # copy the downloaded to a user-specified cache. + except (PermissionError, NotADirectoryError): # project-monai/monai issue #3613 #3757 for windows + pass + logger.info(f"Downloaded: {filepath}") if not check_hash(filepath, hash_val, hash_type): raise RuntimeError( f"{hash_type} check of downloaded file failed: URL={url}, " diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index 91b944bde5..fd7ca572e6 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -161,9 +161,9 @@ def get_system_info() -> OrderedDict: ), ) mem = psutil.virtual_memory() - _dict_append(output, "Total physical memory (GB)", lambda: round(mem.total / 1024 ** 3, 1)) - _dict_append(output, "Available memory (GB)", lambda: round(mem.available / 1024 ** 3, 1)) - _dict_append(output, "Used memory (GB)", lambda: round(mem.used / 1024 ** 3, 1)) + _dict_append(output, "Total physical memory (GB)", lambda: round(mem.total / 1024**3, 1)) + _dict_append(output, "Available memory (GB)", lambda: round(mem.available / 1024**3, 1)) + _dict_append(output, "Used memory (GB)", lambda: round(mem.used / 1024**3, 1)) return output @@ -209,7 +209,7 @@ def get_gpu_info() -> OrderedDict: _dict_append(output, f"GPU {gpu} Is integrated", lambda: bool(gpu_info.is_integrated)) _dict_append(output, f"GPU {gpu} Is multi GPU board", lambda: bool(gpu_info.is_multi_gpu_board)) _dict_append(output, f"GPU {gpu} Multi processor count", lambda: gpu_info.multi_processor_count) - _dict_append(output, f"GPU {gpu} Total memory (GB)", lambda: round(gpu_info.total_memory / 1024 ** 3, 1)) + _dict_append(output, f"GPU {gpu} Total memory (GB)", lambda: round(gpu_info.total_memory / 1024**3, 1)) _dict_append(output, f"GPU {gpu} CUDA capability (maj.min)", lambda: f"{gpu_info.major}.{gpu_info.minor}") return output diff --git a/monai/config/type_definitions.py b/monai/config/type_definitions.py index 686befb2eb..16919c2ec4 100644 --- a/monai/config/type_definitions.py +++ b/monai/config/type_definitions.py @@ -63,15 +63,14 @@ #: Type of datatypes: Adapted from https://github.com/numpy/numpy/blob/v1.21.4/numpy/typing/_dtype_like.py#L121 DtypeLike = Union[np.dtype, type, str, None] +#: NdarrayOrTensor: Union of numpy.ndarray and torch.Tensor to be used for typing +NdarrayOrTensor = Union[np.ndarray, torch.Tensor] + #: NdarrayTensor # # Generic type which can represent either a numpy.ndarray or a torch.Tensor # Unlike Union can create a dependence between parameter(s) / return(s) -NdarrayTensor = TypeVar("NdarrayTensor", np.ndarray, torch.Tensor) - - -#: NdarrayOrTensor: Union of numpy.ndarray and torch.Tensor to be used for typing -NdarrayOrTensor = Union[np.ndarray, torch.Tensor] +NdarrayTensor = TypeVar("NdarrayTensor", bound=NdarrayOrTensor) #: TensorOrList: The TensorOrList type is used for defining `batch-first Tensor` or `list of channel-first Tensor`. TensorOrList = Union[torch.Tensor, Sequence[torch.Tensor]] diff --git a/monai/csrc/ext.cpp b/monai/csrc/ext.cpp index a2fa8bfc56..ac43e6fd3e 100644 --- a/monai/csrc/ext.cpp +++ b/monai/csrc/ext.cpp @@ -31,6 +31,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::enum_(m, "BoundType") .value("replicate", monai::BoundType::Replicate, "a a a | a b c d | d d d") .value("nearest", monai::BoundType::Replicate, "a a a | a b c d | d d d") + .value("border", monai::BoundType::Replicate, "a a a | a b c d | d d d") .value("dct1", monai::BoundType::DCT1, "d c b | a b c d | c b a") .value("mirror", monai::BoundType::DCT1, "d c b | a b c d | c b a") .value("dct2", monai::BoundType::DCT2, "c b a | a b c d | d c b") @@ -43,6 +44,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .value("wrap", monai::BoundType::DFT, "b c d | a b c d | a b c") // .value("sliding", monai::BoundType::Sliding) .value("zero", monai::BoundType::Zero, "0 0 0 | a b c d | 0 0 0") + .value("zeros", monai::BoundType::Zero, "0 0 0 | a b c d | 0 0 0") .export_values(); // resample interpolation mode diff --git a/monai/csrc/resample/pushpull_cpu.cpp b/monai/csrc/resample/pushpull_cpu.cpp index d83557c6c3..c638958a47 100644 --- a/monai/csrc/resample/pushpull_cpu.cpp +++ b/monai/csrc/resample/pushpull_cpu.cpp @@ -1527,10 +1527,10 @@ MONAI_NAMESPACE_DEVICE { // cpu iy0 = bound::index(bound1, iy0, src_Y); iz0 = bound::index(bound2, iz0, src_Z); - // Offsets into source volume offset_t o000, o100, o010, o001, o110, o011, o101, o111; if (do_pull || do_grad || do_sgrad) { + // Offsets into source volume o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; @@ -1539,18 +1539,12 @@ MONAI_NAMESPACE_DEVICE { // cpu o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; + } else if (!(do_push || do_count)) { + o000 = o100 = o010 = o001 = o110 = o011 = o101 = o111 = 0; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { - o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o001 = ix0 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o110 = ix1 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; - o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; scalar_t gx = static_cast(0); scalar_t gy = static_cast(0); scalar_t gz = static_cast(0); @@ -1657,16 +1651,19 @@ MONAI_NAMESPACE_DEVICE { // cpu grad_ptr_NXYZ[grad_sC] = gy; grad_ptr_NXYZ[grad_sC * 2] = gz; } + if (do_push || do_count) { + // Offsets into 'push' volume + o000 = ix0 * out_sX + iy0 * out_sY + iz0 * out_sZ; + o100 = ix1 * out_sX + iy0 * out_sY + iz0 * out_sZ; + o010 = ix0 * out_sX + iy1 * out_sY + iz0 * out_sZ; + o001 = ix0 * out_sX + iy0 * out_sY + iz1 * out_sZ; + o110 = ix1 * out_sX + iy1 * out_sY + iz0 * out_sZ; + o011 = ix0 * out_sX + iy1 * out_sY + iz1 * out_sZ; + o101 = ix1 * out_sX + iy0 * out_sY + iz1 * out_sZ; + o111 = ix1 * out_sX + iy1 * out_sY + iz1 * out_sZ; + } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_pull) { - o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o001 = ix0 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o110 = ix1 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; - o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ; scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC, src_ptr_NC += src_sC) { @@ -1678,14 +1675,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~ else if (do_sgrad) { - o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o001 = ix0 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o110 = ix1 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; - o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -1758,16 +1747,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_count) { - // Offsets into 'push' volume - o000 = ix0 * out_sX + iy0 * out_sY + iz0 * out_sZ; - o100 = ix1 * out_sX + iy0 * out_sY + iz0 * out_sZ; - o010 = ix0 * out_sX + iy1 * out_sY + iz0 * out_sZ; - o001 = ix0 * out_sX + iy0 * out_sY + iz1 * out_sZ; - o110 = ix1 * out_sX + iy1 * out_sY + iz0 * out_sZ; - o011 = ix0 * out_sX + iy1 * out_sY + iz1 * out_sZ; - o101 = ix1 * out_sX + iy0 * out_sY + iz1 * out_sZ; - o111 = ix1 * out_sX + iy1 * out_sY + iz1 * out_sZ; - scalar_t* out_ptr_N = out_ptr + n * out_sN; bound::add(out_ptr_N, o000, w000, s000); bound::add(out_ptr_N, o100, w100, s100); @@ -1822,21 +1801,19 @@ MONAI_NAMESPACE_DEVICE { // cpu ix0 = bound::index(bound0, ix0, src_X); iy0 = bound::index(bound1, iy0, src_Y); - // Offsets into source volume offset_t o00, o10, o01, o11; if (do_pull || do_grad || do_sgrad) { + // Offsets into source volume o00 = ix0 * src_sX + iy0 * src_sY; o10 = ix1 * src_sX + iy0 * src_sY; o01 = ix0 * src_sX + iy1 * src_sY; o11 = ix1 * src_sX + iy1 * src_sY; + } else if (!(do_push || do_count)) { + o00 = o10 = o01 = o11 = 0; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { - o00 = ix0 * src_sX + iy0 * src_sY; - o10 = ix1 * src_sX + iy0 * src_sY; - o01 = ix0 * src_sX + iy1 * src_sY; - o11 = ix1 * src_sX + iy1 * src_sY; scalar_t gx = static_cast(0); scalar_t gy = static_cast(0); scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY; @@ -1893,12 +1870,15 @@ MONAI_NAMESPACE_DEVICE { // cpu (*grad_ptr_NXY) = gx; grad_ptr_NXY[grad_sC] = gy; } + if (do_push || do_count) { + // Offsets into 'push' volume + o00 = ix0 * out_sX + iy0 * out_sY; + o10 = ix1 * out_sX + iy0 * out_sY; + o01 = ix0 * out_sX + iy1 * out_sY; + o11 = ix1 * out_sX + iy1 * out_sY; + } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_pull) { - o00 = ix0 * src_sX + iy0 * src_sY; - o10 = ix1 * src_sX + iy0 * src_sY; - o01 = ix0 * src_sX + iy1 * src_sY; - o11 = ix1 * src_sX + iy1 * src_sY; scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY; scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC, src_ptr_NC += src_sC) { @@ -1908,10 +1888,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_sgrad) { - o00 = ix0 * src_sX + iy0 * src_sY; - o10 = ix1 * src_sX + iy0 * src_sY; - o01 = ix0 * src_sX + iy1 * src_sY; - o11 = ix1 * src_sX + iy1 * src_sY; scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -1926,11 +1902,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_push) { - // Offsets into 'push' volume - o00 = ix0 * out_sX + iy0 * out_sY; - o10 = ix1 * out_sX + iy0 * out_sY; - o01 = ix0 * out_sX + iy1 * out_sY; - o11 = ix1 * out_sX + iy1 * out_sY; scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY; scalar_t* out_ptr_NC = out_ptr + n * out_sN; if (trgt_K == 0) { @@ -1960,12 +1931,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_count) { - // Offsets into 'push' volume - o00 = ix0 * out_sX + iy0 * out_sY; - o10 = ix1 * out_sX + iy0 * out_sY; - o01 = ix0 * out_sX + iy1 * out_sY; - o11 = ix1 * out_sX + iy1 * out_sY; - scalar_t* out_ptr_N = out_ptr + n * out_sN; bound::add(out_ptr_N, o00, w00, s00); bound::add(out_ptr_N, o10, w10, s10); @@ -1996,20 +1961,19 @@ MONAI_NAMESPACE_DEVICE { // cpu ix1 = bound::index(bound0, ix0 + 1, src_X); ix0 = bound::index(bound0, ix0, src_X); - // Offsets into source volume offset_t o0, o1; if (do_pull || do_grad || do_sgrad) { + // Offsets into source volume o0 = ix0 * src_sX; o1 = ix1 * src_sX; + } else if (!(do_push || do_count)) { + o0 = o1 = 0; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { if (trgt_K == 0) { // backward w.r.t. push/pull - - o0 = ix0 * src_sX; - o1 = ix1 * src_sX; scalar_t gx = static_cast(0); scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -2035,10 +1999,13 @@ MONAI_NAMESPACE_DEVICE { // cpu // -> zero (make sure this is done at initialization) } } + if (do_push || do_count) { + // Offsets into 'push' volume + o0 = ix0 * out_sX; + o1 = ix1 * out_sX; + } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_pull) { - o0 = ix0 * src_sX; - o1 = ix1 * src_sX; scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX; scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) { @@ -2047,8 +2014,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_sgrad) { - o0 = ix0 * src_sX; - o1 = ix1 * src_sX; scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -2058,9 +2023,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_push) { - // Offsets into 'push' volume - o0 = ix0 * out_sX; - o1 = ix1 * out_sX; scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX; scalar_t* out_ptr_NC = out_ptr + n * out_sN; if (trgt_K == 0) { @@ -2081,10 +2043,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_count) { - // Offsets into 'push' volume - o0 = ix0 * out_sX; - o1 = ix1 * out_sX; - scalar_t* out_ptr_N = out_ptr + n * out_sN; bound::add(out_ptr_N, o0, w0, s0); bound::add(out_ptr_N, o1, w1, s1); diff --git a/monai/csrc/resample/pushpull_cuda.cu b/monai/csrc/resample/pushpull_cuda.cu index 4a2d6c27ef..461962cb80 100644 --- a/monai/csrc/resample/pushpull_cuda.cu +++ b/monai/csrc/resample/pushpull_cuda.cu @@ -1491,10 +1491,10 @@ MONAI_NAMESPACE_DEVICE { // cuda iy0 = bound::index(bound1, iy0, src_Y); iz0 = bound::index(bound2, iz0, src_Z); - // Offsets into source volume offset_t o000, o100, o010, o001, o110, o011, o101, o111; if (do_pull || do_grad || do_sgrad) { + // Offsets into source volume o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; @@ -1503,18 +1503,12 @@ MONAI_NAMESPACE_DEVICE { // cuda o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; + } else if (!(do_push || do_count)) { + o000 = o100 = o010 = o001 = o110 = o011 = o101 = o111 = 0; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { - o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o001 = ix0 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o110 = ix1 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; - o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; scalar_t gx = static_cast(0); scalar_t gy = static_cast(0); scalar_t gz = static_cast(0); @@ -1621,16 +1615,19 @@ MONAI_NAMESPACE_DEVICE { // cuda grad_ptr_NXYZ[grad_sC] = gy; grad_ptr_NXYZ[grad_sC * 2] = gz; } + if (do_push || do_count) { + // Offsets into 'push' volume + o000 = ix0 * out_sX + iy0 * out_sY + iz0 * out_sZ; + o100 = ix1 * out_sX + iy0 * out_sY + iz0 * out_sZ; + o010 = ix0 * out_sX + iy1 * out_sY + iz0 * out_sZ; + o001 = ix0 * out_sX + iy0 * out_sY + iz1 * out_sZ; + o110 = ix1 * out_sX + iy1 * out_sY + iz0 * out_sZ; + o011 = ix0 * out_sX + iy1 * out_sY + iz1 * out_sZ; + o101 = ix1 * out_sX + iy0 * out_sY + iz1 * out_sZ; + o111 = ix1 * out_sX + iy1 * out_sY + iz1 * out_sZ; + } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_pull) { - o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o001 = ix0 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o110 = ix1 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; - o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ; scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC, src_ptr_NC += src_sC) { @@ -1642,14 +1639,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~ else if (do_sgrad) { - o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o001 = ix0 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o110 = ix1 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; - o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -1672,15 +1661,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_push) { - // Offsets into 'push' volume - o000 = ix0 * out_sX + iy0 * out_sY + iz0 * out_sZ; - o100 = ix1 * out_sX + iy0 * out_sY + iz0 * out_sZ; - o010 = ix0 * out_sX + iy1 * out_sY + iz0 * out_sZ; - o001 = ix0 * out_sX + iy0 * out_sY + iz1 * out_sZ; - o110 = ix1 * out_sX + iy1 * out_sY + iz0 * out_sZ; - o011 = ix0 * out_sX + iy1 * out_sY + iz1 * out_sZ; - o101 = ix1 * out_sX + iy0 * out_sY + iz1 * out_sZ; - o111 = ix1 * out_sX + iy1 * out_sY + iz1 * out_sZ; scalar_t* trgt_ptr_NCXYZ = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY + d * trgt_sZ; scalar_t* out_ptr_NC = out_ptr + n * out_sN; if (trgt_K == 0) { @@ -1722,16 +1702,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_count) { - // Offsets into 'push' volume - o000 = ix0 * out_sX + iy0 * out_sY + iz0 * out_sZ; - o100 = ix1 * out_sX + iy0 * out_sY + iz0 * out_sZ; - o010 = ix0 * out_sX + iy1 * out_sY + iz0 * out_sZ; - o001 = ix0 * out_sX + iy0 * out_sY + iz1 * out_sZ; - o110 = ix1 * out_sX + iy1 * out_sY + iz0 * out_sZ; - o011 = ix0 * out_sX + iy1 * out_sY + iz1 * out_sZ; - o101 = ix1 * out_sX + iy0 * out_sY + iz1 * out_sZ; - o111 = ix1 * out_sX + iy1 * out_sY + iz1 * out_sZ; - scalar_t* out_ptr_N = out_ptr + n * out_sN; bound::add(out_ptr_N, o000, w000, s000); bound::add(out_ptr_N, o100, w100, s100); @@ -1786,21 +1756,19 @@ MONAI_NAMESPACE_DEVICE { // cuda ix0 = bound::index(bound0, ix0, src_X); iy0 = bound::index(bound1, iy0, src_Y); - // Offsets into source volume offset_t o00, o10, o01, o11; if (do_pull || do_grad || do_sgrad) { + // Offsets into source volume o00 = ix0 * src_sX + iy0 * src_sY; o10 = ix1 * src_sX + iy0 * src_sY; o01 = ix0 * src_sX + iy1 * src_sY; o11 = ix1 * src_sX + iy1 * src_sY; + } else if (!(do_push || do_count)) { + o00 = o10 = o01 = o11 = 0; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { - o00 = ix0 * src_sX + iy0 * src_sY; - o10 = ix1 * src_sX + iy0 * src_sY; - o01 = ix0 * src_sX + iy1 * src_sY; - o11 = ix1 * src_sX + iy1 * src_sY; scalar_t gx = static_cast(0); scalar_t gy = static_cast(0); scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY; @@ -1857,12 +1825,15 @@ MONAI_NAMESPACE_DEVICE { // cuda (*grad_ptr_NXY) = gx; grad_ptr_NXY[grad_sC] = gy; } + if (do_push || do_count) { + // Offsets into 'push' volume + o00 = ix0 * out_sX + iy0 * out_sY; + o10 = ix1 * out_sX + iy0 * out_sY; + o01 = ix0 * out_sX + iy1 * out_sY; + o11 = ix1 * out_sX + iy1 * out_sY; + } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_pull) { - o00 = ix0 * src_sX + iy0 * src_sY; - o10 = ix1 * src_sX + iy0 * src_sY; - o01 = ix0 * src_sX + iy1 * src_sY; - o11 = ix1 * src_sX + iy1 * src_sY; scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY; scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC, src_ptr_NC += src_sC) { @@ -1872,10 +1843,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_sgrad) { - o00 = ix0 * src_sX + iy0 * src_sY; - o10 = ix1 * src_sX + iy0 * src_sY; - o01 = ix0 * src_sX + iy1 * src_sY; - o11 = ix1 * src_sX + iy1 * src_sY; scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -1890,11 +1857,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_push) { - // Offsets into 'push' volume - o00 = ix0 * out_sX + iy0 * out_sY; - o10 = ix1 * out_sX + iy0 * out_sY; - o01 = ix0 * out_sX + iy1 * out_sY; - o11 = ix1 * out_sX + iy1 * out_sY; scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY; scalar_t* out_ptr_NC = out_ptr + n * out_sN; if (trgt_K == 0) { @@ -1924,12 +1886,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_count) { - // Offsets into 'push' volume - o00 = ix0 * out_sX + iy0 * out_sY; - o10 = ix1 * out_sX + iy0 * out_sY; - o01 = ix0 * out_sX + iy1 * out_sY; - o11 = ix1 * out_sX + iy1 * out_sY; - scalar_t* out_ptr_N = out_ptr + n * out_sN; bound::add(out_ptr_N, o00, w00, s00); bound::add(out_ptr_N, o10, w10, s10); @@ -1965,15 +1921,14 @@ MONAI_NAMESPACE_DEVICE { // cuda if (do_pull || do_grad || do_sgrad) { o0 = ix0 * src_sX; o1 = ix1 * src_sX; + } else if (!(do_push || do_count)) { + o0 = o1 = 0; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { if (trgt_K == 0) { // backward w.r.t. push/pull - - o0 = ix0 * src_sX; - o1 = ix1 * src_sX; scalar_t gx = static_cast(0); scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -1999,10 +1954,13 @@ MONAI_NAMESPACE_DEVICE { // cuda // -> zero (make sure this is done at initialization) } } + if (do_push || do_count) { + // Offsets into 'push' volume + o0 = ix0 * out_sX; + o1 = ix1 * out_sX; + } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_pull) { - o0 = ix0 * src_sX; - o1 = ix1 * src_sX; scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX; scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) { @@ -2011,8 +1969,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_sgrad) { - o0 = ix0 * src_sX; - o1 = ix1 * src_sX; scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -2022,9 +1978,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_push) { - // Offsets into 'push' volume - o0 = ix0 * out_sX; - o1 = ix1 * out_sX; scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX; scalar_t* out_ptr_NC = out_ptr + n * out_sN; if (trgt_K == 0) { @@ -2045,10 +1998,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_count) { - // Offsets into 'push' volume - o0 = ix0 * out_sX; - o1 = ix1 * out_sX; - scalar_t* out_ptr_N = out_ptr + n * out_sN; bound::add(out_ptr_N, o0, w0, s0); bound::add(out_ptr_N, o1, w1, s1); diff --git a/monai/data/__init__.py b/monai/data/__init__.py index bd49f40273..bed194d2f4 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -35,6 +35,16 @@ from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter from .image_dataset import ImageDataset from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader +from .image_writer import ( + SUPPORTED_WRITERS, + ImageWriter, + ITKWriter, + NibabelWriter, + PILWriter, + logger, + register_writer, + resolve_writer, +) from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti @@ -60,11 +70,13 @@ iter_patch_slices, json_hashing, list_data_collate, + orientation_ras_lps, pad_list_data_collate, partition_dataset, partition_dataset_classes, pickle_hashing, rectify_header_sform_qform, + reorient_spatial_axes, resample_datalist, select_cross_validation_folds, set_rnd, diff --git a/monai/data/dataloader.py b/monai/data/dataloader.py index 3117a27c02..d1f5bd4fe1 100644 --- a/monai/data/dataloader.py +++ b/monai/data/dataloader.py @@ -81,4 +81,4 @@ def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs) -> None: if "worker_init_fn" not in kwargs: kwargs.update({"worker_init_fn": worker_init_fn}) - super().__init__(dataset=dataset, num_workers=num_workers, **kwargs) # type: ignore[call-overload] + super().__init__(dataset=dataset, num_workers=num_workers, **kwargs) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 8a42ed5181..156708c7dd 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -269,7 +269,7 @@ def _pre_transform(self, item_transformed): random transform object """ - for _transform in self.transform.transforms: # type:ignore + for _transform in self.transform.transforms: # execute all the deterministic transforms if isinstance(_transform, Randomizable) or not isinstance(_transform, Transform): break @@ -513,7 +513,7 @@ def __init__( self.db_file = self.cache_dir / f"{db_name}.lmdb" self.lmdb_kwargs = lmdb_kwargs or {} if not self.lmdb_kwargs.get("map_size", 0): - self.lmdb_kwargs["map_size"] = 1024 ** 4 # default map_size + self.lmdb_kwargs["map_size"] = 1024**4 # default map_size # lmdb is single-writer multi-reader by default # the cache is created without multi-threading self._read_env = None @@ -676,6 +676,8 @@ def __init__( progress: bool = True, copy_cache: bool = True, as_contiguous: bool = True, + hash_as_key: bool = False, + hash_func: Callable[..., bytes] = pickle_hashing, ) -> None: """ Args: @@ -695,19 +697,29 @@ def __init__( may set `copy=False` for better performance. as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous. it may help improve the performance of following logic. + hash_as_key: whether to compute hash value of input data as the key to save cache, + if key exists, avoid saving duplicated content. it can help save memory when + the dataset has duplicated items or augmented dataset. + hash_func: if `hash_as_key`, a callable to compute hash from data items to be cached. + defaults to `monai.data.utils.pickle_hashing`. """ if not isinstance(transform, Compose): transform = Compose(transform) super().__init__(data=data, transform=transform) + self.set_num = cache_num # tracking the user-provided `cache_num` option + self.set_rate = cache_rate # tracking the user-provided `cache_rate` option self.progress = progress self.copy_cache = copy_cache self.as_contiguous = as_contiguous - self.cache_num = min(int(cache_num), int(len(data) * cache_rate), len(data)) + self.hash_as_key = hash_as_key + self.hash_func = hash_func self.num_workers = num_workers if self.num_workers is not None: self.num_workers = max(int(self.num_workers), 1) - self._cache: List = self._fill_cache() + self.cache_num = 0 + self._cache: Union[List, Dict] = [] + self.set_data(data) def set_data(self, data: Sequence): """ @@ -718,8 +730,21 @@ def set_data(self, data: Sequence): generated cache content. """ - self.data = data - self._cache = self._fill_cache() + + def _compute_cache(): + self.cache_num = min(int(self.set_num), int(len(self.data) * self.set_rate), len(self.data)) + return self._fill_cache() + + if self.hash_as_key: + # only compute cache for the unique items of dataset + mapping = {self.hash_func(v): v for v in data} + self.data = list(mapping.values()) + cache_ = _compute_cache() + self._cache = dict(zip(list(mapping)[: self.cache_num], cache_)) + self.data = data + else: + self.data = data + self._cache = _compute_cache() def _fill_cache(self) -> List: if self.cache_num <= 0: @@ -754,14 +779,21 @@ def _load_cache_item(self, idx: int): return item def _transform(self, index: int): - if index % len(self) >= self.cache_num: # support negative index + index_: Any = index + if self.hash_as_key: + key = self.hash_func(self.data[index]) + if key in self._cache: + # if existing in cache, get the index + index_ = key # if using hash as cache keys, set the key + + if isinstance(index_, int) and index_ % len(self) >= self.cache_num: # support negative index # no cache for this index, execute all the transforms directly - return super()._transform(index) + return super()._transform(index_) # load data from cache and execute from the first random transform start_run = False if self._cache is None: self._cache = self._fill_cache() - data = self._cache[index] + data = self._cache[index_] if not isinstance(self.transform, Compose): raise ValueError("transform must be an instance of monai.transforms.Compose.") for _transform in self.transform.transforms: @@ -862,10 +894,14 @@ def __init__( ) -> None: if shuffle: self.set_random_state(seed=seed) - data = copy(data) - self.randomize(data) self.shuffle = shuffle + self._start_pos: int = 0 + self._update_lock: threading.Lock = threading.Lock() + self._round: int = 1 + self._replace_done: bool = False + self._replace_mgr: Optional[threading.Thread] = None + super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress, copy_cache, as_contiguous) if self._cache is None: self._cache = self._fill_cache() @@ -884,13 +920,6 @@ def __init__( self._replace_num: int = min(math.ceil(self.cache_num * replace_rate), len(data) - self.cache_num) self._replacements: List[Any] = [None for _ in range(self._replace_num)] self._replace_data_idx: List[int] = list(range(self._replace_num)) - - self._start_pos: int = 0 - self._update_lock: threading.Lock = threading.Lock() - self._round: int = 1 - self._replace_done: bool = False - self._replace_mgr: Optional[threading.Thread] = None - self._compute_data_idx() def set_data(self, data: Sequence): diff --git a/monai/data/dataset_summary.py b/monai/data/dataset_summary.py index 2b4df4ebbf..956e038569 100644 --- a/monai/data/dataset_summary.py +++ b/monai/data/dataset_summary.py @@ -150,7 +150,7 @@ def calculate_statistics(self, foreground_threshold: int = 0): self.data_max, self.data_min = max(voxel_max), min(voxel_min) self.data_mean = (voxel_sum / voxel_ct).item() - self.data_std = (torch.sqrt(voxel_square_sum / voxel_ct - self.data_mean ** 2)).item() + self.data_std = (torch.sqrt(voxel_square_sum / voxel_ct - self.data_mean**2)).item() def calculate_percentiles( self, diff --git a/monai/data/folder_layout.py b/monai/data/folder_layout.py index d8ce162c27..b2f41b0651 100644 --- a/monai/data/folder_layout.py +++ b/monai/data/folder_layout.py @@ -29,7 +29,7 @@ class FolderLayout: layout = FolderLayout( output_dir="/test_run_1/", postfix="seg", - extension=".nii", + extension="nii", makedirs=False) layout.filename(subject="Sub-A", idx="00", modality="T1") # return value: "/test_run_1/Sub-A_seg_00_modality-T1.nii" @@ -95,5 +95,6 @@ def filename(self, subject: PathLike = "subject", idx=None, **kwargs): for k, v in kwargs.items(): full_name += f"_{k}-{v}" if self.ext is not None: - full_name += f"{self.ext}" + ext = f"{self.ext}" + full_name += f".{ext}" if ext and not ext.startswith(".") else f"{ext}" return full_name diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index df574adf00..0be7feb1e5 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -18,14 +18,12 @@ from torch.utils.data._utils.collate import np_str_obj_array_pattern from monai.config import DtypeLike, KeysCollection, PathLike -from monai.data.utils import correct_nifti_header_if_necessary +from monai.data.utils import correct_nifti_header_if_necessary, is_supported_format from monai.transforms.utility.array import EnsureChannelFirst from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import, require_pkg -from .utils import is_supported_format - if TYPE_CHECKING: - import itk # type: ignore + import itk import nibabel as nib from nibabel.nifti1 import Nifti1Image from PIL import Image as PILImage diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py new file mode 100644 index 0000000000..e9f753fb34 --- /dev/null +++ b/monai/data/image_writer.py @@ -0,0 +1,804 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict, Mapping, Optional, Sequence, Union + +import numpy as np + +from monai.apps.utils import get_logger +from monai.config import DtypeLike, NdarrayOrTensor, PathLike +from monai.data.utils import affine_to_spacing, ensure_tuple, ensure_tuple_rep, orientation_ras_lps, to_affine_nd +from monai.transforms.spatial.array import Resize, SpatialResample +from monai.transforms.utils_pytorch_numpy_unification import ascontiguousarray, moveaxis +from monai.utils import ( + GridSampleMode, + GridSamplePadMode, + InterpolateMode, + OptionalImportError, + convert_data_type, + look_up_option, + optional_import, + require_pkg, +) + +DEFAULT_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)d - %(message)s" +EXT_WILDCARD = "*" +logger = get_logger(module_name=__name__, fmt=DEFAULT_FMT) + +if TYPE_CHECKING: + import itk + import nibabel as nib + from PIL import Image as PILImage +else: + itk, _ = optional_import("itk", allow_namespace_pkg=True) + nib, _ = optional_import("nibabel") + PILImage, _ = optional_import("PIL.Image") + + +__all__ = [ + "ImageWriter", + "ITKWriter", + "NibabelWriter", + "PILWriter", + "SUPPORTED_WRITERS", + "register_writer", + "resolve_writer", + "logger", +] + +SUPPORTED_WRITERS: Dict = {} + + +def register_writer(ext_name, *im_writers): + """ + Register ``ImageWriter``, so that writing a file with filename extension ``ext_name`` + could be resolved to a tuple of potentially appropriate ``ImageWriter``. + The customised writers could be registered by: + + .. code-block:: python + + from monai.data import register_writer + # `MyWriter` must implement `ImageWriter` interface + register_writer("nii", MyWriter) + + Args: + ext_name: the filename extension of the image. + As an indexing key, it will be converted to a lower case string. + im_writers: one or multiple ImageWriter classes with high priority ones first. + """ + fmt = f"{ext_name}".lower() + if fmt.startswith("."): + fmt = fmt[1:] + existing = look_up_option(fmt, SUPPORTED_WRITERS, default=()) + all_writers = im_writers + existing + SUPPORTED_WRITERS[fmt] = all_writers + + +def resolve_writer(ext_name, error_if_not_found=True) -> Sequence: + """ + Resolves to a tuple of available ``ImageWriter`` in ``SUPPORTED_WRITERS`` + according to the filename extension key ``ext_name``. + + Args: + ext_name: the filename extension of the image. + As an indexing key it will be converted to a lower case string. + error_if_not_found: whether to raise an error if no suitable image writer is found. + if True , raise an ``OptionalImportError``, otherwise return an empty tuple. Default is ``True``. + """ + if not SUPPORTED_WRITERS: + init() + fmt = f"{ext_name}".lower() + if fmt.startswith("."): + fmt = fmt[1:] + avail_writers = [] + default_writers = SUPPORTED_WRITERS.get(EXT_WILDCARD, ()) + for _writer in look_up_option(fmt, SUPPORTED_WRITERS, default=default_writers): + try: + _writer() # this triggers `monai.utils.module.require_pkg` to check the system availability + avail_writers.append(_writer) + except OptionalImportError: + continue + except Exception: # other writer init errors indicating it exists + avail_writers.append(_writer) + if not avail_writers and error_if_not_found: + raise OptionalImportError(f"No ImageWriter backend found for {fmt}.") + writer_tuple = ensure_tuple(avail_writers) + SUPPORTED_WRITERS[fmt] = writer_tuple + return writer_tuple + + +class ImageWriter: + """ + The class is a collection of utilities to write images to disk. + + Main aspects to be considered are: + + - dimensionality of the data array, arrangements of spatial dimensions and channel/time dimensions + - ``convert_to_channel_last()`` + - metadata of the current affine and output affine, the data array should be converted accordingly + - ``get_meta_info()`` + - ``resample_if_needed()`` + - data type handling of the output image (as part of ``resample_if_needed()``) + + Subclasses of this class should implement the backend-specific functions: + + - ``set_data_array()`` to set the data array (input must be numpy array or torch tensor) + - this method sets the backend object's data part + - ``set_metadata()`` to set the metadata and output affine + - this method sets the metadata including affine handling and image resampling + - backend-specific data object ``create_backend_obj()`` + - backend-specific writing function ``write()`` + + The primary usage of subclasses of ``ImageWriter`` is: + + .. code-block:: python + + writer = MyWriter() # subclass of ImageWriter + writer.set_data_array(data_array) + writer.set_metadata(meta_dict) + writer.write(filename) + + This creates an image writer object based on ``data_array`` and ``meta_dict`` and write to ``filename``. + + It supports up to three spatial dimensions (with the resampling step supports for both 2D and 3D). + When saving multiple time steps or multiple channels `data_array`, time + and/or modality axes should be the at the `channel_dim`. For example, + the shape of a 2D eight-class and ``channel_dim=0``, the segmentation + probabilities to be saved could be `(8, 64, 64)`; in this case + ``data_array`` will be converted to `(64, 64, 1, 8)` (the third + dimension is reserved as a spatial dimension). + + The ``metadata`` could optionally have the following keys: + + - ``'original_affine'``: for data original affine, it will be the + affine of the output object, defaulting to an identity matrix. + - ``'affine'``: it should specify the current data affine, defaulting to an identity matrix. + - ``'spatial_shape'``: for data output spatial shape. + + When ``metadata`` is specified, the saver will may resample data from the space defined by + `"affine"` to the space defined by `"original_affine"`, for more details, please refer to the + ``resample_if_needed`` method. + """ + + def __init__(self, **kwargs): + """ + The constructor supports adding new instance members. + The current member in the base class is ``self.data_obj``, the subclasses can add more members, + so that necessary meta information can be stored in the object and shared among the class methods. + """ + self.data_obj = None + for k, v in kwargs.items(): + setattr(self, k, v) + + def set_data_array(self, data_array, **kwargs): + raise NotImplementedError(f"Subclasses of {self.__class__.__name__} must implement this method.") + + def set_metadata(self, meta_dict: Optional[Mapping], **options): + raise NotImplementedError(f"Subclasses of {self.__class__.__name__} must implement this method.") + + def write(self, filename: PathLike, verbose: bool = True, **kwargs): + """subclass should implement this method to call the backend-specific writing APIs.""" + if verbose: + logger.info(f"writing: {filename}") + + @classmethod + def create_backend_obj(cls, data_array: NdarrayOrTensor, **kwargs) -> np.ndarray: + """ + Subclass should implement this method to return a backend-specific data representation object. + This method is used by ``cls.write`` and the input ``data_array`` is assumed 'channel-last'. + """ + return convert_data_type(data_array, np.ndarray)[0] + + @classmethod + def resample_if_needed( + cls, + data_array: NdarrayOrTensor, + affine: Optional[NdarrayOrTensor] = None, + target_affine: Optional[NdarrayOrTensor] = None, + output_spatial_shape: Union[Sequence[int], int, None] = None, + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + align_corners: bool = False, + dtype: DtypeLike = np.float64, + ): + """ + Convert the ``data_array`` into the coordinate system specified by + ``target_affine``, from the current coordinate definition of ``affine``. + + If the transform between ``affine`` and ``target_affine`` could be + achieved by simply transposing and flipping ``data_array``, no resampling + will happen. Otherwise, this function resamples ``data_array`` using the + transformation computed from ``affine`` and ``target_affine``. + + This function assumes the NIfTI dimension notations. Spatially it + supports up to three dimensions, that is, H, HW, HWD for 1D, 2D, 3D + respectively. When saving multiple time steps or multiple channels, + time and/or modality axes should be appended after the first three + dimensions. For example, shape of 2D eight-class segmentation + probabilities to be saved could be `(64, 64, 1, 8)`. Also, data in + shape `(64, 64, 8)` or `(64, 64, 8, 1)` will be considered as a + single-channel 3D image. The ``convert_to_channel_last`` method can be + used to convert the data to the format described here. + + Note that the shape of the resampled ``data_array`` may subject to some + rounding errors. For example, resampling a 20x20 pixel image from pixel + size (1.5, 1.5)-mm to (3.0, 3.0)-mm space will return a 10x10-pixel + image. However, resampling a 20x20-pixel image from pixel size (2.0, + 2.0)-mm to (3.0, 3.0)-mm space will output a 14x14-pixel image, where + the image shape is rounded from 13.333x13.333 pixels. In this case + ``output_spatial_shape`` could be specified so that this function + writes image data to a designated shape. + + Args: + data_array: input data array to be converted. + affine: the current affine of ``data_array``. Defaults to identity + target_affine: the designated affine of ``data_array``. + The actual output affine might be different from this value due to precision changes. + output_spatial_shape: spatial shape of the output image. + This option is used when resampling is needed. + mode: available options are {``"bilinear"``, ``"nearest"``, ``"bicubic"``}. + This option is used when resampling is needed. + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + padding_mode: available options are {``"zeros"``, ``"border"``, ``"reflection"``}. + This option is used when resampling is needed. + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + align_corners: boolean option of ``grid_sample`` to handle the corner convention. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + dtype: data type for resampling computation. Defaults to + ``np.float64`` for best precision. If ``None``, use the data type of input data. + The output data type of this method is always ``np.float32``. + """ + resampler = SpatialResample(mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype) + output_array, target_affine = resampler( + data_array[None], src_affine=affine, dst_affine=target_affine, spatial_size=output_spatial_shape + ) + return output_array[0], target_affine + + @classmethod + def convert_to_channel_last( + cls, + data: NdarrayOrTensor, + channel_dim: Union[None, int, Sequence[int]] = 0, + squeeze_end_dims: bool = True, + spatial_ndim: Optional[int] = 3, + contiguous: bool = False, + ): + """ + Rearrange the data array axes to make the `channel_dim`-th dim the last + dimension and ensure there are ``spatial_ndim`` number of spatial + dimensions. + + When ``squeeze_end_dims`` is ``True``, a postprocessing step will be + applied to remove any trailing singleton dimensions. + + Args: + data: input data to be converted to "channel-last" format. + channel_dim: specifies the channel axes of the data array to move to the last. + ``None`` indicates no channel dimension, a new axis will be appended as the channel dimension. + a sequence of integers indicates multiple non-spatial dimensions. + squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed (after the channel + has been moved to the end). So if input is `(H,W,D,C)` and C==1, then it will be saved as `(H,W,D)`. + If D is also 1, it will be saved as `(H,W)`. If ``False``, image will always be saved as `(H,W,D,C)`. + spatial_ndim: modifying the spatial dims if needed, so that output to have at least + this number of spatial dims. If ``None``, the output will have the same number of + spatial dimensions as the input. + contiguous: if ``True``, the output will be contiguous. + """ + # change data to "channel last" format + if channel_dim is not None: + _chns = ensure_tuple(channel_dim) + data = moveaxis(data, _chns, tuple(range(-len(_chns), 0))) + else: # adds a channel dimension + data = data[..., None] + # To ensure at least ``spatial_ndim`` number of spatial dims + if spatial_ndim: + while len(data.shape) < spatial_ndim + 1: # assuming the data has spatial + channel dims + data = data[..., None, :] + while len(data.shape) > spatial_ndim + 1: + data = data[..., 0, :] + # if desired, remove trailing singleton dimensions + while squeeze_end_dims and data.shape[-1] == 1: + data = np.squeeze(data, -1) + if contiguous: + data = ascontiguousarray(data) + return data + + @classmethod + def get_meta_info(cls, metadata: Optional[Mapping] = None): + """ + Extracts relevant meta information from the metadata object (using ``.get``). + Optional keys are ``"spatial_shape"``, ``"affine"``, ``"original_affine"``. + """ + if not metadata: + metadata = {"original_affine": None, "affine": None, "spatial_shape": None} + original_affine = metadata.get("original_affine") + affine = metadata.get("affine") + spatial_shape = metadata.get("spatial_shape") + return original_affine, affine, spatial_shape + + +@require_pkg(pkg_name="itk") +class ITKWriter(ImageWriter): + """ + Write data and metadata into files on disk using ITK-python. + + .. code-block:: python + + import numpy as np + from monai.data import ITKWriter + + np_data = np.arange(48).reshape(3, 4, 4) + + # write as 3d spatial image no channel + writer = ITKWriter(output_dtype=np.float32) + writer.set_data_array(np_data, channel_dim=None) + # optionally set metadata affine + writer.set_metadata({"affine": np.eye(4), "original_affine": -1 * np.eye(4)}) + writer.write("test1.nii.gz") + + # write as 2d image, channel-first + writer = ITKWriter(output_dtype=np.uint8) + writer.set_data_array(np_data, channel_dim=0) + writer.set_metadata({"spatial_shape": (5, 5)}) + writer.write("test1.png") + + """ + + def __init__(self, output_dtype: DtypeLike = np.float32, **kwargs): + """ + Args: + output_dtype: output data type. + kwargs: keyword arguments passed to ``ImageWriter``. + + The constructor will create ``self.output_dtype`` internally. + ``affine`` and ``channel_dim`` are initialized as instance members (default ``None``, ``0``): + + - user-specified ``affine`` should be set in ``set_metadata``, + - user-specified ``channel_dim`` should be set in ``set_data_array``. + """ + super().__init__(output_dtype=output_dtype, affine=None, channel_dim=0, **kwargs) + + def set_data_array( + self, data_array: NdarrayOrTensor, channel_dim: Optional[int] = 0, squeeze_end_dims: bool = True, **kwargs + ): + """ + Convert ``data_array`` into 'channel-last' numpy ndarray. + + Args: + data_array: input data array with the channel dimension specified by ``channel_dim``. + channel_dim: channel dimension of the data array. Defaults to 0. + ``None`` indicates data without any channel dimension. + squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed. + kwargs: keyword arguments passed to ``self.convert_to_channel_last``, + currently support ``spatial_ndim`` and ``contiguous``, defauting to ``3`` and ``False`` respectively. + """ + _r = len(data_array.shape) + self.data_obj = self.convert_to_channel_last( + data=data_array, + channel_dim=channel_dim, + squeeze_end_dims=squeeze_end_dims, + spatial_ndim=kwargs.pop("spatial_ndim", 3), + contiguous=kwargs.pop("contiguous", True), + ) + self.channel_dim = channel_dim if len(self.data_obj.shape) >= _r else None # channel dim is at the end + + def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = True, **options): + """ + Resample ``self.dataobj`` if needed. This method assumes ``self.data_obj`` is a 'channel-last' ndarray. + + Args: + meta_dict: a metadata dictionary for affine, original affine and spatial shape information. + Optional keys are ``"spatial_shape"``, ``"affine"``, ``"original_affine"``. + resample: if ``True``, the data will be resampled to the original affine (specified in ``meta_dict``). + options: keyword arguments passed to ``self.resample_if_needed``, + currently support ``mode``, ``padding_mode``, ``align_corners``, and ``dtype``, + defaulting to ``bilinear``, ``border``, ``False``, and ``np.float64`` respectively. + """ + original_affine, affine, spatial_shape = self.get_meta_info(meta_dict) + self.data_obj, self.affine = self.resample_if_needed( + data_array=self.data_obj, + affine=affine, + target_affine=original_affine if resample else None, + output_spatial_shape=spatial_shape if resample else None, + mode=options.pop("mode", GridSampleMode.BILINEAR), + padding_mode=options.pop("padding_mode", GridSamplePadMode.BORDER), + align_corners=options.pop("align_corners", False), + dtype=options.pop("dtype", np.float64), + ) + + def write(self, filename: PathLike, verbose: bool = False, **kwargs): + """ + Create an ITK object from ``self.create_backend_obj(self.obj, ...)`` and call ``itk.imwrite``. + + Args: + filename: filename or PathLike object. + verbose: if ``True``, log the progress. + kwargs: keyword arguments passed to ``itk.imwrite``, + currently support ``compression`` and ``imageio``. + + See also: + + - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L809 + """ + super().write(filename, verbose=verbose) + self.data_obj = self.create_backend_obj( + self.data_obj, channel_dim=self.channel_dim, affine=self.affine, dtype=self.output_dtype, **kwargs # type: ignore + ) + itk.imwrite( + self.data_obj, filename, compression=kwargs.pop("compression", False), imageio=kwargs.pop("imageio", None) + ) + + @classmethod + def create_backend_obj( + cls, + data_array: NdarrayOrTensor, + channel_dim: Optional[int] = 0, + affine: Optional[NdarrayOrTensor] = None, + dtype: DtypeLike = np.float32, + **kwargs, + ): + """ + Create an ITK object from ``data_array``. This method assumes a 'channel-last' ``data_array``. + + Args: + data_array: input data array. + channel_dim: channel dimension of the data array. This is used to create a Vector Image if it is not ``None``. + affine: affine matrix of the data array. This is used to compute `spacing`, `direction` and `origin`. + dtype: output data type. + kwargs: keyword arguments. Current `itk.GetImageFromArray` will read ``ttype`` from this dictionary. + + see also: + + - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L389 + """ + data_array = super().create_backend_obj(data_array) + _is_vec = channel_dim is not None + if _is_vec: + data_array = np.moveaxis(data_array, -1, 0) # from channel last to channel first + data_array = data_array.T.astype(dtype, copy=True, order="C") + itk_obj = itk.GetImageFromArray(data_array, is_vector=_is_vec, ttype=kwargs.pop("ttype", None)) + + d = len(itk.size(itk_obj)) + if affine is None: + affine = np.eye(d + 1, dtype=np.float64) + _affine = convert_data_type(affine, np.ndarray)[0] + _affine = orientation_ras_lps(to_affine_nd(d, _affine)) + spacing = affine_to_spacing(_affine, r=d) + _direction: np.ndarray = np.diag(1 / spacing) + _direction = _affine[:d, :d] @ _direction + itk_obj.SetSpacing(spacing.tolist()) + itk_obj.SetOrigin(_affine[:d, -1].tolist()) + itk_obj.SetDirection(itk.GetMatrixFromArray(_direction)) + return itk_obj + + +@require_pkg(pkg_name="nibabel") +class NibabelWriter(ImageWriter): + """ + Write data and metadata into files on disk using Nibabel. + + .. code-block:: python + + import numpy as np + from monai.data import NibabelWriter + + np_data = np.arange(48).reshape(3, 4, 4) + writer = NibabelWriter() + writer.set_data_array(np_data, channel_dim=None) + writer.set_metadata({"affine": np.eye(4), "original_affine": np.eye(4)}) + writer.write("test1.nii.gz", verbose=True) + + """ + + def __init__(self, output_dtype: DtypeLike = np.float32, **kwargs): + """ + Args: + output_dtype: output data type. + kwargs: keyword arguments passed to ``ImageWriter``. + + The constructor will create ``self.output_dtype`` internally. + ``affine`` is initialized as instance members (default ``None``), + user-specified ``affine`` should be set in ``set_metadata``. + """ + super().__init__(output_dtype=output_dtype, affine=None, **kwargs) + + def set_data_array( + self, data_array: NdarrayOrTensor, channel_dim: Optional[int] = 0, squeeze_end_dims: bool = True, **kwargs + ): + """ + Convert ``data_array`` into 'channel-last' numpy ndarray. + + Args: + data_array: input data array with the channel dimension specified by ``channel_dim``. + channel_dim: channel dimension of the data array. Defaults to 0. + ``None`` indicates data without any channel dimension. + squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed. + kwargs: keyword arguments passed to ``self.convert_to_channel_last``, + currently support ``spatial_ndim``, defauting to ``3``. + """ + self.data_obj = self.convert_to_channel_last( + data=data_array, + channel_dim=channel_dim, + squeeze_end_dims=squeeze_end_dims, + spatial_ndim=kwargs.pop("spatial_ndim", 3), + ) + + def set_metadata(self, meta_dict: Optional[Mapping], resample: bool = True, **options): + """ + Resample ``self.dataobj`` if needed. This method assumes ``self.data_obj`` is a 'channel-last' ndarray. + + Args: + meta_dict: a metadata dictionary for affine, original affine and spatial shape information. + Optional keys are ``"spatial_shape"``, ``"affine"``, ``"original_affine"``. + resample: if ``True``, the data will be resampled to the original affine (specified in ``meta_dict``). + options: keyword arguments passed to ``self.resample_if_needed``, + currently support ``mode``, ``padding_mode``, ``align_corners``, and ``dtype``, + defaulting to ``bilinear``, ``border``, ``False``, and ``np.float64`` respectively. + """ + original_affine, affine, spatial_shape = self.get_meta_info(meta_dict) + self.data_obj, self.affine = self.resample_if_needed( + data_array=self.data_obj, + affine=affine, + target_affine=original_affine if resample else None, + output_spatial_shape=spatial_shape if resample else None, + mode=options.pop("mode", GridSampleMode.BILINEAR), + padding_mode=options.pop("padding_mode", GridSamplePadMode.BORDER), + align_corners=options.pop("align_corners", False), + dtype=options.pop("dtype", np.float64), + ) + + def write(self, filename: PathLike, verbose: bool = False, **obj_kwargs): + """ + Create a Nibabel object from ``self.create_backend_obj(self.obj, ...)`` and call ``nib.save``. + + Args: + filename: filename or PathLike object. + verbose: if ``True``, log the progress. + obj_kwargs: keyword arguments passed to ``self.create_backend_obj``, + + See also: + + - https://nipy.org/nibabel/reference/nibabel.nifti1.html#nibabel.nifti1.save + """ + super().write(filename, verbose=verbose) + self.data_obj = self.create_backend_obj( + self.data_obj, affine=self.affine, dtype=self.output_dtype, **obj_kwargs # type: ignore + ) + nib.save(self.data_obj, filename) + + @classmethod + def create_backend_obj( + cls, data_array: NdarrayOrTensor, affine: Optional[NdarrayOrTensor] = None, dtype: DtypeLike = None, **kwargs + ): + """ + Create an Nifti1Image object from ``data_array``. This method assumes a 'channel-last' ``data_array``. + + Args: + data_array: input data array. + affine: affine matrix of the data array. + dtype: output data type. + kwargs: keyword arguments. Current ``nib.nifti1.Nifti1Image`` will read + ``header``, ``extra``, ``file_map`` from this dictionary. + + See also: + + - https://nipy.org/nibabel/reference/nibabel.nifti1.html#nibabel.nifti1.Nifti1Image + """ + data_array = super().create_backend_obj(data_array) + if dtype is not None: + data_array = data_array.astype(dtype, copy=False) + affine = convert_data_type(affine, np.ndarray)[0] + affine = to_affine_nd(r=3, affine=affine) + return nib.nifti1.Nifti1Image( + data_array, + affine, + header=kwargs.pop("header", None), + extra=kwargs.pop("extra", None), + file_map=kwargs.pop("file_map", None), + ) + + +@require_pkg(pkg_name="PIL") +class PILWriter(ImageWriter): + """ + Write image data into files on disk using pillow. + + It's based on the Image module in PIL library: + https://pillow.readthedocs.io/en/stable/reference/Image.html + + .. code-block:: python + + import numpy as np + from monai.data import PILWriter + + np_data = np.arange(48).reshape(3, 4, 4) + writer = PILWriter(np.uint8) + writer.set_data_array(np_data, channel_dim=0) + writer.write("test1.png", verbose=True) + """ + + def __init__( + self, output_dtype: DtypeLike = np.float32, channel_dim: Optional[int] = 0, scale: Optional[int] = 255, **kwargs + ): + """ + Args: + output_dtype: output data type. + channel_dim: channel dimension of the data array. Defaults to 0. + ``None`` indicates data without any channel dimension. + scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling + [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. + kwargs: keyword arguments passed to ``ImageWriter``. + """ + super().__init__(output_dtype=output_dtype, channel_dim=channel_dim, scale=scale, **kwargs) + + def set_data_array( + self, + data_array: NdarrayOrTensor, + channel_dim: Optional[int] = 0, + squeeze_end_dims: bool = True, + contiguous: bool = False, + **kwargs, + ): + """ + Convert ``data_array`` into 'channel-last' numpy ndarray. + + Args: + data_array: input data array with the channel dimension specified by ``channel_dim``. + channel_dim: channel dimension of the data array. Defaults to 0. + ``None`` indicates data without any channel dimension. + squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed. + contiguous: if ``True``, the data array will be converted to a contiguous array. Default is ``False``. + kwargs: keyword arguments passed to ``self.convert_to_channel_last``, + currently support ``spatial_ndim``, defauting to ``2``. + """ + self.data_obj = self.convert_to_channel_last( + data=data_array, + channel_dim=channel_dim, + squeeze_end_dims=squeeze_end_dims, + spatial_ndim=kwargs.pop("spatial_ndim", 2), + contiguous=contiguous, + ) + + def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = True, **options): + """ + Resample ``self.dataobj`` if needed. This method assumes ``self.data_obj`` is a 'channel-last' ndarray. + + Args: + meta_dict: a metadata dictionary for affine, original affine and spatial shape information. + Optional key is ``"spatial_shape"``. + resample: if ``True``, the data will be resampled to the spatial shape specified in ``meta_dict``. + options: keyword arguments passed to ``self.resample_if_needed``, + currently support ``mode``, defaulting to ``bicubic``. + """ + spatial_shape = self.get_meta_info(meta_dict) + self.data_obj = self.resample_and_clip( + data_array=self.data_obj, + output_spatial_shape=spatial_shape if resample else None, + mode=options.pop("mode", InterpolateMode.BICUBIC), + ) + + def write(self, filename: PathLike, verbose: bool = False, **kwargs): + """ + Create a PIL image object from ``self.create_backend_obj(self.obj, ...)`` and call ``save``. + + Args: + filename: filename or PathLike object. + verbose: if ``True``, log the progress. + kwargs: optional keyword arguments passed to ``self.create_backend_obj`` + currently support ``reverse_indexing``, ``image_mode``, defaulting to ``True``, ``None`` respectively. + + See also: + + - https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.save + """ + super().write(filename, verbose=verbose) + self.data_obj = self.create_backend_obj( + data_array=self.data_obj, + dtype=self.output_dtype, # type: ignore + reverse_indexing=kwargs.pop("reverse_indexing", True), + image_mode=kwargs.pop("image_mode", None), + scale=self.scale, # type: ignore + **kwargs, + ) + self.data_obj.save(filename, **kwargs) + + @classmethod + def get_meta_info(cls, metadata: Optional[Mapping] = None): + return None if not metadata else metadata.get("spatial_shape") + + @classmethod + def resample_and_clip( + cls, + data_array: NdarrayOrTensor, + output_spatial_shape: Optional[Sequence[int]] = None, + mode: Union[InterpolateMode, str] = InterpolateMode.BICUBIC, + ): + """ + Resample ``data_array`` to ``output_spatial_shape`` if needed. + Args: + data_array: input data array. This method assumes the 'channel-last' format. + output_spatial_shape: output spatial shape. + mode: interpolation mode, defautl is ``InterpolateMode.BICUBIC``. + """ + + data: np.ndarray = convert_data_type(data_array, np.ndarray)[0] + if output_spatial_shape is not None: + output_spatial_shape_ = ensure_tuple_rep(output_spatial_shape, 2) + mode = look_up_option(mode, InterpolateMode) + align_corners = None if mode in (InterpolateMode.NEAREST, InterpolateMode.AREA) else False + xform = Resize(spatial_size=output_spatial_shape_, mode=mode, align_corners=align_corners) + _min, _max = np.min(data), np.max(data) + if len(data.shape) == 3: + data = np.moveaxis(data, -1, 0) # to channel first + data = xform(data) # type: ignore + data = np.moveaxis(data, 0, -1) + else: # (H, W) + data = np.expand_dims(data, 0) # make a channel + data = xform(data)[0] # type: ignore + if mode != InterpolateMode.NEAREST: + data = np.clip(data, _min, _max) + return data + + @classmethod + def create_backend_obj( + cls, + data_array: NdarrayOrTensor, + dtype: DtypeLike = None, + scale: Optional[int] = 255, + reverse_indexing: bool = True, + **kwargs, + ): + """ + Create a PIL object from ``data_array``. + + Args: + data_array: input data array. + dtype: output data type. + scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling + [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. + reverse_indexing: if ``True``, the data array's first two dimensions will be swapped. + kwargs: keyword arguments. Currently ``PILImage.fromarray`` will read + ``image_mode`` from this dictionary, defaults to ``None``. + + See also: + + - https://pillow.readthedocs.io/en/stable/reference/Image.html + """ + data: np.ndarray = super().create_backend_obj(data_array) + if scale: + # scale the data to be in an integer range + data = np.clip(data, 0.0, 1.0) # type: ignore # png writer only can scale data in range [0, 1] + if scale == np.iinfo(np.uint8).max: + data = (scale * data).astype(np.uint8, copy=False) + elif scale == np.iinfo(np.uint16).max: + data = (scale * data).astype(np.uint16, copy=False) + else: + raise ValueError(f"Unsupported scale: {scale}, available options are [255, 65535].") + if dtype is not None: + data = data.astype(dtype, copy=False) + if reverse_indexing: + data = np.moveaxis(data, 0, 1) + + return PILImage.fromarray(data, mode=kwargs.pop("image_mode", None)) + + +def init(): + """ + Initialize the image writer modules according to the filename extension. + """ + for ext in ("png", "jpg", "jpeg", "bmp", "tiff", "tif"): + register_writer(ext, PILWriter) # TODO: test 16-bit + for ext in ("nii.gz", "nii"): + register_writer(ext, NibabelWriter, ITKWriter) + register_writer("nrrd", ITKWriter, NibabelWriter) + register_writer(EXT_WILDCARD, ITKWriter, NibabelWriter, ITKWriter) diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index a5acdd032e..3fdc0aa3e8 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -19,8 +19,10 @@ from monai.data.utils import create_file_basename from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key +from monai.utils import deprecated +@deprecated(since="0.8", msg_suffix="use monai.transforms.SaveImage instead.") class NiftiSaver: """ Save the data as NIfTI file, it can support single data content or a batch of data. @@ -32,6 +34,9 @@ class NiftiSaver: Note: image should include channel dimension: [B],C,H,W,[D]. + .. deprecated:: 0.8 + Use :py:class:`monai.transforms.SaveImage` instead. + """ def __init__( diff --git a/monai/data/nifti_writer.py b/monai/data/nifti_writer.py index 4e7a99f557..8a6172955f 100644 --- a/monai/data/nifti_writer.py +++ b/monai/data/nifti_writer.py @@ -18,12 +18,14 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.data.utils import compute_shape_offset, to_affine_nd from monai.networks.layers import AffineTransform -from monai.utils import GridSampleMode, GridSamplePadMode, optional_import +from monai.transforms.utils_pytorch_numpy_unification import allclose +from monai.utils import GridSampleMode, GridSamplePadMode, deprecated, optional_import from monai.utils.type_conversion import convert_data_type nib, _ = optional_import("nibabel") +@deprecated(since="0.8", msg_suffix="use monai.data.NibabelWriter instead.") def write_nifti( data: NdarrayOrTensor, file_name: str, @@ -97,24 +99,26 @@ def write_nifti( dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. output_dtype: data type for saving data. Defaults to ``np.float32``. + + .. deprecated:: 0.8 + Use :py:meth:`monai.data.NibabelWriter` instead. + """ - if isinstance(data, torch.Tensor): - data, *_ = convert_data_type(data, np.ndarray) - if isinstance(affine, torch.Tensor): - affine, *_ = convert_data_type(affine, np.ndarray) + data, *_ = convert_data_type(data, np.ndarray) + affine, *_ = convert_data_type(affine, np.ndarray) if not isinstance(data, np.ndarray): raise AssertionError("input data must be numpy array or torch tensor.") dtype = dtype or data.dtype sr = min(data.ndim, 3) if affine is None: affine = np.eye(4, dtype=np.float64) - affine = to_affine_nd(sr, affine) # type: ignore + affine = to_affine_nd(sr, affine) if target_affine is None: target_affine = affine - target_affine = to_affine_nd(sr, target_affine) + target_affine, *_ = convert_data_type(to_affine_nd(sr, target_affine), np.ndarray) - if np.allclose(affine, target_affine, atol=1e-3): + if allclose(affine, target_affine, atol=1e-3): # no affine changes, save (data, affine) results_img = nib.Nifti1Image(data.astype(output_dtype, copy=False), to_affine_nd(3, target_affine)) nib.save(results_img, file_name) @@ -127,7 +131,7 @@ def write_nifti( data_shape = data.shape data = nib.orientations.apply_orientation(data, ornt_transform) _affine = affine @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape) - if np.allclose(_affine, target_affine, atol=1e-3) or not resample: + if allclose(_affine, target_affine, atol=1e-3) or not resample: results_img = nib.Nifti1Image(data.astype(output_dtype, copy=False), to_affine_nd(3, _affine)) # type: ignore nib.save(results_img, file_name) return diff --git a/monai/data/png_saver.py b/monai/data/png_saver.py index a83a560e9f..9a1ade0efa 100644 --- a/monai/data/png_saver.py +++ b/monai/data/png_saver.py @@ -18,9 +18,10 @@ from monai.data.png_writer import write_png from monai.data.utils import create_file_basename from monai.utils import ImageMetaKey as Key -from monai.utils import InterpolateMode, look_up_option +from monai.utils import InterpolateMode, deprecated, look_up_option +@deprecated(since="0.8", msg_suffix="use monai.transforms.SaveImage instead.") class PNGSaver: """ Save the data as png file, it can support single data content or a batch of data. @@ -30,6 +31,9 @@ class PNGSaver: where the input image name is extracted from the provided meta data dictionary. If no meta data provided, use index from 0 as the filename prefix. + .. deprecated:: 0.8 + Use :py:class:`monai.transforms.SaveImage` instead. + """ def __init__( diff --git a/monai/data/png_writer.py b/monai/data/png_writer.py index f1aa5fc5c8..5d05536923 100644 --- a/monai/data/png_writer.py +++ b/monai/data/png_writer.py @@ -14,11 +14,12 @@ import numpy as np from monai.transforms.spatial.array import Resize -from monai.utils import InterpolateMode, ensure_tuple_rep, look_up_option, optional_import +from monai.utils import InterpolateMode, deprecated, ensure_tuple_rep, look_up_option, optional_import Image, _ = optional_import("PIL", name="Image") +@deprecated(since="0.8", msg_suffix="use monai.data.PILWriter instead.") def write_png( data: np.ndarray, file_name: str, @@ -46,6 +47,9 @@ def write_png( Raises: ValueError: When ``scale`` is not one of [255, 65535]. + .. deprecated:: 0.8 + Use :py:meth:`monai.data.PILWriter` instead. + """ if not isinstance(data, np.ndarray): raise ValueError("input data must be numpy array.") @@ -65,10 +69,10 @@ def write_png( data = np.expand_dims(data, 0) # make a channel data = xform(data)[0] # type: ignore if mode != InterpolateMode.NEAREST: - data = np.clip(data, _min, _max) # type: ignore + data = np.clip(data, _min, _max) if scale is not None: - data = np.clip(data, 0.0, 1.0) # type: ignore # png writer only can scale data in range [0, 1] + data = np.clip(data, 0.0, 1.0) # png writer only can scale data in range [0, 1] if scale == np.iinfo(np.uint8).max: data = (scale * data).astype(np.uint8, copy=False) elif scale == np.iinfo(np.uint16).max: @@ -77,7 +81,7 @@ def write_png( raise ValueError(f"Unsupported scale: {scale}, available options are [255, 65535]") # PNG data must be int number - if data.dtype not in (np.uint8, np.uint16): # type: ignore + if data.dtype not in (np.uint8, np.uint16): data = data.astype(np.uint8, copy=False) data = np.moveaxis(data, 0, 1) diff --git a/monai/data/samplers.py b/monai/data/samplers.py index 40eed03187..f5175266d8 100644 --- a/monai/data/samplers.py +++ b/monai/data/samplers.py @@ -50,7 +50,7 @@ def __init__( super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, **kwargs) if not even_divisible: - data_len = len(dataset) # type: ignore + data_len = len(dataset) extra_size = self.total_size - data_len if self.rank + extra_size >= self.num_replicas: self.num_samples -= 1 diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 997e96a1b3..0b97c9febf 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -16,17 +16,17 @@ import numpy as np import torch +from monai.config.type_definitions import NdarrayOrTensor from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset -from monai.data.utils import list_data_collate, pad_list_data_collate +from monai.data.utils import decollate_batch, pad_list_data_collate from monai.transforms.compose import Compose +from monai.transforms.croppad.batch import PadListDataCollate from monai.transforms.inverse import InvertibleTransform -from monai.transforms.inverse_batch_transform import BatchInverseTransform +from monai.transforms.post.dictionary import Invertd from monai.transforms.transform import Randomizable -from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode -from monai.utils.enums import CommonKeys, PostFix, TraceKeys -from monai.utils.module import optional_import -from monai.utils.type_conversion import convert_data_type +from monai.transforms.utils_pytorch_numpy_unification import mode, stack +from monai.utils import CommonKeys, PostFix, optional_import if TYPE_CHECKING: from tqdm import tqdm @@ -75,24 +75,29 @@ class TestTimeAugmentation: orig_meta_keys: the key of the meta data of original input data, will get the `affine`, `data_shape`, etc. the meta data is a dictionary object which contains: filename, original_shape, etc. if None, will try to construct meta_keys by `{orig_key}_{meta_key_postfix}`. - meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data, + meta_key_postfix: use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. this arg only works when `meta_keys=None`. - return_full_data: normally, metrics are returned (mode, mean, std, vvc). Setting this flag to `True` will return the - full data. Dimensions will be same size as when passing a single image through `inferrer_fn`, with a dimension appended - equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`. + to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`. + output_device: if converted the inverted data to Tensor, move the inverted results to target device + before `post_func`, default to "cpu". + post_func: post processing for the inverted data, should be a callable function. + return_full_data: normally, metrics are returned (mode, mean, std, vvc). Setting this flag to `True` + will return the full data. Dimensions will be same size as when passing a single image through + `inferrer_fn`, with a dimension appended equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`. progress: whether to display a progress bar. Example: .. code-block:: python - transform = RandAffined(keys, ...) - post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) + model = UNet(...).to(device) + transform = Compose([RandAffined(keys, ...), ...]) + transform.set_random_state(seed=123) # ensure deterministic evaluation tt_aug = TestTimeAugmentation( - transform, batch_size=5, num_workers=0, inferrer_fn=lambda x: post_trans(model(x)), device=device + transform, batch_size=5, num_workers=0, inferrer_fn=model, device=device ) mode, mean, std, vvc = tt_aug(test_data) """ @@ -109,6 +114,9 @@ def __init__( nearest_interp: bool = True, orig_meta_keys: Optional[str] = None, meta_key_postfix=DEFAULT_POST_FIX, + to_tensor: bool = True, + output_device: Union[str, torch.device] = "cpu", + post_func: Callable = _identity, return_full_data: bool = False, progress: bool = True, ) -> None: @@ -118,12 +126,20 @@ def __init__( self.inferrer_fn = inferrer_fn self.device = device self.image_key = image_key - self.orig_key = orig_key - self.nearest_interp = nearest_interp - self.orig_meta_keys = orig_meta_keys - self.meta_key_postfix = meta_key_postfix self.return_full_data = return_full_data self.progress = progress + self._pred_key = CommonKeys.PRED + self.inverter = Invertd( + keys=self._pred_key, + transform=transform, + orig_keys=orig_key, + orig_meta_keys=orig_meta_keys, + meta_key_postfix=meta_key_postfix, + nearest_interp=nearest_interp, + to_tensor=to_tensor, + device=output_device, + post_func=post_func, + ) # check that the transform has at least one random component, and that all random transforms are invertible self._check_transforms() @@ -135,8 +151,8 @@ def _check_transforms(self): invertibles = np.array([isinstance(t, InvertibleTransform) for t in ts]) # check at least 1 random if sum(randoms) == 0: - raise RuntimeError( - "Requires a `Randomizable` transform or a `Compose` containing at least one `Randomizable` transform." + warnings.warn( + "TTA usually has at least a `Randomizable` transform or `Compose` contains `Randomizable` transforms." ) # check that whenever randoms is True, invertibles is also true for r, i in zip(randoms, invertibles): @@ -147,18 +163,19 @@ def _check_transforms(self): def __call__( self, data: Dict[str, Any], num_examples: int = 10 - ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, float], np.ndarray]: + ) -> Union[Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor, float], NdarrayOrTensor]: """ Args: data: dictionary data to be processed. num_examples: number of realisations to be processed and results combined. Returns: - - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are calculated across - `num_examples` outputs at each voxel. The volume variation coefficient (VVC) is `std/mean` across the whole output, - including `num_examples`. See original paper for clarification. - - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then concatenating across - the first dimension containing `num_examples`. This allows the user to perform their own analysis if desired. + - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are + calculated across `num_examples` outputs at each voxel. The volume variation coefficient (VVC) + is `std/mean` across the whole output, including `num_examples`. See original paper for clarification. + - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then + concatenating across the first dimension containing `num_examples`. This allows the user to perform + their own analysis if desired. """ d = dict(data) @@ -171,56 +188,22 @@ def __call__( ds = Dataset(data_in, self.transform) dl = DataLoader(ds, num_workers=self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate) - transform_key = InvertibleTransform.trace_key(self.orig_key) - - # create inverter - inverter = BatchInverseTransform(self.transform, dl, collate_fn=list_data_collate) - - outputs: List[np.ndarray] = [] - - for batch_data in tqdm(dl) if has_tqdm and self.progress else dl: - - batch_images = batch_data[self.image_key].to(self.device) + outs: List = [] + for b in tqdm(dl) if has_tqdm and self.progress else dl: # do model forward pass - batch_output = self.inferrer_fn(batch_images) - if isinstance(batch_output, torch.Tensor): - batch_output = batch_output.detach().cpu() - if isinstance(batch_output, np.ndarray): - batch_output = torch.Tensor(batch_output) - transform_info = batch_data.get(transform_key, None) - if transform_info is None: - # no invertible transforms, adding dummy info for identity invertible - transform_info = [[TraceKeys.NONE] for _ in range(self.batch_size)] - if self.nearest_interp: - transform_info = convert_inverse_interp_mode( - trans_info=deepcopy(transform_info), mode="nearest", align_corners=None - ) + b[self._pred_key] = self.inferrer_fn(b[self.image_key].to(self.device)) + outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)]) - # create a dictionary containing the inferred batch and their transforms - inferred_dict = {self.orig_key: batch_output, transform_key: transform_info} - # if meta dict is present, add that too (required for some inverse transforms) - meta_dict_key = self.orig_meta_keys or f"{self.orig_key}_{self.meta_key_postfix}" - if meta_dict_key in batch_data: - inferred_dict[meta_dict_key] = batch_data[meta_dict_key] - - # do inverse transformation (allow missing keys as only inverting the orig_key) - with allow_missing_keys_mode(self.transform): # type: ignore - inv_batch = inverter(inferred_dict) - - # append - outputs.append(inv_batch[self.orig_key]) - - # output - output: np.ndarray = np.concatenate(outputs) + output: NdarrayOrTensor = stack(outs, 0) if self.return_full_data: return output # calculate metrics - output_t, *_ = convert_data_type(output, output_type=torch.Tensor, dtype=np.int64) - mode: np.ndarray = np.asarray(torch.mode(output_t, dim=0).values) # type: ignore - mean: np.ndarray = np.mean(output, axis=0) # type: ignore - std: np.ndarray = np.std(output, axis=0) # type: ignore - vvc: float = (np.std(output) / np.mean(output)).item() - return mode, mean, std, vvc + _mode = mode(output, dim=0) + mean = output.mean(0) + std = output.std(0) + vvc = (output.std() / output.mean()).item() + + return _mode, mean, std, vvc diff --git a/monai/data/thread_buffer.py b/monai/data/thread_buffer.py index cdd7c05f31..e21af69813 100644 --- a/monai/data/thread_buffer.py +++ b/monai/data/thread_buffer.py @@ -89,6 +89,12 @@ class ThreadDataLoader(DataLoader): on the same batch will still produce good training with minimal short-term overfitting while allowing a slow batch generation process more time to produce a result. + Another typical usage is to accelerate light-weight preprocessing (usually cached all the deterministic transforms + and no IO operations), because it leverages the separate thread to execute preprocessing to avoid unnecessary IPC + between multiple workers of DataLoader. And as CUDA may not work well with the multi-processing of DataLoader, + `ThreadDataLoader` can be useful for GPU transforms. For more details: + https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_model_training_guide.md. + See: * Fischetti et al. "Faster SGD training by minibatch persistency." ArXiv (2018) https://arxiv.org/abs/1806.07353 * Dami et al., "Faster Neural Network Training with Data Echoing" ArXiv (2020) https://arxiv.org/abs/1907.05550 @@ -99,20 +105,15 @@ class ThreadDataLoader(DataLoader): dataset: input dataset. buffer_size: number of items to buffer from the data source. buffer_timeout: time to wait for an item from the buffer, or to wait while the buffer is full when adding items. - num_workers: number of the multi-processing workers in PyTorch DataLoader. - repeats: number of times to yield the same batch + repeats: number of times to yield the same batch. + kwargs: other arguments for `DataLoader` except for `dataset`. + """ def __init__( - self, - dataset: Dataset, - buffer_size: int = 1, - buffer_timeout: float = 0.01, - num_workers: int = 0, - repeats: int = 1, - **kwargs, + self, dataset: Dataset, buffer_size: int = 1, buffer_timeout: float = 0.01, repeats: int = 1, **kwargs ): - super().__init__(dataset, num_workers, **kwargs) + super().__init__(dataset, **kwargs) self.buffer_size = buffer_size self.buffer_timeout = buffer_timeout self.repeats = repeats diff --git a/monai/data/utils.py b/monai/data/utils.py index 79ef9bd7fb..495daf15e2 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -27,12 +27,15 @@ import torch from torch.utils.data._utils.collate import default_collate -from monai.config.type_definitions import PathLike +from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike from monai.networks.layers.simplelayers import GaussianFilter from monai.utils import ( MAX_SEED, BlendMode, + Method, NumpyPadMode, + convert_data_type, + convert_to_dst_type, ensure_tuple, ensure_tuple_rep, ensure_tuple_size, @@ -42,7 +45,6 @@ look_up_option, optional_import, ) -from monai.utils.enums import Method pd, _ = optional_import("pandas") DataFrame, _ = optional_import("pandas", name="DataFrame") @@ -50,39 +52,46 @@ __all__ = [ - "get_random_patch", - "iter_patch_slices", - "dense_patch_slices", - "iter_patch", - "get_valid_patch_size", - "list_data_collate", - "worker_init_fn", - "set_rnd", - "correct_nifti_header_if_necessary", - "rectify_header_sform_qform", - "zoom_affine", + "AFFINE_TOL", + "SUPPORTED_PICKLE_MOD", + "affine_to_spacing", + "compute_importance_map", "compute_shape_offset", - "to_affine_nd", + "convert_tables_to_dicts", + "correct_nifti_header_if_necessary", "create_file_basename", - "compute_importance_map", + "decollate_batch", + "dense_patch_slices", + "get_random_patch", + "get_valid_patch_size", "is_supported_format", + "iter_patch", + "iter_patch_slices", + "json_hashing", + "list_data_collate", + "no_collation", + "orientation_ras_lps", + "pad_list_data_collate", "partition_dataset", "partition_dataset_classes", + "pickle_hashing", + "rectify_header_sform_qform", + "reorient_spatial_axes", "resample_datalist", "select_cross_validation_folds", - "json_hashing", - "pickle_hashing", + "set_rnd", "sorted_dict", - "decollate_batch", - "pad_list_data_collate", - "no_collation", - "convert_tables_to_dicts", - "SUPPORTED_PICKLE_MOD", + "to_affine_nd", + "worker_init_fn", + "zoom_affine", ] # module to be used by `torch.save` SUPPORTED_PICKLE_MOD = {"pickle": pickle} +# tolerance for affine matrix computation +AFFINE_TOL = 1e-3 + def get_random_patch( dims: Sequence[int], patch_size: Sequence[int], rand_state: Optional[np.random.RandomState] = None @@ -489,10 +498,10 @@ def pad_list_data_collate( tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of different sizes. - This can be used on both list and dictionary data. In the case of the dictionary data, this transform will be added - to the list of invertible transforms. - - The inverse can be called using the static method: `monai.transforms.croppad.batch.PadListDataCollate.inverse`. + This can be used on both list and dictionary data. + Note that in the case of the dictionary data, this decollate function may add the transform information of + `PadListDataCollate` to the list of invertible transforms if input batch have different spatial shape, so need to + call static method: `monai.transforms.croppad.batch.PadListDataCollate.inverse` before inverting other transforms. Args: batch: batch of data to pad-collate @@ -544,6 +553,30 @@ def set_rnd(obj, seed: int) -> int: return seed +def affine_to_spacing(affine: NdarrayTensor, r: int = 3, dtype=float, suppress_zeros: bool = True) -> NdarrayTensor: + """ + Computing the current spacing from the affine matrix. + + Args: + affine: a d x d affine matrix. + r: indexing based on the spatial rank, spacing is computed from `affine[:r, :r]`. + dtype: data type of the output. + suppress_zeros: whether to surpress the zeros with ones. + + Returns: + an `r` dimensional vector of spacing. + """ + _affine, *_ = convert_to_dst_type(affine[:r, :r], dst=affine, dtype=dtype) + if isinstance(_affine, torch.Tensor): + spacing = torch.sqrt(torch.sum(_affine * _affine, dim=0)) + else: + spacing = np.sqrt(np.sum(_affine * _affine, axis=0)) + if suppress_zeros: + spacing[spacing == 0] = 1.0 + spacing_, *_ = convert_to_dst_type(spacing, dst=affine, dtype=dtype) + return spacing_ + + def correct_nifti_header_if_necessary(img_nii): """ Check nifti object header's format, update the header if needed. @@ -559,7 +592,7 @@ def correct_nifti_header_if_necessary(img_nii): return img_nii # do nothing for high-dimensional array # check that affine matches zooms pixdim = np.asarray(img_nii.header.get_zooms())[:dim] - norm_affine = np.sqrt(np.sum(np.square(img_nii.affine[:dim, :dim]), 0)) + norm_affine = affine_to_spacing(img_nii.affine, r=dim) if np.allclose(pixdim, norm_affine): return img_nii if hasattr(img_nii, "get_sform"): @@ -580,8 +613,8 @@ def rectify_header_sform_qform(img_nii): d = img_nii.header["dim"][0] pixdim = np.asarray(img_nii.header.get_zooms())[:d] sform, qform = img_nii.get_sform(), img_nii.get_qform() - norm_sform = np.sqrt(np.sum(np.square(sform[:d, :d]), 0)) - norm_qform = np.sqrt(np.sum(np.square(qform[:d, :d]), 0)) + norm_sform = affine_to_spacing(sform, r=d) + norm_qform = affine_to_spacing(qform, r=d) sform_mismatch = not np.allclose(norm_sform, pixdim) qform_mismatch = not np.allclose(norm_qform, pixdim) @@ -598,7 +631,7 @@ def rectify_header_sform_qform(img_nii): img_nii.set_qform(img_nii.get_sform()) return img_nii - norm = np.sqrt(np.sum(np.square(img_nii.affine[:d, :d]), 0)) + norm = affine_to_spacing(img_nii.affine, r=d) warnings.warn(f"Modifying image pixdim from {pixdim} to {norm}") img_nii.header.set_zooms(norm) @@ -638,7 +671,7 @@ def zoom_affine(affine: np.ndarray, scale: Union[np.ndarray, Sequence[float]], d d = len(affine) - 1 # compute original pixdim - norm = np.sqrt(np.sum(np.square(affine), 0))[:-1] + norm = affine_to_spacing(affine, r=d) if len(scale_np) < d: # defaults based on affine scale_np = np.append(scale_np, norm[len(scale_np) :]) scale_np = scale_np[:d] @@ -658,7 +691,7 @@ def zoom_affine(affine: np.ndarray, scale: Union[np.ndarray, Sequence[float]], d def compute_shape_offset( - spatial_shape: Union[np.ndarray, Sequence[int]], in_affine: np.ndarray, out_affine: np.ndarray + spatial_shape: Union[np.ndarray, Sequence[int]], in_affine: NdarrayOrTensor, out_affine: NdarrayOrTensor ) -> Tuple[np.ndarray, np.ndarray]: """ Given input and output affine, compute appropriate shapes @@ -673,67 +706,98 @@ def compute_shape_offset( """ shape = np.array(spatial_shape, copy=True, dtype=float) sr = len(shape) - in_affine = to_affine_nd(sr, in_affine) - out_affine = to_affine_nd(sr, out_affine) + in_affine_ = convert_data_type(to_affine_nd(sr, in_affine), np.ndarray)[0] + out_affine_ = convert_data_type(to_affine_nd(sr, out_affine), np.ndarray)[0] in_coords = [(0.0, dim - 1.0) for dim in shape] - corners = np.asarray(np.meshgrid(*in_coords, indexing="ij")).reshape((len(shape), -1)) + corners: np.ndarray = np.asarray(np.meshgrid(*in_coords, indexing="ij")).reshape((len(shape), -1)) corners = np.concatenate((corners, np.ones_like(corners[:1]))) - corners = in_affine @ corners - corners_out = np.linalg.inv(out_affine) @ corners + corners = in_affine_ @ corners + try: + inv_mat = np.linalg.inv(out_affine_) + except np.linalg.LinAlgError as e: + raise ValueError(f"Affine {out_affine_} is not invertible") from e + corners_out = inv_mat @ corners corners_out = corners_out[:-1] / corners_out[-1] out_shape = np.round(corners_out.ptp(axis=1) + 1.0) - if np.allclose(nib.io_orientation(in_affine), nib.io_orientation(out_affine)): - # same orientation, get translate from the origin - offset = in_affine @ ([0] * sr + [1]) - offset = offset[:-1] / offset[-1] - else: - # different orientation, the min is the origin - corners = corners[:-1] / corners[-1] - offset = np.min(corners, 1) + mat = inv_mat[:-1, :-1] + k = 0 + for i in range(corners.shape[1]): + min_corner = np.min(mat @ corners[:-1, :] - mat @ corners[:-1, i : i + 1], 1) + if np.allclose(min_corner, 0.0, rtol=AFFINE_TOL): + k = i + break + offset = corners[:-1, k] return out_shape.astype(int, copy=False), offset -def to_affine_nd(r: Union[np.ndarray, int], affine: np.ndarray) -> np.ndarray: +def to_affine_nd(r: Union[np.ndarray, int], affine: NdarrayTensor, dtype=np.float64) -> NdarrayTensor: """ Using elements from affine, to create a new affine matrix by assigning the rotation/zoom/scaling matrix and the translation vector. - when ``r`` is an integer, output is an (r+1)x(r+1) matrix, + When ``r`` is an integer, output is an (r+1)x(r+1) matrix, where the top left kxk elements are copied from ``affine``, the last column of the output affine is copied from ``affine``'s last column. `k` is determined by `min(r, len(affine) - 1)`. - when ``r`` is an affine matrix, the output has the same as ``r``, - the top left kxk elements are copied from ``affine``, + When ``r`` is an affine matrix, the output has the same shape as ``r``, + and the top left kxk elements are copied from ``affine``, the last column of the output affine is copied from ``affine``'s last column. `k` is determined by `min(len(r) - 1, len(affine) - 1)`. Args: r (int or matrix): number of spatial dimensions or an output affine to be filled. affine (matrix): 2D affine matrix + dtype: data type of the output array. Raises: ValueError: When ``affine`` dimensions is not 2. ValueError: When ``r`` is nonpositive. Returns: - an (r+1) x (r+1) matrix + an (r+1) x (r+1) matrix (tensor or ndarray depends on the input ``affine`` data type) """ - affine_np = np.array(affine, dtype=np.float64) + affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True)[0] + affine_np = affine_np.copy() if affine_np.ndim != 2: raise ValueError(f"affine must have 2 dimensions, got {affine_np.ndim}.") - new_affine = np.array(r, dtype=np.float64, copy=True) + new_affine = np.array(r, dtype=dtype, copy=True) if new_affine.ndim == 0: sr: int = int(new_affine.astype(np.uint)) if not np.isfinite(sr) or sr < 0: raise ValueError(f"r must be positive, got {sr}.") - new_affine = np.eye(sr + 1, dtype=np.float64) + new_affine = np.eye(sr + 1, dtype=dtype) d = max(min(len(new_affine) - 1, len(affine_np) - 1), 1) new_affine[:d, :d] = affine_np[:d, :d] if d > 1: new_affine[:d, -1] = affine_np[:d, -1] - return new_affine + output, *_ = convert_to_dst_type(new_affine, affine, dtype=dtype) + return output + + +def reorient_spatial_axes( + data_shape: Sequence[int], init_affine: NdarrayOrTensor, target_affine: NdarrayOrTensor +) -> Tuple[np.ndarray, NdarrayOrTensor]: + """ + Given the input ``init_affine``, compute the orientation transform between + it and ``target_affine`` by rearranging/flipping the axes. + + Returns the orientation transform and the updated affine (tensor or ndarray + depends on the input ``affine`` data type). + Note that this function requires external module ``nibabel.orientations``. + """ + init_affine_, *_ = convert_data_type(init_affine, np.ndarray) + target_affine_, *_ = convert_data_type(target_affine, np.ndarray) + start_ornt = nib.orientations.io_orientation(init_affine_) + target_ornt = nib.orientations.io_orientation(target_affine_) + try: + ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt) + except ValueError as e: + raise ValueError(f"The input affine {init_affine} and target affine {target_affine} are not compatible.") from e + new_affine = init_affine_ @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape) + new_affine, *_ = convert_to_dst_type(new_affine, init_affine) + return ornt_transform, new_affine def create_file_basename( @@ -839,7 +903,7 @@ def compute_importance_map( """ mode = look_up_option(mode, BlendMode) - device = torch.device(device) # type: ignore[arg-type] + device = torch.device(device) if mode == BlendMode.CONSTANT: importance_map = torch.ones(patch_size, device=device).float() elif mode == BlendMode.GAUSSIAN: @@ -1225,3 +1289,19 @@ def convert_tables_to_dicts( data = [dict(d, **{k: v[i] for k, v in groups.items()}) for i, d in enumerate(data)] return data + + +def orientation_ras_lps(affine: NdarrayTensor) -> NdarrayTensor: + """ + Convert the ``affine`` between the `RAS` and `LPS` orientation + by flipping the first two spatial dimensions. + + Args: + affine: a 2D affine matrix. + """ + sr = max(affine.shape[0] - 1, 1) # spatial rank is at least 1 + flip_d = [[-1, 1], [-1, -1, 1], [-1, -1, 1, 1]] + flip_diag = flip_d[min(sr - 1, 2)] + [1] * (sr - 3) + if isinstance(affine, torch.Tensor): + return torch.diag(torch.as_tensor(flip_diag).to(affine)) @ affine # type: ignore + return np.diag(flip_diag).astype(affine.dtype) @ affine # type: ignore diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index e3e70d8b32..c3e8c456b7 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -14,7 +14,7 @@ import torch from torch.utils.data import DataLoader -from monai.config import IgniteInfo +from monai.config import IgniteInfo, KeysCollection from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer @@ -281,6 +281,7 @@ class EnsembleEvaluator(Evaluator): networks: networks to evaluate in order in the evaluator, should be regular PyTorch `torch.nn.Module`. pred_keys: the keys to store every prediction data. the length must exactly match the number of networks. + if None, use "pred_{index}" as key corresponding to N networks, index from `0` to `N-1`. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. prepare_batch: function to parse expected data (usually `image`, `label` and other network args) @@ -321,7 +322,7 @@ def __init__( device: torch.device, val_data_loader: Union[Iterable, DataLoader], networks: Sequence[torch.nn.Module], - pred_keys: Sequence[str], + pred_keys: Optional[KeysCollection] = None, epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, @@ -358,7 +359,11 @@ def __init__( ) self.networks = ensure_tuple(networks) - self.pred_keys = ensure_tuple(pred_keys) + self.pred_keys = ( + [f"{Keys.PRED}_{i}" for i in range(len(self.networks))] if pred_keys is None else ensure_tuple(pred_keys) + ) + if len(self.pred_keys) != len(self.networks): + raise ValueError("length of `pred_keys` must be same as the length of `networks`.") self.inferer = SimpleInferer() if inferer is None else inferer def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): diff --git a/monai/engines/multi_gpu_supervised_trainer.py b/monai/engines/multi_gpu_supervised_trainer.py index 7c59b670b7..0433617649 100644 --- a/monai/engines/multi_gpu_supervised_trainer.py +++ b/monai/engines/multi_gpu_supervised_trainer.py @@ -74,8 +74,8 @@ def create_multigpu_supervised_trainer( tuple of tensors `(batch_x, batch_y)`. output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`. - distributed: whether convert model to `DistributedDataParallel`, if have multiple devices, use - the first device as output device. + distributed: whether convert model to `DistributedDataParallel`, if `True`, `devices` must contain + only 1 GPU or CPU for current distributed rank. Returns: Engine: a trainer engine with supervised update function. @@ -87,6 +87,8 @@ def create_multigpu_supervised_trainer( devices_ = get_devices_spec(devices) if distributed: + if len(devices_) > 1: + raise ValueError(f"for distributed training, `devices` must contain only 1 GPU or CPU, but got {devices_}.") net = DistributedDataParallel(net, device_ids=devices_) elif len(devices_) > 1: net = DataParallel(net) @@ -122,8 +124,8 @@ def create_multigpu_supervised_evaluator( output_transform: function that receives 'x', 'y', 'y_pred' and returns value to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits output expected by metrics. If you change it you should use `output_transform` in metrics. - distributed: whether convert model to `DistributedDataParallel`, if have multiple devices, use - the first device as output device. + distributed: whether convert model to `DistributedDataParallel`, if `True`, `devices` must contain + only 1 GPU or CPU for current distributed rank. Note: `engine.state.output` for this engine is defined by `output_transform` parameter and is @@ -137,6 +139,10 @@ def create_multigpu_supervised_evaluator( if distributed: net = DistributedDataParallel(net, device_ids=devices_) + if len(devices_) > 1: + raise ValueError( + f"for distributed evaluation, `devices` must contain only 1 GPU or CPU, but got {devices_}." + ) elif len(devices_) > 1: net = DataParallel(net) diff --git a/monai/handlers/lr_schedule_handler.py b/monai/handlers/lr_schedule_handler.py index 207af306d4..db186bd73d 100644 --- a/monai/handlers/lr_schedule_handler.py +++ b/monai/handlers/lr_schedule_handler.py @@ -36,7 +36,6 @@ def __init__( name: Optional[str] = None, epoch_level: bool = True, step_transform: Callable[[Engine], Any] = lambda engine: (), - logger_handler: Optional[logging.Handler] = None, ) -> None: """ Args: @@ -48,9 +47,6 @@ def __init__( `True` is epoch level, `False` is iteration level. step_transform: a callable that is used to transform the information from `engine` to expected input data of lr_scheduler.step() function if necessary. - logger_handler: if `print_lr` is True, add additional handler to log the learning rate: save to file, etc. - all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html. - the handler should have a logging level of at least `INFO`. Raises: TypeError: When ``step_transform`` is not ``callable``. @@ -63,8 +59,6 @@ def __init__( if not callable(step_transform): raise TypeError(f"step_transform must be callable but is {type(step_transform).__name__}.") self.step_transform = step_transform - if logger_handler is not None: - self.logger.addHandler(logger_handler) self._name = name diff --git a/monai/handlers/parameter_scheduler.py b/monai/handlers/parameter_scheduler.py index c0e18edcd0..67c51fd351 100644 --- a/monai/handlers/parameter_scheduler.py +++ b/monai/handlers/parameter_scheduler.py @@ -134,7 +134,7 @@ def _exponential(initial_value: float, gamma: float, current_step: int) -> float Returns: float: new parameter value """ - return initial_value * gamma ** current_step + return initial_value * gamma**current_step @staticmethod def _step(initial_value: float, gamma: float, step_size: int, current_step: int) -> float: diff --git a/monai/handlers/roc_auc.py b/monai/handlers/roc_auc.py index 31b046064e..68cf2e655e 100644 --- a/monai/handlers/roc_auc.py +++ b/monai/handlers/roc_auc.py @@ -16,7 +16,7 @@ from monai.utils import Average -class ROCAUC(IgniteMetric): # type: ignore[valid-type, misc] # due to optional_import +class ROCAUC(IgniteMetric): """ Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). accumulating predictions and the ground-truth during an epoch and applying `compute_roc_auc`. diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index 8410aaec87..a51a5a1382 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -34,13 +34,28 @@ class StatsHandler: It can be used for any Ignite Engine(trainer, validator and evaluator). And it can support logging for epoch level and iteration level with pre-defined loggers. + Note that if `name` arg is None, will leverage `engine.logger` as default logger directly, otherwise, + get logger from `logging.getLogger(name)`, we can setup a logger outside first with the same `name`. + As the default log level of `RootLogger` is `WARNING`, may need to call + `logging.basicConfig(stream=sys.stdout, level=logging.INFO)` before running this handler to enable + the stats logging. + Default behaviors: - When EPOCH_COMPLETED, logs ``engine.state.metrics`` using ``self.logger``. - When ITERATION_COMPLETED, logs ``self.output_transform(engine.state.output)`` using ``self.logger``. - Usage example is available in the tutorial: - https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb. + Usage example:: + + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + + trainer = SupervisedTrainer(...) + StatsHandler(name="train_stats").attach(trainer) + + trainer.run() + + More details of example is available in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/engines/unet_training_dict.py. """ @@ -54,7 +69,6 @@ def __init__( name: Optional[str] = None, tag_name: str = DEFAULT_TAG, key_var_format: str = DEFAULT_KEY_VAL_FORMAT, - logger_handler: Optional[logging.Handler] = None, ) -> None: """ @@ -77,13 +91,11 @@ def __init__( with the trainer engine. state_attributes: expected attributes from `engine.state`, if provided, will extract them when epoch completed. - name: identifier of logging.logger to use, defaulting to ``engine.logger``. + name: identifier of `logging.logger` to use, if None, defaulting to ``engine.logger``. tag_name: when iteration output is a scalar, tag_name is used to print tag_name: scalar_value to logger. Defaults to ``'Loss'``. key_var_format: a formatting string to control the output string format of key: value. - logger_handler: add additional handler to handle the stats data: save to file, etc. - all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html. - the handler should have a logging level of at least `INFO`. + """ self.epoch_print_logger = epoch_print_logger @@ -91,13 +103,10 @@ def __init__( self.output_transform = output_transform self.global_epoch_transform = global_epoch_transform self.state_attributes = state_attributes - self.logger = logging.getLogger(name) - self._name = name - self.tag_name = tag_name self.key_var_format = key_var_format - if logger_handler is not None: - self.logger.addHandler(logger_handler) + self.logger = logging.getLogger(name) # if `name` is None, will default to `engine.logger` when attached + self.name = name def attach(self, engine: Engine) -> None: """ @@ -107,8 +116,13 @@ def attach(self, engine: Engine) -> None: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - if self._name is None: + if self.name is None: self.logger = engine.logger + if self.logger.getEffectiveLevel() > logging.INFO or logging.root.getEffectiveLevel() > logging.INFO: + warnings.warn( + "the effective log level of engine logger or RootLogger is higher than INFO, may not record log," + " please call `logging.basicConfig(stream=sys.stdout, level=logging.INFO)` to enable it." + ) if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) if not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index b527522cd7..a06f6fb5cd 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -126,7 +126,7 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != pred.shape: raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})") - t2, p2, tp = target ** 2, pred ** 2, target * pred + t2, p2, tp = target**2, pred**2, target * pred kernel, kernel_vol = self.kernel.to(pred), self.kernel_vol.to(pred) # sum over kernel t_sum = separable_filtering(target, kernels=[kernel.to(pred)] * self.ndim) @@ -217,7 +217,7 @@ def __init__( self.num_bins = num_bins self.kernel_type = kernel_type if self.kernel_type == "gaussian": - self.preterm = 1 / (2 * sigma ** 2) + self.preterm = 1 / (2 * sigma**2) self.bin_centers = bin_centers[None, None, ...] self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) @@ -280,7 +280,7 @@ def parzen_windowing_b_spline(self, img: torch.Tensor, order: int) -> Tuple[torc weight = weight + (sample_bin_matrix < 0.5) + (sample_bin_matrix == 0.5) * 0.5 elif order == 3: weight = ( - weight + (4 - 6 * sample_bin_matrix ** 2 + 3 * sample_bin_matrix ** 3) * (sample_bin_matrix < 1) / 6 + weight + (4 - 6 * sample_bin_matrix**2 + 3 * sample_bin_matrix**3) * (sample_bin_matrix < 1) / 6 ) weight = weight + (2 - sample_bin_matrix) ** 3 * (sample_bin_matrix >= 1) * (sample_bin_matrix < 2) / 6 else: diff --git a/monai/metrics/confusion_matrix.py b/monai/metrics/confusion_matrix.py index dd1f81a4b4..eca4663ee7 100644 --- a/monai/metrics/confusion_matrix.py +++ b/monai/metrics/confusion_matrix.py @@ -102,7 +102,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor return get_confusion_matrix(y_pred=y_pred, y=y, include_background=self.include_background) - def aggregate(self): # type: ignore + def aggregate(self): """ Execute reduction for the confusion matrix values. diff --git a/monai/metrics/cumulative_average.py b/monai/metrics/cumulative_average.py index 090d65a44c..768841f6c7 100644 --- a/monai/metrics/cumulative_average.py +++ b/monai/metrics/cumulative_average.py @@ -51,7 +51,7 @@ def reset(self): self.sum = None self.not_nans = None - def aggregate(self): # type: ignore + def aggregate(self): """ Sync data from all the ranks and compute the average value with previous sum value. diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 082311aa67..51748c43c1 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -85,7 +85,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor if not torch.all(y_pred.byte() == y_pred): warnings.warn("y_pred should be a binarized tensor.") if not torch.all(y.byte() == y): - raise ValueError("y should be a binarized tensor.") + warnings.warn("y should be a binarized tensor.") dims = y_pred.ndimension() if dims < 3: raise ValueError("y_pred should have at least three dimensions.") @@ -99,7 +99,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor directed=self.directed, ) - def aggregate(self): # type: ignore + def aggregate(self): """ Execute reduction logic for the output of `compute_hausdorff_distance`. diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 1e6065b59c..15635ff2d9 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -72,14 +72,14 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor if not torch.all(y_pred.byte() == y_pred): warnings.warn("y_pred should be a binarized tensor.") if not torch.all(y.byte() == y): - raise ValueError("y should be a binarized tensor.") + warnings.warn("y should be a binarized tensor.") dims = y_pred.ndimension() if dims < 3: raise ValueError("y_pred should have at least three dimensions.") # compute dice (BxC) for each channel for each batch return compute_meandice(y_pred=y_pred, y=y, include_background=self.include_background) - def aggregate(self): # type: ignore + def aggregate(self): """ Execute reduction logic for the output of `compute_meandice`. diff --git a/monai/metrics/metric.py b/monai/metrics/metric.py index 60dcd0b52d..7782c4c468 100644 --- a/monai/metrics/metric.py +++ b/monai/metrics/metric.py @@ -201,8 +201,7 @@ def extend(self, *data) -> None: self._buffers = [[] for _ in data] for b, d in zip(self._buffers, data): # converting to pytorch tensors so that we can use the distributed API - d_t: torch.Tensor - d_t, *_ = convert_data_type(d, output_type=torch.Tensor, wrap_sequence=True) # type: ignore + d_t, *_ = convert_data_type(d, output_type=torch.Tensor, wrap_sequence=True) try: b.extend([x[0] for x in torch.split(d_t, 1, dim=0)]) except (AttributeError, IndexError, RuntimeError) as e: @@ -228,8 +227,7 @@ def append(self, *data) -> None: self._buffers = [[] for _ in data] for b, d in zip(self._buffers, data): # converting to pytorch tensors so that we can use the distributed API - d_t: torch.Tensor - d_t, *_ = convert_data_type(d, output_type=torch.Tensor, wrap_sequence=True) # type: ignore + d_t, *_ = convert_data_type(d, output_type=torch.Tensor, wrap_sequence=True) b.append(d_t) self._synced = False diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py index bd63134d6c..d5733eee97 100644 --- a/monai/metrics/regression.py +++ b/monai/metrics/regression.py @@ -45,7 +45,7 @@ def __init__( self.reduction = reduction self.get_not_nans = get_not_nans - def aggregate(self): # type: ignore + def aggregate(self): data = self.get_buffer() if not isinstance(data, torch.Tensor): raise ValueError("the data to aggregate must be PyTorch Tensor.") diff --git a/monai/metrics/rocauc.py b/monai/metrics/rocauc.py index 73f15534a9..e65d5ae4cb 100644 --- a/monai/metrics/rocauc.py +++ b/monai/metrics/rocauc.py @@ -49,7 +49,7 @@ def __init__(self, average: Union[Average, str] = Average.MACRO) -> None: def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore return y_pred, y - def aggregate(self): # type: ignore + def aggregate(self): """ As AUC metric needs to execute on the overall data, so usually users accumulate `y_pred` and `y` of every iteration, then execute real computation and reduction on the accumulated data. diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 04eed97a5d..fce4b735e5 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -78,7 +78,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor if not torch.all(y_pred.byte() == y_pred): warnings.warn("y_pred should be a binarized tensor.") if not torch.all(y.byte() == y): - raise ValueError("y should be a binarized tensor.") + warnings.warn("y should be a binarized tensor.") dims = y_pred.ndimension() if dims < 3: raise ValueError("y_pred should have at least three dimensions.") @@ -91,7 +91,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor distance_metric=self.distance_metric, ) - def aggregate(self): # type: ignore + def aggregate(self): """ Execute reduction logic for the output of `compute_average_surface_distance`. @@ -158,7 +158,7 @@ def compute_average_surface_distance( if surface_distance.shape == (0,): avg_surface_distance = np.nan else: - avg_surface_distance = surface_distance.mean() # type: ignore + avg_surface_distance = surface_distance.mean() if not symmetric: asd[b, c] = avg_surface_distance else: @@ -166,7 +166,7 @@ def compute_average_surface_distance( if surface_distance_2.shape == (0,): avg_surface_distance_2 = np.nan else: - avg_surface_distance_2 = surface_distance_2.mean() # type: ignore + avg_surface_distance_2 = surface_distance_2.mean() asd[b, c] = np.mean((avg_surface_distance, avg_surface_distance_2)) return torch.from_numpy(asd) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index ccb6d93862..cf62cf9960 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -116,7 +116,7 @@ def get_mask_edges( The input images can be binary or labelfield images. If labelfield images are supplied, they are converted to binary images using `label_idx`. - `scipy`'s binary erosion is used to to calculate the edges of the binary + `scipy`'s binary erosion is used to calculate the edges of the binary labelfield. In order to improve the computing efficiency, before getting the edges, diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index 4e607dd298..76223dfaef 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -13,12 +13,14 @@ convert_to_torchscript, copy_model_state, eval_mode, + get_state_dict, icnr_init, normal_init, normalize_transform, one_hot, pixelshuffle, predict_segmentation, + save_state, slice_channels, to_norm_affine, train_mode, diff --git a/monai/networks/blocks/acti_norm.py b/monai/networks/blocks/acti_norm.py index d07d78f1ad..65b662ac32 100644 --- a/monai/networks/blocks/acti_norm.py +++ b/monai/networks/blocks/acti_norm.py @@ -98,4 +98,4 @@ def __init__( if item not in op_dict: raise ValueError(f"ordering must be a string of {op_dict}, got {item} in it.") if op_dict[item] is not None: - self.add_module(item, op_dict[item]) # type: ignore + self.add_module(item, op_dict[item]) diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 063a1fded1..4c7263c6d5 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -80,7 +80,7 @@ def __init__( if self.pos_embed == "perceptron" and m % p != 0: raise ValueError("patch_size should be divisible by img_size for perceptron.") self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_size, patch_size)]) - self.patch_dim = in_channels * np.prod(patch_size) + self.patch_dim = int(in_channels * np.prod(patch_size)) self.patch_embeddings: nn.Module if self.pos_embed == "conv": diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 4a86cd84bc..db92111d14 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -14,7 +14,7 @@ from monai.utils import optional_import -einops, _ = optional_import("einops") +Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") class SABlock(nn.Module): @@ -43,17 +43,20 @@ def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0) self.num_heads = num_heads self.out_proj = nn.Linear(hidden_size, hidden_size) self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=False) + self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) + self.out_rearrange = Rearrange("b h l d -> b l (h d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) self.head_dim = hidden_size // num_heads - self.scale = self.head_dim ** -0.5 + self.scale = self.head_dim**-0.5 def forward(self, x): - q, k, v = einops.rearrange(self.qkv(x), "b h (qkv l d) -> qkv b l h d", qkv=3, l=self.num_heads) + output = self.input_rearrange(self.qkv(x)) + q, k, v = output[0], output[1], output[2] att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) - x = einops.rearrange(x, "b h l d -> b l (h d)") + x = self.out_rearrange(x) x = self.out_proj(x) x = self.drop_output(x) return x diff --git a/monai/networks/blocks/upsample.py b/monai/networks/blocks/upsample.py index c72d1bc518..fa3929df20 100644 --- a/monai/networks/blocks/upsample.py +++ b/monai/networks/blocks/upsample.py @@ -219,7 +219,7 @@ def __init__( out_channels = out_channels or in_channels if not out_channels: raise ValueError("in_channels need to be specified.") - conv_out_channels = out_channels * (scale_factor ** self.dimensions) + conv_out_channels = out_channels * (scale_factor**self.dimensions) self.conv_block = Conv[Conv.CONV, self.dimensions]( in_channels=in_channels, out_channels=conv_out_channels, kernel_size=3, stride=1, padding=1, bias=bias ) @@ -247,7 +247,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x: Tensor in shape (batch, channel, spatial_1[, spatial_2, ...). """ x = self.conv_block(x) - if x.shape[1] % (self.scale_factor ** self.dimensions) != 0: + if x.shape[1] % (self.scale_factor**self.dimensions) != 0: raise ValueError( f"Number of channels after `conv_block` ({x.shape[1]}) must be evenly " "divisible by scale_factor ** dimensions " diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 9fdaab0a48..5b925258b6 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -150,7 +150,7 @@ def forward(self, dvf): Returns: a dense displacement field """ - ddf: torch.Tensor = dvf / (2 ** self.num_steps) + ddf: torch.Tensor = dvf / (2**self.num_steps) for _ in range(self.num_steps): ddf = ddf + self.warp_layer(image=ddf, ddf=ddf) return ddf diff --git a/monai/networks/layers/convutils.py b/monai/networks/layers/convutils.py index 5efb6e792f..1e9ce954e8 100644 --- a/monai/networks/layers/convutils.py +++ b/monai/networks/layers/convutils.py @@ -115,7 +115,7 @@ def gaussian_1d( out = out.clamp(min=0) elif approx.lower() == "sampled": x = torch.arange(-tail, tail + 1, dtype=torch.float, device=sigma.device) - out = torch.exp(-0.5 / (sigma * sigma) * x ** 2) + out = torch.exp(-0.5 / (sigma * sigma) * x**2) if not normalize: # compute the normalizer out = out / (2.5066282 * sigma) elif approx.lower() == "scalespace": diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py index 7aa3e110fc..c1bb951c4d 100644 --- a/monai/networks/layers/spatial_transforms.py +++ b/monai/networks/layers/spatial_transforms.py @@ -70,13 +70,13 @@ def grid_pull( `bound` can be an int, a string or a BoundType. Possible values are:: - - 0 or 'replicate' or 'nearest' or BoundType.replicate + - 0 or 'replicate' or 'nearest' or BoundType.replicate or 'border' - 1 or 'dct1' or 'mirror' or BoundType.dct1 - 2 or 'dct2' or 'reflect' or BoundType.dct2 - 3 or 'dst1' or 'antimirror' or BoundType.dst1 - 4 or 'dst2' or 'antireflect' or BoundType.dst2 - 5 or 'dft' or 'wrap' or BoundType.dft - - 7 or 'zero' or BoundType.zero + - 7 or 'zero' or 'zeros' or BoundType.zero A list of values can be provided, in the order [W, H, D], to specify dimension-specific boundary conditions. @@ -115,7 +115,7 @@ def grid_pull( for i in ensure_tuple(interpolation) ] out: torch.Tensor - out = _GridPull.apply(input, grid, interpolation, bound, extrapolate) # type: ignore + out = _GridPull.apply(input, grid, interpolation, bound, extrapolate) return out diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py index 8fb2c269ab..52bd2fa994 100644 --- a/monai/networks/nets/densenet.py +++ b/monai/networks/nets/densenet.py @@ -148,6 +148,9 @@ class DenseNet(nn.Module): """ Densenet based on: `Densely Connected Convolutional Networks `_. Adapted from PyTorch Hub 2D version: https://pytorch.org/vision/stable/models.html#id16. + This network is non-determistic When `spatial_dims` is 3 and CUDA is enabled. Please check the link below + for more details: + https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms Args: spatial_dims: number of spatial dimensions of the input image. diff --git a/monai/networks/nets/dints.py b/monai/networks/nets/dints.py index c024d6e0f1..a4aaf32eed 100644 --- a/monai/networks/nets/dints.py +++ b/monai/networks/nets/dints.py @@ -124,7 +124,7 @@ def __init__(self, in_channel: int, out_channel: int, spatial_dims: int = 3): # s0 is upsampled 2x from s1, representing feature sizes at two resolutions. # in_channel * s0 (activation) + 3 * out_channel * s1 (convolution, concatenation, normalization) # s0 = s1 * 2^(spatial_dims) = output_size / out_channel * 2^(spatial_dims) - self.ram_cost = in_channel / out_channel * 2 ** self._spatial_dims + 3 + self.ram_cost = in_channel / out_channel * 2**self._spatial_dims + 3 class MixedOp(nn.Module): @@ -330,7 +330,7 @@ def __init__( # define downsample stems before DiNTS search if use_downsample: self.stem_down[str(res_idx)] = StemTS( - nn.Upsample(scale_factor=1 / (2 ** res_idx), mode=mode, align_corners=True), + nn.Upsample(scale_factor=1 / (2**res_idx), mode=mode, align_corners=True), conv_type( in_channels=in_channels, out_channels=self.filter_nums[res_idx], @@ -373,7 +373,7 @@ def __init__( else: self.stem_down[str(res_idx)] = StemTS( - nn.Upsample(scale_factor=1 / (2 ** res_idx), mode=mode, align_corners=True), + nn.Upsample(scale_factor=1 / (2**res_idx), mode=mode, align_corners=True), conv_type( in_channels=in_channels, out_channels=self.filter_nums[res_idx], @@ -789,7 +789,7 @@ def get_ram_cost_usage(self, in_size, full: bool = False): image_size = np.array(in_size[-self._spatial_dims :]) sizes = [] for res_idx in range(self.num_depths): - sizes.append(batch_size * self.filter_nums[res_idx] * (image_size // (2 ** res_idx)).prod()) + sizes.append(batch_size * self.filter_nums[res_idx] * (image_size // (2**res_idx)).prod()) sizes = torch.tensor(sizes).to(torch.float32).to(self.device) / (2 ** (int(self.use_downsample))) probs_a, arch_code_prob_a = self.get_prob_a(child=False) cell_prob = F.softmax(self.log_alpha_c, dim=-1) @@ -807,7 +807,7 @@ def get_ram_cost_usage(self, in_size, full: bool = False): * (1 + (ram_cost[blk_idx, path_idx] * cell_prob[blk_idx, path_idx]).sum()) * sizes[self.arch_code2out[path_idx]] ) - return usage * 32 / 8 / 1024 ** 2 + return usage * 32 / 8 / 1024**2 def get_topology_entropy(self, probs): """ diff --git a/monai/networks/nets/highresnet.py b/monai/networks/nets/highresnet.py index 95c0c758af..891a65e67b 100644 --- a/monai/networks/nets/highresnet.py +++ b/monai/networks/nets/highresnet.py @@ -168,7 +168,7 @@ def __init__( # residual blocks for (idx, params) in enumerate(layer_params[1:-2]): # res blocks except the 1st and last two conv layers. _in_chns, _out_chns = _out_chns, params["n_features"] - _dilation = 2 ** idx + _dilation = 2**idx for _ in range(params["repeat"]): blocks.append( HighResBlock( diff --git a/monai/networks/nets/regunet.py b/monai/networks/nets/regunet.py index 8524563faa..6776c7ce9e 100644 --- a/monai/networks/nets/regunet.py +++ b/monai/networks/nets/regunet.py @@ -92,7 +92,7 @@ def __init__( raise AssertionError self.encode_kernel_sizes: List[int] = encode_kernel_sizes - self.num_channels = [self.num_channel_initial * (2 ** d) for d in range(self.depth + 1)] + self.num_channels = [self.num_channel_initial * (2**d) for d in range(self.depth + 1)] self.min_extract_level = min(self.extract_levels) # init layers @@ -310,14 +310,14 @@ def __init__( encode_kernel_sizes: Union[int, List[int]] = 3, ): for size in image_size: - if size % (2 ** depth) != 0: + if size % (2**depth) != 0: raise ValueError( f"given depth {depth}, " f"all input spatial dimension must be divisible by {2 ** depth}, " f"got input of size {image_size}" ) self.image_size = image_size - self.decode_size = [size // (2 ** depth) for size in image_size] + self.decode_size = [size // (2**depth) for size in image_size] super().__init__( spatial_dims=spatial_dims, in_channels=in_channels, diff --git a/monai/networks/nets/segresnet.py b/monai/networks/nets/segresnet.py index d2c45dd3a3..299f1ca811 100644 --- a/monai/networks/nets/segresnet.py +++ b/monai/networks/nets/segresnet.py @@ -102,7 +102,7 @@ def _make_down_layers(self): down_layers = nn.ModuleList() blocks_down, spatial_dims, filters, norm = (self.blocks_down, self.spatial_dims, self.init_filters, self.norm) for i in range(len(blocks_down)): - layer_in_channels = filters * 2 ** i + layer_in_channels = filters * 2**i pre_conv = ( get_conv_layer(spatial_dims, layer_in_channels // 2, layer_in_channels, stride=2) if i > 0 @@ -299,12 +299,12 @@ def _get_vae_loss(self, net_input: torch.Tensor, vae_input: torch.Tensor): if self.vae_estimate_std: z_sigma = self.vae_fc2(x_vae) z_sigma = F.softplus(z_sigma) - vae_reg_loss = 0.5 * torch.mean(z_mean ** 2 + z_sigma ** 2 - torch.log(1e-8 + z_sigma ** 2) - 1) + vae_reg_loss = 0.5 * torch.mean(z_mean**2 + z_sigma**2 - torch.log(1e-8 + z_sigma**2) - 1) x_vae = z_mean + z_sigma * z_mean_rand else: z_sigma = self.vae_default_std - vae_reg_loss = torch.mean(z_mean ** 2) + vae_reg_loss = torch.mean(z_mean**2) x_vae = z_mean + z_sigma * z_mean_rand diff --git a/monai/networks/nets/transchex.py b/monai/networks/nets/transchex.py index 7b3afe3e5e..cb9baa5e65 100644 --- a/monai/networks/nets/transchex.py +++ b/monai/networks/nets/transchex.py @@ -217,7 +217,7 @@ class MultiModal(BertPreTrainedModel): """ def __init__( - self, num_language_layers: int, num_vision_layers: int, num_mixed_layers: int, bert_config: dict # type: ignore + self, num_language_layers: int, num_vision_layers: int, num_mixed_layers: int, bert_config: dict ) -> None: """ Args: @@ -254,8 +254,8 @@ class Transchex(torch.nn.Module): def __init__( self, in_channels: int, - img_size: Union[Sequence[int], int], # type: ignore - patch_size: Union[int, Tuple[int, int]], # type: ignore + img_size: Union[Sequence[int], int], + patch_size: Union[int, Tuple[int, int]], num_classes: int, num_language_layers: int, num_vision_layers: int, @@ -352,10 +352,7 @@ def __init__( self.patch_size = patch_size self.num_patches = (img_size[0] // self.patch_size[0]) * (img_size[1] // self.patch_size[1]) # type: ignore self.vision_proj = nn.Conv2d( - in_channels=in_channels, - out_channels=hidden_size, - kernel_size=self.patch_size, # type: ignore - stride=self.patch_size, # type: ignore + in_channels=in_channels, out_channels=hidden_size, kernel_size=self.patch_size, stride=self.patch_size ) self.norm_vision_pos = nn.LayerNorm(hidden_size) self.pos_embed_vis = nn.Parameter(torch.zeros(1, self.num_patches, hidden_size)) diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index 62e92603ab..a5f7963eca 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -25,6 +25,8 @@ class ViT(nn.Module): """ Vision Transformer (ViT), based on: "Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + + ViT supports Torchscript but only works for Pytorch after 1.8. """ def __init__( @@ -99,7 +101,7 @@ def __init__( def forward(self, x): x = self.patch_embedding(x) - if self.classification: + if hasattr(self, "cls_token"): cls_token = self.cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_token, x), dim=1) hidden_states_out = [] @@ -107,6 +109,6 @@ def forward(self, x): x = blk(x) hidden_states_out.append(x) x = self.norm(x) - if self.classification: + if hasattr(self, "classification_head"): x = self.classification_head(x[:, 0]) return x, hidden_states_out diff --git a/monai/networks/utils.py b/monai/networks/utils.py index a4ca0a6fd5..a6b0699107 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -20,8 +20,9 @@ import torch import torch.nn as nn +from monai.config import PathLike from monai.utils.deprecate_utils import deprecated, deprecated_arg -from monai.utils.misc import ensure_tuple, set_determinism +from monai.utils.misc import ensure_tuple, save_obj, set_determinism from monai.utils.module import pytorch_after __all__ = [ @@ -35,7 +36,9 @@ "pixelshuffle", "eval_mode", "train_mode", + "get_state_dict", "copy_model_state", + "save_state", "convert_to_torchscript", "meshgrid_ij", ] @@ -270,7 +273,7 @@ def pixelshuffle( dim, factor = spatial_dims, scale_factor input_size = list(x.size()) batch_size, channels = input_size[:2] - scale_divisor = factor ** dim + scale_divisor = factor**dim if channels % scale_divisor != 0: raise ValueError( @@ -357,6 +360,20 @@ def train_mode(*nets: nn.Module): n.eval() +def get_state_dict(obj: Union[torch.nn.Module, Mapping]): + """ + Get the state dict of input object if has `state_dict`, otherwise, return object directly. + For data parallel model, automatically convert it to regular model first. + + Args: + obj: input object to check and get the state_dict. + + """ + if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): + obj = obj.module + return obj.state_dict() if hasattr(obj, "state_dict") else obj # type: ignore + + def copy_model_state( dst: Union[torch.nn.Module, Mapping], src: Union[torch.nn.Module, Mapping], @@ -401,15 +418,10 @@ def copy_model_state( # Returns: an OrderedDict of the updated `dst` state, the changed, and unchanged keys. - """ - if isinstance(src, (nn.DataParallel, nn.parallel.DistributedDataParallel)): - src = src.module - if isinstance(dst, (nn.DataParallel, nn.parallel.DistributedDataParallel)): - dst = dst.module - src_dict = src.state_dict() if isinstance(src, torch.nn.Module) else src - dst_dict = dst.state_dict() if isinstance(dst, torch.nn.Module) else dst - dst_dict = OrderedDict(dst_dict) + """ + src_dict = get_state_dict(src) + dst_dict = OrderedDict(get_state_dict(dst)) to_skip = {s_key for s_key in src_dict if exclude_vars and re.compile(exclude_vars).search(s_key)} @@ -436,6 +448,40 @@ def copy_model_state( return dst_dict, updated_keys, unchanged_keys +def save_state(src: Union[torch.nn.Module, Dict], path: PathLike, **kwargs): + """ + Save the state dict of input source data with PyTorch `save`. + It can save `nn.Module`, `state_dict`, a dictionary of `nn.Module` or `state_dict`. + And automatically convert the data parallel module to regular module. + For example:: + + save_state(net, path) + save_state(net.state_dict(), path) + save_state({"net": net, "opt": opt}, path) + net_dp = torch.nn.DataParallel(net) + save_state(net_dp, path) + + Refer to: https://pytorch.org/ignite/v0.4.8/generated/ignite.handlers.DiskSaver.html. + + Args: + src: input data to save, can be `nn.Module`, `state_dict`, a dictionary of `nn.Module` or `state_dict`. + path: target file path to save the input object. + kwargs: other args for the `save_obj` except for the `obj` and `path`. + default `func` is `torch.save()`, details of the args of it: + https://pytorch.org/docs/stable/generated/torch.save.html. + + """ + + ckpt: Dict = {} + if isinstance(src, dict): + for k, v in src.items(): + ckpt[k] = get_state_dict(v) + else: + ckpt = get_state_dict(src) + + save_obj(obj=ckpt, path=path, **kwargs) + + def convert_to_torchscript( model: nn.Module, filename_or_obj: Optional[Any] = None, diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 9b779ed18e..53f1009a76 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -201,7 +201,7 @@ ThresholdIntensityDict, ) from .inverse import InvertibleTransform, TraceableTransform -from .inverse_batch_transform import BatchInverseTransform, Decollated +from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict from .io.array import SUPPORTED_READERS, LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .nvtx import ( @@ -249,6 +249,8 @@ AsDiscreted, AsDiscreteDict, Ensembled, + EnsembleD, + EnsembleDict, FillHolesD, FillHolesd, FillHolesDict, @@ -283,7 +285,17 @@ RandSmoothFieldAdjustIntensity, SmoothField, ) -from .smooth_field.dictionary import RandSmoothDeformd, RandSmoothFieldAdjustContrastd, RandSmoothFieldAdjustIntensityd +from .smooth_field.dictionary import ( + RandSmoothDeformd, + RandSmoothDeformD, + RandSmoothDeformDict, + RandSmoothFieldAdjustContrastd, + RandSmoothFieldAdjustContrastD, + RandSmoothFieldAdjustContrastDict, + RandSmoothFieldAdjustIntensityd, + RandSmoothFieldAdjustIntensityD, + RandSmoothFieldAdjustIntensityDict, +) from .spatial.array import ( Affine, AffineGrid, @@ -306,6 +318,7 @@ Rotate, Rotate90, Spacing, + SpatialResample, Zoom, ) from .spatial.dictionary import ( @@ -360,6 +373,9 @@ Spacingd, SpacingD, SpacingDict, + SpatialResampled, + SpatialResampleD, + SpatialResampleDict, Zoomd, ZoomD, ZoomDict, @@ -552,6 +568,7 @@ zero_margins, ) from .utils_pytorch_numpy_unification import ( + allclose, any_np_pt, ascontiguousarray, clip, @@ -562,11 +579,13 @@ isfinite, isnan, maximum, + mode, moveaxis, nonzero, percentile, ravel, repeat, + stack, unravel_index, where, ) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index b4b0e2af43..bc55af0b15 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -107,6 +107,9 @@ class Compose(Randomizable, InvertibleTransform): defaults to `True`. unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. defaults to `False`. + log_stats: whether to log the detailed information of data and applied transform when error happened, + for NumPy array and PyTorch Tensor, log the data shape and value range, + for other meta data, log the values directly. default to `False`. """ @@ -115,12 +118,14 @@ def __init__( transforms: Optional[Union[Sequence[Callable], Callable]] = None, map_items: bool = True, unpack_items: bool = False, + log_stats: bool = False, ) -> None: if transforms is None: transforms = [] self.transforms = ensure_tuple(transforms) self.map_items = map_items self.unpack_items = unpack_items + self.log_stats = log_stats self.set_random_state(seed=get_seed()) def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> "Compose": @@ -165,7 +170,7 @@ def __len__(self): def __call__(self, input_): for _transform in self.transforms: - input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items) + input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items, self.log_stats) return input_ def inverse(self, data): @@ -175,7 +180,7 @@ def inverse(self, data): # loop backwards over transforms for t in reversed(invertible_transforms): - data = apply_transform(t.inverse, data, self.map_items, self.unpack_items) + data = apply_transform(t.inverse, data, self.map_items, self.unpack_items, self.log_stats) return data @@ -192,6 +197,9 @@ class OneOf(Compose): defaults to `True`. unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. defaults to `False`. + log_stats: whether to log the detailed information of data and applied transform when error happened, + for NumPy array and PyTorch Tensor, log the data shape and value range, + for other meta data, log the values directly. default to `False`. """ @@ -201,8 +209,9 @@ def __init__( weights: Optional[Union[Sequence[float], float]] = None, map_items: bool = True, unpack_items: bool = False, + log_stats: bool = False, ) -> None: - super().__init__(transforms, map_items, unpack_items) + super().__init__(transforms, map_items, unpack_items, log_stats) if len(self.transforms) == 0: weights = [] elif weights is None or isinstance(weights, float): @@ -243,7 +252,7 @@ def __call__(self, data): return data index = self.R.multinomial(1, self.weights).argmax() _transform = self.transforms[index] - data = apply_transform(_transform, data, self.map_items, self.unpack_items) + data = apply_transform(_transform, data, self.map_items, self.unpack_items, self.log_stats) # if the data is a mapping (dictionary), append the OneOf transform to the end if isinstance(data, Mapping): for key in data.keys(): diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index faf5306ce0..199f185500 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -394,13 +394,13 @@ def __init__( data=roi_center, output_type=torch.Tensor, dtype=torch.int16, wrap_sequence=True ) roi_size, *_ = convert_to_dst_type(src=roi_size, dst=roi_center, wrap_sequence=True) - _zeros = torch.zeros_like(roi_center) # type: ignore + _zeros = torch.zeros_like(roi_center) roi_start_torch = maximum(roi_center - floor_divide(roi_size, 2), _zeros) # type: ignore roi_end_torch = maximum(roi_start_torch + roi_size, roi_start_torch) else: if roi_start is None or roi_end is None: raise ValueError("Please specify either roi_center, roi_size or roi_start, roi_end.") - roi_start_torch, *_ = convert_data_type( # type: ignore + roi_start_torch, *_ = convert_data_type( data=roi_start, output_type=torch.Tensor, dtype=torch.int16, wrap_sequence=True ) roi_start_torch = maximum(roi_start_torch, torch.zeros_like(roi_start_torch)) # type: ignore @@ -1162,7 +1162,7 @@ def __call__(self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, str] If None, defaults to the ``mode`` in construction. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ - return self.padder(self.cropper(img), mode=mode) # type: ignore + return self.padder(self.cropper(img), mode=mode) class BoundingRect(Transform): diff --git a/monai/transforms/croppad/batch.py b/monai/transforms/croppad/batch.py index 6edaf4622d..52d0c7be3b 100644 --- a/monai/transforms/croppad/batch.py +++ b/monai/transforms/croppad/batch.py @@ -45,8 +45,9 @@ class PadListDataCollate(InvertibleTransform): tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of different sizes. - This can be used on both list and dictionary data. In the case of the dictionary data, this transform will be added - to the list of invertible transforms. + This can be used on both list and dictionary data. + Note that in the case of the dictionary data, it may add the transform information to the list of invertible transforms + if input batch have different spatial shape, so need to call static method: `inverse` before inverting other transforms. Note that normally, a user won't explicitly use the `__call__` method. Rather this would be passed to the `DataLoader`. This means that `__call__` handles data as it comes out of a `DataLoader`, containing batch dimension. However, the diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index d62d6fbdcb..c2033e3443 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -735,7 +735,7 @@ class RandSpatialCropSamplesd(Randomizable, MapTransform, InvertibleTransform): the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. used to add `patch_index` to the meta dict. allow_missing_keys: don't raise exception if key is missing. @@ -933,7 +933,7 @@ class RandWeightedCropd(Randomizable, MapTransform, InvertibleTransform): the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. used to add `patch_index` to the meta dict. allow_missing_keys: don't raise exception if key is missing. @@ -1075,7 +1075,7 @@ class RandCropByPosNegLabeld(Randomizable, MapTransform, InvertibleTransform): the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. used to add `patch_index` to the meta dict. allow_smaller: if `False`, an exception will be raised if the image is smaller than @@ -1279,7 +1279,7 @@ class RandCropByLabelClassesd(Randomizable, MapTransform, InvertibleTransform): the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. used to add `patch_index` to the meta dict. allow_smaller: if `False`, an exception will be raised if the image is smaller than diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 6657950eae..b5cf61ce0c 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -23,7 +23,7 @@ import torch from monai.config import DtypeLike -from monai.config.type_definitions import NdarrayOrTensor +from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor from monai.data.utils import get_random_patch, get_valid_patch_size from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter from monai.transforms.transform import RandomizableTransform, Transform @@ -106,7 +106,7 @@ def randomize(self, img: NdarrayOrTensor, mean: Optional[float] = None) -> None: rand_std = self.R.uniform(0, self.std) noise = self.R.normal(self.mean if mean is None else mean, rand_std, size=img.shape) # noise is float64 array, convert to the output dtype to save memory - self.noise, *_ = convert_data_type(noise, dtype=self.dtype) # type: ignore + self.noise, *_ = convert_data_type(noise, dtype=self.dtype) def __call__(self, img: NdarrayOrTensor, mean: Optional[float] = None, randomize: bool = True) -> NdarrayOrTensor: """ @@ -182,9 +182,9 @@ def _add_noise(self, img: NdarrayOrTensor, mean: float, std: float): if isinstance(img, torch.Tensor): n1 = torch.tensor(self._noise1, device=img.device) n2 = torch.tensor(self._noise2, device=img.device) - return torch.sqrt((img + n1) ** 2 + n2 ** 2) + return torch.sqrt((img + n1) ** 2 + n2**2) - return np.sqrt((img + self._noise1) ** 2 + self._noise2 ** 2) + return np.sqrt((img + self._noise1) ** 2 + self._noise2**2) def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ @@ -694,7 +694,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: else: img = self._normalize(img, self.subtrahend, self.divisor) - out, *_ = convert_data_type(img, dtype=dtype) + out = convert_to_dst_type(img, img, dtype=dtype)[0] return out @@ -779,7 +779,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: img = img * (self.b_max - self.b_min) + self.b_min if self.clip: img = clip(img, self.b_min, self.b_max) - ret, *_ = convert_data_type(img, dtype=dtype) + ret: NdarrayOrTensor = convert_data_type(img, dtype=dtype)[0] return ret @@ -1115,8 +1115,7 @@ def __call__(self, img: NdarrayOrTensor): np.ndarray containing envelope of data in img along the specified axis. """ - img_t: torch.Tensor - img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore + img_t, *_ = convert_data_type(img, torch.Tensor) # add one to transform axis because a batch axis will be added at dimension 0 hilbert_transform = HilbertTransform(self.axis + 1, self.n) # convert to Tensor and add Batch axis expected by HilbertTransform @@ -1146,9 +1145,8 @@ def __init__(self, sigma: Union[Sequence[float], float] = 1.0, approx: str = "er self.sigma = sigma self.approx = approx - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: - img_t: torch.Tensor - img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) # type: ignore + def __call__(self, img: NdarrayTensor) -> NdarrayTensor: + img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) sigma: Union[Sequence[torch.Tensor], torch.Tensor] if isinstance(self.sigma, Sequence): sigma = [torch.as_tensor(s, device=img_t.device) for s in self.sigma] @@ -1255,9 +1253,8 @@ def __init__( self.alpha = alpha self.approx = approx - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: - img_t: torch.Tensor - img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float32) # type: ignore + def __call__(self, img: NdarrayTensor) -> NdarrayTensor: + img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float32) gf1, gf2 = ( GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx).to(img_t.device) @@ -1402,8 +1399,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen if self.reference_control_points is None or self.floating_control_points is None: raise RuntimeError("please call the `randomize()` function first.") - img_np: np.ndarray - img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore + img_np, *_ = convert_data_type(img, np.ndarray) img_min, img_max = img_np.min(), img_np.max() reference_control_points_scaled = self.reference_control_points * (img_max - img_min) + img_min floating_control_points_scaled = self.floating_control_points * (img_max - img_min) + img_min @@ -1626,13 +1622,13 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: # FT k = self.shift_fourier(img, n_dims) lib = np if isinstance(k, np.ndarray) else torch - log_abs = lib.log(lib.abs(k) + 1e-10) # type: ignore - phase = lib.angle(k) # type: ignore + log_abs = lib.log(lib.abs(k) + 1e-10) + phase = lib.angle(k) k_intensity = self.k_intensity # default log intensity if k_intensity is None: - k_intensity = tuple(lib.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5) # type: ignore + k_intensity = tuple(lib.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5) # highlight if isinstance(self.loc[0], Sequence): @@ -1641,7 +1637,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: else: self._set_spike(log_abs, self.loc, k_intensity) # map back - k = lib.exp(log_abs) * lib.exp(1j * phase) # type: ignore + k = lib.exp(log_abs) * lib.exp(1j * phase) img, *_ = convert_to_dst_type(self.inv_shift_fourier(k, n_dims), dst=img) return img @@ -1816,8 +1812,8 @@ def _set_default_range(self, img: NdarrayOrTensor) -> Sequence[Sequence[float]]: k = self.shift_fourier(img, n_dims) mod = torch if isinstance(k, torch.Tensor) else np - log_abs = mod.log(mod.absolute(k) + 1e-10) # type: ignore - shifted_means = mod.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5 # type: ignore + log_abs = mod.log(mod.absolute(k) + 1e-10) + shifted_means = mod.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5 return tuple((i * 0.95, i * 1.1) for i in shifted_means) @@ -1892,8 +1888,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen if not self._do_transform: return img - img_np: np.ndarray - img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore + img_np, *_ = convert_data_type(img, np.ndarray) out = self._transform_holes(img=img_np) ret, *_ = convert_to_dst_type(src=out, dst=img) return ret @@ -2048,12 +2043,11 @@ def __init__( self.dtype = dtype def __call__(self, img: NdarrayOrTensor, mask: Optional[NdarrayOrTensor] = None) -> NdarrayOrTensor: - img_np: np.ndarray - img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore + img_np, *_ = convert_data_type(img, np.ndarray) mask = mask if mask is not None else self.mask mask_np: Optional[np.ndarray] = None if mask is not None: - mask_np, *_ = convert_data_type(mask, np.ndarray) # type: ignore + mask_np, *_ = convert_data_type(mask, np.ndarray) ret = equalize_hist(img=img_np, mask=mask_np, num_bins=self.num_bins, min=self.min, max=self.max) out, *_ = convert_to_dst_type(src=ret, dst=img, dtype=self.dtype or img.dtype) diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index c66f768738..b0f5149456 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -307,7 +307,7 @@ def __init__( the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. used to extract the factor value is `factor_key` is not None. allow_missing_keys: don't raise exception if key is missing. @@ -366,7 +366,7 @@ def __init__( the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. used to extract the factor value is `factor_key` is not None. prob: probability of rotating. @@ -1533,7 +1533,7 @@ def __call__(self, data): if first_key == []: return d - self.dropper.randomize(d[first_key].shape[1:]) # type: ignore + self.dropper.randomize(d[first_key].shape[1:]) for key in self.key_iterator(d): d[key] = self.dropper(img=d[key], randomize=False) @@ -1602,7 +1602,7 @@ def __call__(self, data): if first_key == []: return d - self.shuffle.randomize(d[first_key].shape[1:]) # type: ignore + self.shuffle.randomize(d[first_key].shape[1:]) for key in self.key_iterator(d): d[key] = self.shuffle(img=d[key], randomize=False) diff --git a/monai/transforms/inverse_batch_transform.py b/monai/transforms/inverse_batch_transform.py index ae0317cea8..cc77a199dd 100644 --- a/monai/transforms/inverse_batch_transform.py +++ b/monai/transforms/inverse_batch_transform.py @@ -23,7 +23,7 @@ from monai.transforms.transform import MapTransform, Transform from monai.utils import first -__all__ = ["BatchInverseTransform", "Decollated"] +__all__ = ["BatchInverseTransform", "Decollated", "DecollateD", "DecollateDict"] class _BatchInverseDataset(Dataset): @@ -151,3 +151,6 @@ def __call__(self, data: Union[Dict, List]): d[key] = data[key] return decollate_batch(d, detach=self.detach, pad=self.pad_batch, fill_value=self.fill_value) + + +DecollateD = DecollateDict = Decollated diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index f8aa838439..f3715b2712 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -16,6 +16,7 @@ import inspect import logging import sys +import traceback import warnings from pathlib import Path from typing import Dict, List, Optional, Sequence, Union @@ -23,11 +24,12 @@ import numpy as np import torch -from monai.config import DtypeLike, PathLike +from monai.config import DtypeLike, NdarrayOrTensor, PathLike +from monai.data import image_writer +from monai.data.folder_layout import FolderLayout from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader -from monai.data.nifti_saver import NiftiSaver -from monai.data.png_saver import PNGSaver from monai.transforms.transform import Transform +from monai.transforms.utility.array import EnsureChannelFirst from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key from monai.utils import InterpolateMode, OptionalImportError, ensure_tuple, look_up_option, optional_import @@ -82,7 +84,7 @@ class LoadImage(Transform): - User-specified reader in the constructor of `LoadImage`. - Readers from the last to the first in the registered list. - Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader), - (npz, npy -> NumpyReader), (others -> ITKReader). + (npz, npy -> NumpyReader), (DICOM file -> ITKReader). See also: @@ -90,7 +92,15 @@ class LoadImage(Transform): """ - def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np.float32, *args, **kwargs) -> None: + def __init__( + self, + reader=None, + image_only: bool = False, + dtype: DtypeLike = np.float32, + ensure_channel_first: bool = False, + *args, + **kwargs, + ) -> None: """ Args: reader: reader to load image file and meta data @@ -103,6 +113,8 @@ def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np. image_only: if True return only the image volume, otherwise return image data array and header dict. dtype: if not None convert the loaded image to this data type. + ensure_channel_first: if `True` and loaded both image array and meta data, automatically convert + the image array shape to `channel first`. default to `False`. args: additional parameters for reader if providing a reader name. kwargs: additional parameters for reader if providing a reader name. @@ -112,7 +124,7 @@ def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np. or a tuple of two elements containing the data array, and the meta data in a dictionary format otherwise. - If `reader` is specified, the loader will attempt to use the specified readers and the default supported readers. This might introduce overheads when handling the exceptions of trying the incompatible loaders. - In this case, it is therefore recommended to set the most appropriate reader as + In this case, it is therefore recommended setting the most appropriate reader as the last item of the `reader` parameter. """ @@ -120,6 +132,7 @@ def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np. self.auto_select = reader is None self.image_only = image_only self.dtype = dtype + self.ensure_channel_first = ensure_channel_first self.readers: List[ImageReader] = [] for r in SUPPORTED_READERS: # set predefined readers as default @@ -184,7 +197,7 @@ def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Option """ filename = tuple(f"{Path(s).expanduser()}" for s in ensure_tuple(filename)) # allow Path objects - img = None + img, err = None, [] if reader is not None: img = reader.read(filename) # runtime specified reader else: @@ -197,99 +210,100 @@ def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Option try: img = reader.read(filename) except Exception as e: - logging.getLogger(self.__class__.__name__).debug( - f"{reader.__class__.__name__}: unable to load {filename}.\n" f"Error: {e}" + err.append(traceback.format_exc()) + logging.getLogger(self.__class__.__name__).debug(e, exc_info=True) + logging.getLogger(self.__class__.__name__).info( + f"{reader.__class__.__name__}: unable to load {filename}.\n" ) else: + err = [] break if img is None or reader is None: if isinstance(filename, tuple) and len(filename) == 1: filename = filename[0] + msg = "\n".join([f"{e}" for e in err]) raise RuntimeError( - f"cannot find a suitable reader for file: {filename}.\n" + f"{self.__class__.__name__} cannot find a suitable reader for file: {filename}.\n" " Please install the reader libraries, see also the installation instructions:\n" " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" - f" The current registered: {self.readers}.\n" + f" The current registered: {self.readers}.\n{msg}" ) + img_array: NdarrayOrTensor img_array, meta_data = reader.get_data(img) img_array = img_array.astype(self.dtype, copy=False) + if not isinstance(meta_data, dict): + raise ValueError("`meta_data` must be a dict.") + # make sure all elements in metadata are little endian + meta_data = switch_endianness(meta_data, "<") + if self.ensure_channel_first: + img_array = EnsureChannelFirst()(img_array, meta_data) if self.image_only: return img_array meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader - # make sure all elements in metadata are little endian - meta_data = switch_endianness(meta_data, "<") return img_array, meta_data class SaveImage(Transform): """ - Save transformed data into files, support NIfTI and PNG formats. - It can work for both numpy array and PyTorch Tensor in both preprocessing transform - chain and postprocessing transform chain. - The name of saved file will be `{input_image_name}_{output_postfix}{output_ext}`, - where the input image name is extracted from the provided meta data dictionary. - If no meta data provided, use index from 0 as the filename prefix. - It can also save a list of PyTorch Tensor or numpy array without `batch dim`. + Save the image (in the form of torch tensor or numpy ndarray) and metadata dictionary into files. - Note: image should be channel-first shape: [C,H,W,[D]]. + The name of saved file will be `{input_image_name}_{output_postfix}{output_ext}`, + where the `input_image_name` is extracted from the provided metadata dictionary. + If no metadata provided, a running index starting from 0 will be used as the filename prefix. Args: output_dir: output image directory. output_postfix: a string appended to all output file names, default to `trans`. - output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`. - resample: whether to resample before saving the data array. - if saving PNG format image, based on the `spatial_shape` from metadata. - if saving NIfTI format image, based on the `original_affine` from metadata. - mode: This option is used when ``resample = True``. Defaults to ``"nearest"``. - - - NIfTI files {``"bilinear"``, ``"nearest"``} - Interpolation mode to calculate output values. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} - The interpolation mode. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html - - padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. + output_ext: output file extension name. + output_dtype: data type for saving data. Defaults to ``np.float32``. + resample: whether to resample image (if needed) before saving the data array, + based on the `spatial_shape` (and `original_affine`) from metadata. + mode: This option is used when ``resample=True``. Defaults to ``"nearest"``. + Depending on the writers, the possible options are - - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``} - Padding mode for outside grid values. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - - PNG files - This option is ignored. + - {``"bilinear"``, ``"nearest"``, ``"bicubic"``}. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + - {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}. + See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. + Possible options are {``"zeros"``, ``"border"``, ``"reflection"``} + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling - [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. - it's used for PNG format only. + [0, 255] (uint8) or [0, 65535] (uint16). Default is `None` (no scaling). dtype: data type during resampling computation. Defaults to ``np.float64`` for best precision. if None, use the data type of input data. To be compatible with other modules, - the output data type is always ``np.float32``. - it's used for NIfTI format only. - output_dtype: data type for saving data. Defaults to ``np.float32``. - it's used for NIfTI format only. squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and - then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false, + then if C==1, it will be saved as (H,W,D). If D is also 1, it will be saved as (H,W). If `false`, image will always be saved as (H,W,D,C). - it's used for NIfTI format only. data_root_dir: if not empty, it specifies the beginning parts of the input file's - absolute path. it's used to compute `input_file_rel_path`, the relative path to the file from + absolute path. It's used to compute `input_file_rel_path`, the relative path to the file from `data_root_dir` to preserve folder structure when saving in case there are files in different - folders with the same file names. for example: - input_file_name: /foo/bar/test1/image.nii, - output_postfix: seg - output_ext: nii.gz - output_dir: /output, - data_root_dir: /foo/bar, - output will be: /output/test1/image/image_seg.nii.gz - separate_folder: whether to save every file in a separate folder, for example: if input filename is - `image.nii`, postfix is `seg` and folder_path is `output`, if `True`, save as: - `output/image/image_seg.nii`, if `False`, save as `output/image_seg.nii`. default to `True`. - print_log: whether to print log about the saved file path, etc. default to `True`. - + folders with the same file names. For example, with the following inputs: + + - input_file_name: `/foo/bar/test1/image.nii` + - output_postfix: `seg` + - output_ext: `.nii.gz` + - output_dir: `/output` + - data_root_dir: `/foo/bar` + + The output will be: /output/test1/image/image_seg.nii.gz + + separate_folder: whether to save every file in a separate folder. For example: for the input filename + `image.nii`, postfix `seg` and folder_path `output`, if `separate_folder=True`, it will be saved as: + `output/image/image_seg.nii`, if `False`, saving as `output/image_seg.nii`. Default to `True`. + print_log: whether to print logs when saving. Default to `True`. + output_format: an optional string of filename extension to specify the output image writer. + see also: `monai.data.image_writer.SUPPORTED_WRITERS`. + writer: a customised image writer to save data arrays. + if `None`, use the default writer from `monai.data.image_writer` according to `output_ext`. + channel_dim: the index of the channel dimension. Default to `0`. + `None` to indicate no channel dimension. """ def __init__( @@ -297,55 +311,98 @@ def __init__( output_dir: PathLike = "./", output_postfix: str = "trans", output_ext: str = ".nii.gz", + output_dtype: DtypeLike = np.float32, resample: bool = True, mode: Union[GridSampleMode, InterpolateMode, str] = "nearest", padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, scale: Optional[int] = None, dtype: DtypeLike = np.float64, - output_dtype: DtypeLike = np.float32, squeeze_end_dims: bool = True, data_root_dir: PathLike = "", separate_folder: bool = True, print_log: bool = True, + output_format: str = "", + writer: Optional[image_writer.ImageWriter] = None, + channel_dim: Optional[int] = 0, ) -> None: - self.saver: Union[NiftiSaver, PNGSaver] - if output_ext in {".nii.gz", ".nii"}: - self.saver = NiftiSaver( - output_dir=output_dir, - output_postfix=output_postfix, - output_ext=output_ext, - resample=resample, - mode=GridSampleMode(mode), - padding_mode=padding_mode, - dtype=dtype, - output_dtype=output_dtype, - squeeze_end_dims=squeeze_end_dims, - data_root_dir=data_root_dir, - separate_folder=separate_folder, - print_log=print_log, - ) - elif output_ext == ".png": - self.saver = PNGSaver( - output_dir=output_dir, - output_postfix=output_postfix, - output_ext=output_ext, - resample=resample, - mode=InterpolateMode(mode), - scale=scale, - data_root_dir=data_root_dir, - separate_folder=separate_folder, - print_log=print_log, - ) - else: - raise ValueError(f"unsupported output extension: {output_ext}.") + self.folder_layout = FolderLayout( + output_dir=output_dir, + postfix=output_postfix, + extension=output_ext, + parent=separate_folder, + makedirs=True, + data_root_dir=data_root_dir, + ) + + self.output_ext = output_ext.lower() or output_format.lower() + self.writers = image_writer.resolve_writer(self.output_ext) if writer is None else (writer,) + self.writer_obj = None + + _output_dtype = output_dtype + if self.output_ext == ".png" and _output_dtype not in (np.uint8, np.uint16): + _output_dtype = np.uint8 + if self.output_ext == ".dcm" and _output_dtype not in (np.uint8, np.uint16): + _output_dtype = np.uint8 + self.init_kwargs = {"output_dtype": _output_dtype, "scale": scale} + self.data_kwargs = {"squeeze_end_dims": squeeze_end_dims, "channel_dim": channel_dim} + self.meta_kwargs = {"resample": resample, "mode": mode, "padding_mode": padding_mode, "dtype": dtype} + self.write_kwargs = {"verbose": print_log} + self._data_index = 0 + + def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): + """ + Set the options for the underlying writer by updating the `self.*_kwargs` dictionaries. + + The arguments correspond to the following usage: + + - `writer = ImageWriter(**init_kwargs)` + - `writer.set_data_array(array, **data_kwargs)` + - `writer.set_metadata(meta_data, **meta_kwargs)` + - `writer.write(filename, **write_kwargs)` + + """ + if init_kwargs is not None: + self.init_kwargs.update(init_kwargs) + if data_kwargs is not None: + self.data_kwargs.update(data_kwargs) + if meta_kwargs is not None: + self.meta_kwargs.update(meta_kwargs) + if write_kwargs is not None: + self.write_kwargs.update(write_kwargs) def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None): """ Args: - img: target data content that save into file. - meta_data: key-value pairs of meta_data corresponding to the data. - + img: target data content that save into file. The image should be channel-first, shape: `[C,H,W,[D]]`. + meta_data: key-value pairs of metadata corresponding to the data. """ - self.saver.save(img, meta_data) - - return img + subject = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) + patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None + filename = self.folder_layout.filename(subject=f"{subject}", idx=patch_index) + if meta_data and len(ensure_tuple(meta_data.get("spatial_shape", ()))) == len(img.shape): + self.data_kwargs["channel_dim"] = None + + err = [] + for writer_cls in self.writers: + try: + writer_obj = writer_cls(**self.init_kwargs) + writer_obj.set_data_array(data_array=img, **self.data_kwargs) + writer_obj.set_metadata(meta_dict=meta_data, **self.meta_kwargs) + writer_obj.write(filename, **self.write_kwargs) + self.writer_obj = writer_obj + except Exception as e: + err.append(traceback.format_exc()) + logging.getLogger(self.__class__.__name__).debug(e, exc_info=True) + logging.getLogger(self.__class__.__name__).info( + f"{writer_cls.__class__.__name__}: unable to write {filename}.\n" + ) + else: + self._data_index += 1 + return img + msg = "\n".join([f"{e}" for e in err]) + raise RuntimeError( + f"{self.__class__.__name__} cannot find a suitable writer for {filename}.\n" + " Please install the writer libraries, see also the installation instructions:\n" + " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" + f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}" + ) diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index cc6a67593f..ebe06898a6 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -21,6 +21,7 @@ import numpy as np from monai.config import DtypeLike, KeysCollection +from monai.data import image_writer from monai.data.image_reader import ImageReader from monai.transforms.io.array import LoadImage, SaveImage from monai.transforms.transform import MapTransform @@ -48,13 +49,13 @@ class LoadImaged(MapTransform): - User-specified reader in the constructor of `LoadImage`. - Readers from the last to the first in the registered list. - Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader), - (npz, npy -> NumpyReader), (others -> ITKReader). + (npz, npy -> NumpyReader), (dcm, DICOM series and others -> ITKReader). Note: - If `reader` is specified, the loader will attempt to use the specified readers and the default supported readers. This might introduce overheads when handling the exceptions of trying the incompatible loaders. - In this case, it is therefore recommended to set the most appropriate reader as + In this case, it is therefore recommended setting the most appropriate reader as the last item of the `reader` parameter. See also: @@ -72,6 +73,7 @@ def __init__( meta_key_postfix: str = DEFAULT_POST_FIX, overwriting: bool = False, image_only: bool = False, + ensure_channel_first: bool = False, allow_missing_keys: bool = False, *args, **kwargs, @@ -84,7 +86,7 @@ def __init__( at runtime or use the default readers. If a string of reader name provided, will construct a reader object with the `*args` and `**kwargs` parameters, supported reader name: "NibabelReader", "PILReader", "ITKReader", "NumpyReader". - dtype: if not None convert the loaded image data to this data type. + dtype: if not None, convert the loaded image data to this data type. meta_keys: explicitly indicate the key to store the corresponding meta data dictionary. the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. @@ -92,16 +94,18 @@ def __init__( meta_key_postfix: if meta_keys is None, use `key_{postfix}` to store the metadata of the nifti image, default is `meta_dict`. The meta data is a dictionary object. For example, load nifti file for `image`, store the metadata into `image_meta_dict`. - overwriting: whether allow to overwrite existing meta data of same key. + overwriting: whether allow overwriting existing meta data of same key. default is False, which will raise exception if encountering existing key. image_only: if True return dictionary containing just only the image volumes, otherwise return dictionary containing image data array and header dict per input key. + ensure_channel_first: if `True` and loaded both image array and meta data, automatically convert + the image array shape to `channel first`. default to `False`. allow_missing_keys: don't raise exception if key is missing. args: additional parameters for reader if providing a reader name. kwargs: additional parameters for reader if providing a reader name. """ super().__init__(keys, allow_missing_keys) - self._loader = LoadImage(reader, image_only, dtype, *args, **kwargs) + self._loader = LoadImage(reader, image_only, dtype, ensure_channel_first, *args, **kwargs) if not isinstance(meta_key_postfix, str): raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.") self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) @@ -150,68 +154,61 @@ class SaveImaged(MapTransform): Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` - meta_keys: explicitly indicate the key of the corresponding meta data dictionary. - for example, for data with key `image`, the metadata by default is in `image_meta_dict`. - the meta data is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None and `key_{postfix}` was used to store the metadata in `LoadImaged`. - need the key to extract metadata to save images, default is `meta_dict`. - for example, for data with key `image`, the metadata by default is in `image_meta_dict`. - the meta data is a dictionary object which contains: filename, affine, original_shape, etc. - if no corresponding metadata, set to `None`. + meta_keys: explicitly indicate the key of the corresponding metadata dictionary. + For example, for data with key `image`, the metadata by default is in `image_meta_dict`. + The metadata is a dictionary contains values such as filename, original_shape. + This argument can be a sequence of string, map to the `keys`. + If `None`, will try to construct meta_keys by `key_{meta_key_postfix}`. + meta_key_postfix: if `meta_keys` is `None`, use `key_{meta_key_postfix}` to retrieve the metadict. output_dir: output image directory. output_postfix: a string appended to all output file names, default to `trans`. output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`. - resample: whether to resample before saving the data array. - if saving PNG format image, based on the `spatial_shape` from metadata. - if saving NIfTI format image, based on the `original_affine` from metadata. - mode: This option is used when ``resample = True``. Defaults to ``"nearest"``. - - - NIfTI files {``"bilinear"``, ``"nearest"``} - Interpolation mode to calculate output values. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} - The interpolation mode. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html - - padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. + output_dtype: data type for saving data. Defaults to ``np.float32``. + resample: whether to resample image (if needed) before saving the data array, + based on the `spatial_shape` (and `original_affine`) from metadata. + mode: This option is used when ``resample=True``. Defaults to ``"nearest"``. + Depending on the writers, the possible options are: - - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``} - Padding mode for outside grid values. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - - PNG files - This option is ignored. + - {``"bilinear"``, ``"nearest"``, ``"bicubic"``}. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + - {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}. + See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. + Possible options are {``"zeros"``, ``"border"``, ``"reflection"``} + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling - [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. - it's used for PNG format only. + [0, 255] (uint8) or [0, 65535] (uint16). Default is `None` (no scaling). dtype: data type during resampling computation. Defaults to ``np.float64`` for best precision. if None, use the data type of input data. To be compatible with other modules, - the output data type is always ``np.float32``. - it's used for NIfTI format only. output_dtype: data type for saving data. Defaults to ``np.float32``. it's used for NIfTI format only. allow_missing_keys: don't raise exception if key is missing. squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and - then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false, + then if C==1, it will be saved as (H,W,D). If D is also 1, it will be saved as (H,W). If `false`, image will always be saved as (H,W,D,C). - it's used for NIfTI format only. data_root_dir: if not empty, it specifies the beginning parts of the input file's - absolute path. it's used to compute `input_file_rel_path`, the relative path to the file from + absolute path. It's used to compute `input_file_rel_path`, the relative path to the file from `data_root_dir` to preserve folder structure when saving in case there are files in different - folders with the same file names. for example: - input_file_name: /foo/bar/test1/image.nii, - output_postfix: seg - output_ext: nii.gz - output_dir: /output, - data_root_dir: /foo/bar, - output will be: /output/test1/image/image_seg.nii.gz - separate_folder: whether to save every file in a separate folder, for example: if input filename is - `image.nii`, postfix is `seg` and folder_path is `output`, if `True`, save as: - `output/image/image_seg.nii`, if `False`, save as `output/image_seg.nii`. default to `True`. - print_log: whether to print log about the saved file path, etc. default to `True`. + folders with the same file names. For example, with the following inputs: + + - input_file_name: `/foo/bar/test1/image.nii` + - output_postfix: `seg` + - output_ext: `.nii.gz` + - output_dir: `/output` + - data_root_dir: `/foo/bar` + + The output will be: /output/test1/image/image_seg.nii.gz + + separate_folder: whether to save every file in a separate folder. For example: for the input filename + `image.nii`, postfix `seg` and folder_path `output`, if `separate_folder=True`, it will be saved as: + `output/image/image_seg.nii`, if `False`, saving as `output/image_seg.nii`. Default to `True`. + print_log: whether to print logs when saving. Default to `True`. + output_format: an optional string to specify the output image writer. + see also: `monai.data.image_writer.SUPPORTED_WRITERS`. + writer: a customised image writer to save data arrays. + if `None`, use the default writer from `monai.data.image_writer` according to `output_ext`. """ @@ -234,11 +231,13 @@ def __init__( data_root_dir: str = "", separate_folder: bool = True, print_log: bool = True, + output_format: str = "", + writer: Optional[image_writer.ImageWriter] = None, ) -> None: super().__init__(keys, allow_missing_keys) self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self._saver = SaveImage( + self.saver = SaveImage( output_dir=output_dir, output_postfix=output_postfix, output_ext=output_ext, @@ -252,15 +251,20 @@ def __init__( data_root_dir=data_root_dir, separate_folder=separate_folder, print_log=print_log, + output_format=output_format, + writer=writer, ) + def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): + self.saver.set_options(init_kwargs, data_kwargs, meta_kwargs, write_kwargs) + def __call__(self, data): d = dict(data) for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): if meta_key is None and meta_key_postfix is not None: meta_key = f"{key}_{meta_key_postfix}" meta_data = d[meta_key] if meta_key is not None else None - self._saver(img=d[key], meta_data=meta_data) + self.saver(img=d[key], meta_data=meta_data) return d diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 6a189d0bb4..dcd32c5d55 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -95,8 +95,7 @@ def __call__( raise TypeError(f"other must be None or callable but is {type(other).__name__}.") # convert to float as activation must operate on float tensor - img_t: torch.Tensor - img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) # type: ignore + img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) if sigmoid or self.sigmoid: img_t = torch.sigmoid(img_t) if softmax or self.softmax: @@ -232,8 +231,7 @@ def __call__( warnings.warn("`threshold_values=True/False` is deprecated, please use `threshold=value` instead.") threshold = logit_thresh if threshold else None - img_t: torch.Tensor - img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore + img_t, *_ = convert_data_type(img, torch.Tensor) if argmax or self.argmax: img_t = torch.argmax(img_t, dim=0, keepdim=True) @@ -496,8 +494,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ if not isinstance(img, (np.ndarray, torch.Tensor)): raise NotImplementedError(f"{self.__class__} can not handle data of type {type(img)}.") - img_np: np.ndarray - img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore + img_np, *_ = convert_data_type(img, np.ndarray) out_np: np.ndarray = fill_holes(img_np, self.applied_labels, self.connectivity) out, *_ = convert_to_dst_type(out_np, img) return out @@ -539,7 +536,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: ideally the edge should be thin enough, but now it has a thickness. """ - img_: torch.Tensor = convert_data_type(img, torch.Tensor)[0] # type: ignore + img_: torch.Tensor = convert_data_type(img, torch.Tensor)[0] spatial_dims = len(img_.shape) - 1 img_ = img_.unsqueeze(0) # adds a batch dim if spatial_dims == 2: diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 96df434397..b89196207b 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -49,6 +49,8 @@ "AsDiscreteDict", "AsDiscreted", "Ensembled", + "EnsembleD", + "EnsembleDict", "FillHolesD", "FillHolesDict", "FillHolesd", @@ -515,28 +517,25 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]): class Invertd(MapTransform): """ Utility transform to automatically invert the previously applied transforms. - When applying preprocessing transforms on a orig_key(like: `image`, `label`, etc.), we record the context - information of applied transforms in a dictionary in the input data dictionary with the key - "{orig_key}_transforms". This transform will extract the transform context information of `orig_keys` - then invert the transforms(got from this context information) on the `keys` data. - Typical usage is to invert the preprocessing transforms(applied on input `image`) on the model `pred` data. - The output of the inverted data and metadata will be stored at `keys` and `meta_keys` respectively. - To correctly invert the transforms, the information of the previously applied transforms should be - available at `orig_keys`, and the original metadata at `orig_meta_keys`. - (`meta_key_postfix` is an optional string to conveniently construct "meta_keys" and/or "orig_meta_keys".) + Taking the ``transform`` previously applied on ``orig_keys``, this ``Invertd`` will apply the inverse of it + to the data stored at ``keys``. ``Invertd``'s output will also include a copy of the metadata + dictionary (originally from ``orig_meta_keys``), with the relevant fields inverted and stored at ``meta_keys``. + + A typical usage is to apply the inverse of the preprocessing on input ``image`` to the model ``pred``. A detailed usage example is available in the tutorial: https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_inference_dict.py Note: - According to the `collate_fn`, this transform may return a list of Tensor without batch dim, - thus some following transforms may not support a list of Tensor, and users can leverage the - `post_func` arg for basic processing logic. - This transform needs to extract the context information of applied transforms and the meta data - dictionary from the input data dictionary, then use some numpy arrays in them to computes the inverse - logic, so please don't move `data["{orig_key}_transforms"]` and `data["{orig_meta_key}"]` to GPU device. + - The output of the inverted data and metadata will be stored at ``keys`` and ``meta_keys`` respectively. + - To correctly invert the transforms, the information of the previously applied transforms should be + available at ``{orig_keys}_transforms``, and the original metadata at ``orig_meta_keys``. + (``meta_key_postfix`` is an optional string to conveniently construct "meta_keys" and/or "orig_meta_keys".) + see also: :py:class:`monai.transforms.TraceableTransform`. + - The transform will not change the content in ``orig_keys`` and ``orig_meta_key``. + These keys are only used to represent the data status of ``key`` before inverting. """ @@ -556,37 +555,32 @@ def __init__( ) -> None: """ Args: - keys: the key of expected data in the dict, invert transforms on it, in-place operation. - it also can be a list of keys, will invert transform for each of them, like: ["pred", "pred_class2"]. - transform: the previous callable transform that applied on input data. - orig_keys: the key of the original input data in the dict. will get the applied transform information - for this input data, then invert them for the expected data with `keys`. - It can also be a list of keys, each matches to the `keys` data. - meta_keys: explicitly indicate the key for the inverted meta data dictionary. - the meta data is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `{key}_{meta_key_postfix}`. - orig_meta_keys: the key of the meta data of original input data, will get the `affine`, `data_shape`, etc. - the meta data is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `{orig_key}_{meta_key_postfix}`. - meta data will also be inverted and stored in `meta_keys`. - meta_key_postfix: if `orig_meta_keys` is None, use `{orig_key}_{meta_key_postfix}` to to fetch the - meta data from dict, if `meta_keys` is None, use `{key}_{meta_key_postfix}`. - default is `meta_dict`, the meta data is a dictionary object. - For example, to handle orig_key `image`, read/write `affine` matrices from the - metadata `image_meta_dict` dictionary's `affine` field. - the inverted meta dict will be stored with key: "{key}_{meta_key_postfix}". + keys: the key of expected data in the dict, the inverse of ``transforms`` will be applied on it in-place. + It also can be a list of keys, will apply the inverse transform respectively. + transform: the transform applied to ``orig_key``, its inverse will be applied on ``key``. + orig_keys: the key of the original input data in the dict. + the transform trace information of ``transforms`` should be stored at ``{orig_keys}_transforms``. + It can also be a list of keys, each matches the ``keys``. + meta_keys: The key to output the inverted meta data dictionary. + The meta data is a dictionary optionally containing: filename, original_shape. + It can be a sequence of strings, maps to ``keys``. + If None, will try to create a meta data dict with the default key: `{key}_{meta_key_postfix}`. + orig_meta_keys: the key of the meta data of original input data. + The meta data is a dictionary optionally containing: filename, original_shape. + It can be a sequence of strings, maps to the `keys`. + If None, will try to create a meta data dict with the default key: `{orig_key}_{meta_key_postfix}`. + This meta data dict will also be included in the inverted dict, stored in `meta_keys`. + meta_key_postfix: if `orig_meta_keys` is None, use `{orig_key}_{meta_key_postfix}` to fetch the + meta data from dict, if `meta_keys` is None, use `{key}_{meta_key_postfix}`. Default: ``"meta_dict"``. nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms, default to `True`. If `False`, use the same interpolation mode as the original transform. - it also can be a list of bool, each matches to the `keys` data. + It also can be a list of bool, each matches to the `keys` data. to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`. - it also can be a list of bool, each matches to the `keys` data. + It also can be a list of bool, each matches to the `keys` data. device: if converted to Tensor, move the inverted results to target device before `post_func`, - default to "cpu", it also can be a list of string or `torch.device`, - each matches to the `keys` data. + default to "cpu", it also can be a list of string or `torch.device`, each matches to the `keys` data. post_func: post processing for the inverted data, should be a callable function. - it also can be a list of callable, each matches to the `keys` data. + It also can be a list of callable, each matches to the `keys` data. allow_missing_keys: don't raise exception if key is missing. """ @@ -643,10 +637,10 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: input = d[key] if isinstance(input, torch.Tensor): input = input.detach() - # construct the input dict data for BatchInverseTransform + + # construct the input dict data input_dict = {orig_key: input, transform_key: transform_info} orig_meta_key = orig_meta_key or f"{orig_key}_{meta_key_postfix}" - meta_key = meta_key or f"{key}_{meta_key_postfix}" if orig_meta_key in d: input_dict[orig_meta_key] = d[orig_meta_key] @@ -655,8 +649,10 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: # save the inverted data d[key] = post_func(self._totensor(inverted[orig_key]).to(device) if to_tensor else inverted[orig_key]) + # save the inverted meta dict if orig_meta_key in d: + meta_key = meta_key or f"{key}_{meta_key_postfix}" d[meta_key] = inverted.get(orig_meta_key) return d @@ -752,3 +748,4 @@ def get_saver(self): ProbNMSD = ProbNMSDict = ProbNMSd SaveClassificationD = SaveClassificationDict = SaveClassificationd VoteEnsembleD = VoteEnsembleDict = VoteEnsembled +EnsembleD = EnsembleDict = Ensembled diff --git a/monai/transforms/smooth_field/array.py b/monai/transforms/smooth_field/array.py index 356b0d167f..f581687ea5 100644 --- a/monai/transforms/smooth_field/array.py +++ b/monai/transforms/smooth_field/array.py @@ -120,7 +120,7 @@ def __call__(self, randomize=False) -> torch.Tensor: if self.spatial_zoom is not None: resized_field = interpolate( # type: ignore - input=field, # type: ignore + input=field, scale_factor=self.spatial_zoom, mode=look_up_option(self.mode, InterpolateMode).value, align_corners=self.align_corners, @@ -232,7 +232,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen # everything below here is to be computed using the destination type (numpy, tensor, etc.) img = (img - img_min) / (img_rng + 1e-10) # rescale to unit values - img = img ** rfield # contrast is changed by raising image data to a power, in this case the field + img = img**rfield # contrast is changed by raising image data to a power, in this case the field out = (img * img_rng) + img_min # rescale back to the original image value range diff --git a/monai/transforms/smooth_field/dictionary.py b/monai/transforms/smooth_field/dictionary.py index c129d14f32..24890140cc 100644 --- a/monai/transforms/smooth_field/dictionary.py +++ b/monai/transforms/smooth_field/dictionary.py @@ -26,7 +26,17 @@ from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, ensure_tuple_rep from monai.utils.enums import TransformBackends -__all__ = ["RandSmoothFieldAdjustContrastd", "RandSmoothFieldAdjustIntensityd", "RandSmoothDeformd"] +__all__ = [ + "RandSmoothFieldAdjustContrastd", + "RandSmoothFieldAdjustIntensityd", + "RandSmoothDeformd", + "RandSmoothFieldAdjustContrastD", + "RandSmoothFieldAdjustIntensityD", + "RandSmoothDeformD", + "RandSmoothFieldAdjustContrastDict", + "RandSmoothFieldAdjustIntensityDict", + "RandSmoothDeformDict", +] InterpolateModeType = Union[InterpolateMode, str] @@ -276,3 +286,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable d[key] = self.trans(d[key], False, self.trans.device) return d + + +RandSmoothDeformD = RandSmoothDeformDict = RandSmoothDeformd +RandSmoothFieldAdjustIntensityD = RandSmoothFieldAdjustIntensityDict = RandSmoothFieldAdjustIntensityd +RandSmoothFieldAdjustContrastD = RandSmoothFieldAdjustContrastDict = RandSmoothFieldAdjustContrastd diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 60e4f564c8..2ed2ae42c7 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -20,9 +20,9 @@ from monai.config import USE_COMPILED, DtypeLike from monai.config.type_definitions import NdarrayOrTensor -from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine +from monai.data.utils import AFFINE_TOL, compute_shape_offset, reorient_spatial_axes, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull -from monai.networks.utils import meshgrid_ij +from monai.networks.utils import meshgrid_ij, normalize_transform from monai.transforms.croppad.array import CenterSpatialCrop, Pad from monai.transforms.transform import Randomizable, RandomizableTransform, ThreadUnsafe, Transform from monai.transforms.utils import ( @@ -34,6 +34,7 @@ create_translate, map_spatial_axes, ) +from monai.transforms.utils_pytorch_numpy_unification import allclose, moveaxis from monai.utils import ( GridSampleMode, GridSamplePadMode, @@ -46,15 +47,17 @@ fall_back_tuple, issequenceiterable, optional_import, + pytorch_after, ) from monai.utils.deprecate_utils import deprecated_arg from monai.utils.enums import TransformBackends from monai.utils.module import look_up_option from monai.utils.type_conversion import convert_data_type, convert_to_dst_type -nib, _ = optional_import("nibabel") +nib, has_nib = optional_import("nibabel") __all__ = [ + "SpatialResample", "Spacing", "Orientation", "Flip", @@ -82,16 +85,198 @@ RandRange = Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] +class SpatialResample(Transform): + """ + Resample input image from the orientation/spacing defined by ``src_affine`` affine matrix into + the ones specified by ``dst_affine`` affine matrix. + + Internally this transform computes the affine transform matrix from ``src_affine`` to ``dst_affine``, + by ``xform = linalg.solve(src_affine, dst_affine)``, and call ``monai.transforms.Affine`` with ``xform``. + """ + + backend = [TransformBackends.TORCH] + + def __init__( + self, + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + align_corners: bool = False, + dtype: DtypeLike = np.float64, + ): + """ + Args: + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + align_corners: Geometrically, we consider the pixels of the input as squares rather than points. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. + If ``None``, use the data type of input data. To be compatible with other modules, + the output data type is always ``np.float32``. + """ + self.mode = mode + self.padding_mode = padding_mode + self.align_corners = align_corners + self.dtype = dtype + + def __call__( + self, + img: NdarrayOrTensor, + src_affine: Optional[NdarrayOrTensor] = None, + dst_affine: Optional[NdarrayOrTensor] = None, + spatial_size: Optional[Union[Sequence[int], np.ndarray, int]] = None, + mode: Union[GridSampleMode, str, None] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str, None] = GridSamplePadMode.BORDER, + align_corners: Optional[bool] = False, + dtype: DtypeLike = None, + ) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: + """ + Args: + img: input image to be resampled. It currently supports channel-first arrays with + at most three spatial dimensions. + src_affine: source affine matrix. Defaults to ``None``, which means the identity matrix. + the shape should be `(r+1, r+1)` where `r` is the spatial rank of ``img``. + dst_affine: destination affine matrix. Defaults to ``None``, which means the same as `src_affine`. + the shape should be `(r+1, r+1)` where `r` is the spatial rank of ``img``. + when `dst_affine` and `spatial_size` are None, the input will be returned without resampling, + but the data type will be `float32`. + spatial_size: output image spatial size. + if `spatial_size` and `self.spatial_size` are not defined, + the transform will compute a spatial size automatically containing the previous field of view. + if `spatial_size` is ``-1`` are the transform will use the corresponding input img size. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + align_corners: Geometrically, we consider the pixels of the input as squares rather than points. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + dtype: data type for resampling computation. Defaults to ``self.dtype`` or + ``np.float64`` (for best precision). If ``None``, use the data type of input data. + To be compatible with other modules, the output data type is always `float32`. + + The spatial rank is determined by the smallest among ``img.ndim -1``, ``len(src_affine) - 1``, and ``3``. + + When both ``monai.config.USE_COMPILED`` and ``align_corners`` are set to ``True``, + MONAI's resampling implementation will be used. + Set `dst_affine` and `spatial_size` to `None` to turn off the resampling step. + """ + if src_affine is None: + src_affine = np.eye(4, dtype=np.float64) + spatial_rank = min(len(img.shape) - 1, src_affine.shape[0] - 1, 3) + if (not isinstance(spatial_size, int) or spatial_size != -1) and spatial_size is not None: + spatial_rank = min(len(ensure_tuple(spatial_size)), 3) # infer spatial rank based on spatial_size + src_affine = to_affine_nd(spatial_rank, src_affine) + dst_affine = to_affine_nd(spatial_rank, dst_affine) if dst_affine is not None else src_affine + dst_affine, *_ = convert_to_dst_type(dst_affine, dst_affine, dtype=torch.float32) + + in_spatial_size = np.asarray(img.shape[1 : spatial_rank + 1]) + if isinstance(spatial_size, int) and (spatial_size == -1): # using the input spatial size + spatial_size = in_spatial_size + elif spatial_size is None and spatial_rank > 1: # auto spatial size + spatial_size, _ = compute_shape_offset(in_spatial_size, src_affine, dst_affine) # type: ignore + spatial_size = np.asarray(fall_back_tuple(ensure_tuple(spatial_size)[:spatial_rank], in_spatial_size)) + + if ( + allclose(src_affine, dst_affine, atol=AFFINE_TOL) + and allclose(spatial_size, in_spatial_size) + or spatial_rank == 1 + ): + # no significant change, return original image + output_data, *_ = convert_to_dst_type(img, img, dtype=torch.float32) + return output_data, dst_affine + + if has_nib and isinstance(img, np.ndarray): + spatial_ornt, dst_r = reorient_spatial_axes(img.shape[1 : spatial_rank + 1], src_affine, dst_affine) + if allclose(dst_r, dst_affine, atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size): + # simple reorientation achieves the desired affine + spatial_ornt[:, 0] += 1 + spatial_ornt = np.concatenate([np.array([[0, 1]]), spatial_ornt]) + img_ = nib.orientations.apply_orientation(img, spatial_ornt) + output_data, *_ = convert_to_dst_type(img_, img, dtype=torch.float32) + return output_data, dst_affine + + try: + src_affine, *_ = convert_to_dst_type(src_affine, dst_affine) + if isinstance(src_affine, np.ndarray): + xform = np.linalg.solve(src_affine, dst_affine) + else: + xform = ( + torch.linalg.solve(src_affine, dst_affine) + if pytorch_after(1, 8, 0) + else torch.solve(dst_affine, src_affine).solution # type: ignore + ) + except (np.linalg.LinAlgError, RuntimeError) as e: + raise ValueError(f"src affine is not invertible: {src_affine}") from e + xform = to_affine_nd(spatial_rank, xform) + # no resampling if it's identity transform + if allclose(xform, np.diag(np.ones(len(xform))), atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size): + output_data, *_ = convert_to_dst_type(img, img, dtype=torch.float32) + return output_data, dst_affine + + _dtype = dtype or self.dtype or img.dtype + in_spatial_size = in_spatial_size.tolist() + chns, additional_dims = img.shape[0], img.shape[spatial_rank + 1 :] # beyond three spatial dims + # resample + img_ = convert_data_type(img, torch.Tensor, dtype=_dtype)[0] + xform = convert_to_dst_type(xform, img_)[0] + align_corners = self.align_corners if align_corners is None else align_corners + mode = mode or self.mode + padding_mode = padding_mode or self.padding_mode + if additional_dims: + xform_shape = [-1] + in_spatial_size + img_ = img_.reshape(xform_shape) + if align_corners: + _t_r = torch.diag(torch.ones(len(xform), dtype=xform.dtype, device=xform.device)) # type: ignore + for idx, d_dst in enumerate(spatial_size[:spatial_rank]): + _t_r[idx, -1] = (max(d_dst, 2) - 1.0) / 2.0 + xform = xform @ _t_r + if not USE_COMPILED: + _t_l = normalize_transform( + in_spatial_size, xform.device, xform.dtype, align_corners=True # type: ignore + ) + xform = _t_l @ xform # type: ignore + affine_xform = Affine( + affine=xform, spatial_size=spatial_size, norm_coords=False, image_only=True, dtype=_dtype + ) + output_data = affine_xform(img_, mode=mode, padding_mode=padding_mode) + else: + affine_xform = AffineTransform( + normalized=False, + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + reverse_indexing=True, + ) + output_data = affine_xform(img_.unsqueeze(0), theta=xform, spatial_size=spatial_size).squeeze(0) + if additional_dims: + full_shape = (chns, *spatial_size, *additional_dims) + output_data = output_data.reshape(full_shape) + # output dtype float + output_data, *_ = convert_to_dst_type(output_data, img, dtype=torch.float32) + return output_data, dst_affine + + class Spacing(Transform): """ Resample input image into the specified `pixdim`. """ - backend = [TransformBackends.TORCH] + backend = SpatialResample.backend def __init__( self, - pixdim: Union[Sequence[float], float], + pixdim: Union[Sequence[float], float, np.ndarray], diagonal: bool = False, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, @@ -122,6 +307,9 @@ def __init__( mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html @@ -135,12 +323,15 @@ def __init__( """ self.pixdim = np.array(ensure_tuple(pixdim), dtype=np.float64) self.diagonal = diagonal - self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) - self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) - self.align_corners = align_corners - self.dtype = dtype self.image_only = image_only + self.sp_resample = SpatialResample( + mode=look_up_option(mode, GridSampleMode), + padding_mode=look_up_option(padding_mode, GridSamplePadMode), + align_corners=align_corners, + dtype=dtype, + ) + def __call__( self, data_array: NdarrayOrTensor, @@ -149,7 +340,7 @@ def __call__( padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, dtype: DtypeLike = None, - output_spatial_shape: Optional[np.ndarray] = None, + output_spatial_shape: Optional[Union[Sequence[int], np.ndarray, int]] = None, ) -> Union[NdarrayOrTensor, Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor]]: """ Args: @@ -158,6 +349,9 @@ def __call__( mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html @@ -178,7 +372,6 @@ def __call__( data_array (resampled into `self.pixdim`), original affine, current affine. """ - _dtype = dtype or self.dtype or data_array.dtype sr = int(data_array.ndim - 1) if sr <= 0: raise ValueError("data_array must have at least one spatial dimension.") @@ -187,7 +380,7 @@ def __call__( affine_np = affine = np.eye(sr + 1, dtype=np.float64) affine_ = np.eye(sr + 1, dtype=np.float64) else: - affine_np, *_ = convert_data_type(affine, np.ndarray) # type: ignore + affine_np, *_ = convert_data_type(affine, np.ndarray) affine_ = to_affine_nd(sr, affine_np) out_d = self.pixdim[:sr] @@ -198,33 +391,17 @@ def __call__( new_affine = zoom_affine(affine_, out_d, diagonal=self.diagonal) output_shape, offset = compute_shape_offset(data_array.shape[1:], affine_, new_affine) new_affine[:sr, -1] = offset[:sr] - transform = np.linalg.inv(affine_) @ new_affine - # adapt to the actual rank - transform = to_affine_nd(sr, transform) - - # no resampling if it's identity transform - if np.allclose(transform, np.diag(np.ones(len(transform))), atol=1e-3): - output_data = data_array - else: - # resample - affine_xform = AffineTransform( - normalized=False, - mode=look_up_option(mode or self.mode, GridSampleMode), - padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode), - align_corners=self.align_corners if align_corners is None else align_corners, - reverse_indexing=True, - ) - data_array_t: torch.Tensor - data_array_t, *_ = convert_data_type(data_array, torch.Tensor, dtype=_dtype) # type: ignore - output_data = affine_xform( - # AffineTransform requires a batch dim - data_array_t.unsqueeze(0), - convert_data_type(transform, torch.Tensor, data_array_t.device, dtype=_dtype)[0], - spatial_size=output_shape if output_spatial_shape is None else output_spatial_shape, - ).squeeze(0) - - output_data, *_ = convert_to_dst_type(output_data, data_array, dtype=torch.float32) - new_affine = to_affine_nd(affine_np, new_affine) # type: ignore + output_data, new_affine = self.sp_resample( + data_array, + src_affine=affine, + dst_affine=new_affine, + spatial_size=list(output_shape) if output_spatial_shape is None else output_spatial_shape, + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + dtype=dtype, + ) + new_affine = to_affine_nd(affine_np, new_affine) new_affine, *_ = convert_to_dst_type(src=new_affine, dst=affine, dtype=torch.float32) if self.image_only: @@ -297,12 +474,13 @@ def __call__( sr = len(spatial_shape) if sr <= 0: raise ValueError("data_array must have at least one spatial dimension.") + affine_: np.ndarray if affine is None: # default to identity affine_np = affine = np.eye(sr + 1, dtype=np.float64) affine_ = np.eye(sr + 1, dtype=np.float64) else: - affine_np, *_ = convert_data_type(affine, np.ndarray) # type: ignore + affine_np, *_ = convert_data_type(affine, np.ndarray) affine_ = to_affine_nd(sr, affine_np) src = nib.io_orientation(affine_) @@ -311,6 +489,12 @@ def __call__( else: if self.axcodes is None: raise ValueError("Incompatible values: axcodes=None and as_closest_canonical=True.") + if sr < len(self.axcodes): + warnings.warn( + f"axcodes ('{self.axcodes}') length is smaller than the number of input spatial dimensions D={sr}.\n" + f"{self.__class__.__name__}: input spatial shape is {spatial_shape}, num. channels is {data_array.shape[0]}," + "please make sure the input is in the channel-first format." + ) dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=self.labels) if len(dst) < sr: raise ValueError( @@ -333,7 +517,7 @@ def __call__( data_array = data_array.permute(full_transpose.tolist()) # type: ignore else: data_array = data_array.transpose(full_transpose) # type: ignore - out, *_ = convert_to_dst_type(src=data_array, dst=data_array) # type: ignore + out, *_ = convert_to_dst_type(src=data_array, dst=data_array) new_affine = to_affine_nd(affine_np, new_affine) new_affine, *_ = convert_to_dst_type(src=new_affine, dst=affine, dtype=torch.float32) @@ -429,7 +613,7 @@ def __call__( ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. """ - img_, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) # type: ignore + img_, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) if self.size_mode == "all": input_ndim = img_.ndim - 1 # spatial ndim output_ndim = len(ensure_tuple(self.spatial_size)) @@ -448,8 +632,8 @@ def __call__( raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") scale = self.spatial_size / max(img_size) spatial_size_ = tuple(int(round(s * scale)) for s in img_size) - resized = torch.nn.functional.interpolate( # type: ignore - input=img_.unsqueeze(0), # type: ignore + resized = torch.nn.functional.interpolate( + input=img_.unsqueeze(0), size=spatial_size_, mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value, align_corners=self.align_corners if align_corners is None else align_corners, @@ -475,7 +659,7 @@ class Rotate(Transform, ThreadUnsafe): See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html align_corners: Defaults to False. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. + dtype: data type for resampling computation. Defaults to ``np.float32``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. """ @@ -489,7 +673,7 @@ def __init__( mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, align_corners: bool = False, - dtype: Union[DtypeLike, torch.dtype] = np.float64, + dtype: Union[DtypeLike, torch.dtype] = np.float32, ) -> None: self.angle = angle self.keep_size = keep_size @@ -530,8 +714,7 @@ def __call__( """ _dtype = dtype or self.dtype or img.dtype - img_t: torch.Tensor - img_t, *_ = convert_data_type(img, torch.Tensor, dtype=_dtype) # type: ignore + img_t, *_ = convert_data_type(img, torch.Tensor, dtype=_dtype) im_shape = np.asarray(img_t.shape[1:]) # spatial dimensions input_ndim = len(im_shape) @@ -551,8 +734,7 @@ def __call__( shift_1 = create_translate(input_ndim, (-(output_shape - 1) / 2).tolist()) transform = shift @ transform @ shift_1 - transform_t: torch.Tensor - transform_t, *_ = convert_to_dst_type(transform, img_t) # type: ignore + transform_t, *_ = convert_to_dst_type(transform, img_t) xform = AffineTransform( normalized=False, @@ -649,8 +831,7 @@ def __call__( See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html """ - img_t: torch.Tensor - img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float32) # type: ignore + img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float32) _zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim zoomed: NdarrayOrTensor = torch.nn.functional.interpolate( # type: ignore @@ -784,7 +965,7 @@ class RandRotate(RandomizableTransform): See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html align_corners: Defaults to False. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. + dtype: data type for resampling computation. Defaults to ``np.float32``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. """ @@ -801,7 +982,7 @@ def __init__( mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, align_corners: bool = False, - dtype: Union[DtypeLike, torch.dtype] = np.float64, + dtype: Union[DtypeLike, torch.dtype] = np.float32, ) -> None: RandomizableTransform.__init__(self, prob) self.range_x = ensure_tuple(range_x) @@ -1085,6 +1266,9 @@ class AffineGrid(Transform): pixel/voxel relative to the center of the input image. Defaults to no translation. scale_params: scale factor for every spatial dims. a tuple of 2 floats for 2D, a tuple of 3 floats for 3D. Defaults to `1.0`. + dtype: data type for the grid computation. Defaults to ``np.float32``. + If ``None``, use the data type of input data (if `grid` is provided). + device: device on which the tensor will be allocated, if a new grid is generated. affine: If applied, ignore the params (`rotate_params`, etc.) and use the supplied matrix. Should be square with each side = num of image spatial dimensions + 1. @@ -1105,6 +1289,7 @@ def __init__( scale_params: Optional[Union[Sequence[float], float]] = None, as_tensor_output: bool = True, device: Optional[torch.device] = None, + dtype: DtypeLike = np.float32, affine: Optional[NdarrayOrTensor] = None, ) -> None: self.rotate_params = rotate_params @@ -1112,6 +1297,7 @@ def __init__( self.translate_params = translate_params self.scale_params = scale_params self.device = device + self.dtype = dtype self.affine = affine def __call__( @@ -1130,12 +1316,10 @@ def __call__( ValueError: When ``grid=None`` and ``spatial_size=None``. Incompatible values. """ - if grid is None: - if spatial_size is not None: - grid = create_grid(spatial_size, device=self.device, backend="torch") - else: + if grid is None: # create grid from spatial_size + if spatial_size is None: raise ValueError("Incompatible values: grid=None and spatial_size=None.") - + grid = create_grid(spatial_size, device=self.device, backend="torch", dtype=self.dtype) _b = TransformBackends.TORCH if isinstance(grid, torch.Tensor) else TransformBackends.NUMPY _device = grid.device if isinstance(grid, torch.Tensor) else self.device affine: NdarrayOrTensor @@ -1157,7 +1341,7 @@ def __call__( else: affine = self.affine - grid, *_ = convert_data_type(grid, torch.Tensor, device=_device, dtype=float) + grid, *_ = convert_data_type(grid, torch.Tensor, device=_device, dtype=self.dtype or grid.dtype) affine, *_ = convert_to_dst_type(affine, grid) grid = (affine @ grid.reshape((grid.shape[0], -1))).reshape([-1] + list(grid.shape[1:])) @@ -1339,7 +1523,9 @@ def __init__( mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, as_tensor_output: bool = True, + norm_coords: bool = True, device: Optional[torch.device] = None, + dtype: DtypeLike = np.float64, ) -> None: """ computes output image using values from `img`, locations from `grid` using pytorch. @@ -1352,7 +1538,17 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull + norm_coords: whether to normalize the coordinates from `[-(size-1)/2, (size-1)/2]` to + `[0, size - 1]` (for ``monai/csrc`` implementation) or + `[-1, 1]` (for torch ``grid_sample`` implementation) to be compatible with the underlying + resampling API. device: device on which the tensor will be allocated. + dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. + If ``None``, use the data type of input data. To be compatible with other modules, + the output data type is always `float32`. .. deprecated:: 0.6.0 ``as_tensor_output`` is deprecated. @@ -1360,7 +1556,9 @@ def __init__( """ self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) + self.norm_coords = norm_coords self.device = device + self.dtype = dtype def __call__( self, @@ -1368,55 +1566,65 @@ def __call__( grid: Optional[NdarrayOrTensor] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, + dtype: DtypeLike = None, ) -> NdarrayOrTensor: """ Args: img: shape must be (num_channels, H, W[, D]). grid: shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. + if ``norm_coords`` is True, the grid values must be in `[-(size-1)/2, (size-1)/2]`. + if ``USE_COMPILED=True`` and ``norm_coords=False``, grid values must be in `[0, size-1]`. + if ``USE_COMPILED=False`` and ``norm_coords=False``, grid values must be in `[-1, 1]`. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + dtype: data type for resampling computation. Defaults to ``self.dtype``. + To be compatible with other modules, the output data type is always `float32`. + + See also: + :py:const:`monai.config.USE_COMPILED` """ if grid is None: raise ValueError("Unknown grid.") _device = img.device if isinstance(img, torch.Tensor) else self.device - img_t: torch.Tensor - grid_t: torch.Tensor - img_t, *_ = convert_data_type(img, torch.Tensor, device=_device, dtype=torch.float32) # type: ignore - grid_t, *_ = convert_to_dst_type(grid, img_t) # type: ignore + _dtype = dtype or self.dtype or img.dtype + img_t, *_ = convert_data_type(img, torch.Tensor, device=_device, dtype=_dtype) + grid_t = convert_to_dst_type(grid, img_t)[0] + if grid_t is grid: # copy if needed (convert_data_type converts to contiguous) + grid_t = grid_t.clone(memory_format=torch.contiguous_format) + sr = min(len(img_t.shape[1:]), 3) if USE_COMPILED: - for i, dim in enumerate(img_t.shape[1:]): - grid_t[i] += (dim - 1.0) / 2.0 - grid_t = grid_t[:-1] / grid_t[-1:] - grid_t = grid_t.permute(list(range(grid_t.ndimension()))[1:] + [0]) - _padding_mode = look_up_option( - self.padding_mode if padding_mode is None else padding_mode, GridSamplePadMode - ).value - if _padding_mode == "zeros": - bound = 7 - elif _padding_mode == "border": - bound = 0 + if self.norm_coords: + for i, dim in enumerate(img_t.shape[1 : 1 + sr]): + grid_t[i] = (max(dim, 2) / 2.0 - 0.5 + grid_t[i]) / grid_t[-1:] + grid_t = moveaxis(grid_t[:sr], 0, -1) # type: ignore + _padding_mode = self.padding_mode if padding_mode is None else padding_mode + _padding_mode = _padding_mode.value if isinstance(_padding_mode, GridSamplePadMode) else _padding_mode + bound = 1 if _padding_mode == "reflection" else _padding_mode + _interp_mode = self.mode if mode is None else mode + _interp_mode = _interp_mode.value if isinstance(_interp_mode, GridSampleMode) else _interp_mode + if _interp_mode == "bicubic": + interp = 3 + elif _interp_mode == "bilinear": + interp = 1 else: - bound = 1 - _interp_mode = look_up_option(self.mode if mode is None else mode, GridSampleMode).value + interp = _interp_mode # type: ignore out = grid_pull( - img_t.unsqueeze(0), - grid_t.unsqueeze(0), - bound=bound, - extrapolate=True, - interpolation=1 if _interp_mode == "bilinear" else _interp_mode, + img_t.unsqueeze(0), grid_t.unsqueeze(0), bound=bound, extrapolate=True, interpolation=interp )[0] else: - for i, dim in enumerate(img_t.shape[1:]): - grid_t[i] = 2.0 * grid_t[i] / (dim - 1.0) - grid_t = grid_t[:-1] / grid_t[-1:] - index_ordering: List[int] = list(range(img_t.ndimension() - 2, -1, -1)) - grid_t = grid_t[index_ordering] - grid_t = grid_t.permute(list(range(grid_t.ndimension()))[1:] + [0]) + if self.norm_coords: + for i, dim in enumerate(img_t.shape[1 : 1 + sr]): + grid_t[i] = 2.0 / (max(2, dim) - 1.0) * grid_t[i] / grid_t[-1:] + index_ordering: List[int] = list(range(sr - 1, -1, -1)) + grid_t = moveaxis(grid_t[index_ordering], 0, -1) # type: ignore out = torch.nn.functional.grid_sample( img_t.unsqueeze(0), grid_t.unsqueeze(0), @@ -1424,8 +1632,7 @@ def __call__( padding_mode=self.padding_mode.value if padding_mode is None else GridSamplePadMode(padding_mode).value, align_corners=True, )[0] - out_val: NdarrayOrTensor - out_val, *_ = convert_to_dst_type(out, dst=img, dtype=out.dtype) + out_val, *_ = convert_to_dst_type(out, dst=img, dtype=np.float32) return out_val @@ -1449,8 +1656,10 @@ def __init__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION, + norm_coords: bool = True, as_tensor_output: bool = True, device: Optional[torch.device] = None, + dtype: DtypeLike = np.float32, image_only: bool = False, ) -> None: """ @@ -1485,10 +1694,21 @@ def __init__( mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + norm_coords: whether to normalize the coordinates from `[-(size-1)/2, (size-1)/2]` to + `[0, size - 1]` or `[-1, 1]` to be compatible with the underlying resampling API. + If the coordinates are generated by ``monai.transforms.utils.create_grid`` + and the ``affine`` doesn't include the normalization, this argument should be set to ``True``. + If the output `self.affine_grid` is already normalized, this argument should be set to ``False``. device: device on which the tensor will be allocated. + dtype: data type for resampling computation. Defaults to ``np.float32``. + If ``None``, use the data type of input data. To be compatible with other modules, + the output data type is always `float32`. image_only: if True return only the image volume, otherwise return (image, affine). .. deprecated:: 0.6.0 @@ -1501,10 +1721,11 @@ def __init__( translate_params=translate_params, scale_params=scale_params, affine=affine, + dtype=dtype, device=device, ) self.image_only = image_only - self.resampler = Resample(device=device) + self.resampler = Resample(norm_coords=norm_coords, device=device, dtype=dtype) self.spatial_size = spatial_size self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) @@ -1527,6 +1748,9 @@ def __call__( mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 963015f420..eaff3be35d 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -24,6 +24,7 @@ from monai.config import DtypeLike, KeysCollection from monai.config.type_definitions import NdarrayOrTensor +from monai.data.utils import affine_to_spacing from monai.networks.layers import AffineTransform from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad @@ -46,6 +47,7 @@ Rotate, Rotate90, Spacing, + SpatialResample, Zoom, ) from monai.transforms.transform import MapTransform, RandomizableTransform @@ -68,6 +70,7 @@ nib, _ = optional_import("nibabel") __all__ = [ + "SpatialResampled", "Spacingd", "Orientationd", "Rotate90d", @@ -86,6 +89,8 @@ "RandRotated", "Zoomd", "RandZoomd", + "SpatialResampleD", + "SpatialResampleDict", "SpacingD", "SpacingDict", "OrientationD", @@ -131,6 +136,160 @@ DEFAULT_POST_FIX = PostFix.meta() +class SpatialResampled(MapTransform, InvertibleTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.SpatialResample`. + + This transform assumes the ``data`` dictionary has a key for the input + data's metadata and contains ``src_affine`` and ``dst_affine`` required by + `SpatialResample`. The key is formed by ``key_{meta_key_postfix}``. The + transform will swap ``src_affine`` and ``dst_affine`` affine (with potential data type + changes) in the dictionary so that ``src_affine`` always refers to the current + status of affine. + + See also: + :py:class:`monai.transforms.SpatialResample` + """ + + backend = SpatialResample.backend + + def __init__( + self, + keys: KeysCollection, + mode: GridSampleModeSequence = GridSampleMode.BILINEAR, + padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, + align_corners: Union[Sequence[bool], bool] = False, + dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, + meta_keys: Optional[KeysCollection] = None, + meta_key_postfix: str = DEFAULT_POST_FIX, + meta_src_keys: Optional[KeysCollection] = "src_affine", + meta_dst_keys: Optional[KeysCollection] = "dst_affine", + allow_missing_keys: bool = False, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of string, each element corresponds to a key in ``keys``. + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of string, each element corresponds to a key in ``keys``. + align_corners: Geometrically, we consider the pixels of the input as squares rather than points. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of bool, each element corresponds to a key in ``keys``. + dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. + If None, use the data type of input data. To be compatible with other modules, + the output data type is always ``np.float32``. + It also can be a sequence of dtypes, each element corresponds to a key in ``keys``. + meta_keys: explicitly indicate the key of the corresponding meta data dictionary. + for example, for data with key `image`, the metadata by default is in `image_meta_dict`. + the meta data is a dictionary object which contains: filename, affine, original_shape, etc. + it can be a sequence of string, map to the `keys`. + if None, will try to construct meta_keys by `key_{meta_key_postfix}`. + meta_key_postfix: if meta_keys=None, use `key_{postfix}` to fetch the meta data according + to the key data, default is `meta_dict`, the meta data is a dictionary object. + For example, to handle key `image`, read/write affine matrices from the + metadata `image_meta_dict` dictionary's `affine` field. + meta_src_keys: the key of the corresponding ``src_affine`` in the metadata dictionary. + meta_dst_keys: the key of the corresponding ``dst_affine`` in the metadata dictionary. + allow_missing_keys: don't raise exception if key is missing. + """ + super().__init__(keys, allow_missing_keys) + self.sp_transform = SpatialResample() + self.mode = ensure_tuple_rep(mode, len(self.keys)) + self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) + self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) + self.dtype = ensure_tuple_rep(dtype, len(self.keys)) + self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) + if len(self.keys) != len(self.meta_keys): + raise ValueError("meta_keys should have the same length as keys.") + self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) + self.meta_src_keys = ensure_tuple_rep(meta_src_keys, len(self.keys)) + self.meta_dst_keys = ensure_tuple_rep(meta_dst_keys, len(self.keys)) + + def __call__( + self, data: Mapping[Union[Hashable, str], Dict[str, NdarrayOrTensor]] + ) -> Dict[Hashable, NdarrayOrTensor]: + d: Dict = dict(data) + for (key, mode, padding_mode, align_corners, dtype, *metakeyinfo) in self.key_iterator( + d, + self.mode, + self.padding_mode, + self.align_corners, + self.dtype, + self.meta_keys, + self.meta_key_postfix, + self.meta_src_keys, + self.meta_dst_keys, + ): + meta_key, meta_key_postfix, meta_src_key, meta_dst_key = metakeyinfo + meta_key = meta_key or f"{key}_{meta_key_postfix}" + # create metadata if necessary + if meta_key not in d: + d[meta_key] = {meta_src_key: None, meta_dst_key: None} + meta_data = d[meta_key] + original_spatial_shape = d[key].shape[1:] + d[key], meta_data[meta_dst_key] = self.sp_transform( # write dst affine because the dtype might change + img=d[key], + src_affine=meta_data[meta_src_key], + dst_affine=meta_data[meta_dst_key], + spatial_size=None, # None means shape auto inferred + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + dtype=dtype, + ) + meta_data[meta_dst_key], meta_data[meta_src_key] = meta_data[meta_src_key], meta_data[meta_dst_key] + self.push_transform( + d, + key, + extra_info={ + "meta_key": meta_key, + "meta_src_key": meta_src_key, + "meta_dst_key": meta_dst_key, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, + }, + orig_size=original_spatial_shape, + ) + return d + + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = deepcopy(dict(data)) + for key, dtype in self.key_iterator(d, self.dtype): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + meta_data = d[transform[TraceKeys.EXTRA_INFO]["meta_key"]] + src_key = transform[TraceKeys.EXTRA_INFO]["meta_src_key"] + dst_key = transform[TraceKeys.EXTRA_INFO]["meta_dst_key"] + src_affine = meta_data[src_key] + dst_affine = meta_data[dst_key] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] + orig_size = transform[TraceKeys.ORIG_SIZE] + inverse_transform = SpatialResample() + # Apply inverse + d[key], dst_affine = inverse_transform( + img=d[key], + src_affine=src_affine, + dst_affine=dst_affine, + mode=mode, + padding_mode=padding_mode, + align_corners=False if align_corners == TraceKeys.NONE else align_corners, + dtype=dtype, + spatial_size=orig_size, + ) + meta_data[src_key], meta_data[dst_key] = dst_affine, meta_data[src_key] # type: ignore + # Remove the applied transform + self.pop_transform(d, key) + return d + + class Spacingd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Spacing`. @@ -155,7 +314,7 @@ def __init__( mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, align_corners: Union[Sequence[bool], bool] = False, - dtype: Optional[Union[Sequence[DtypeLike], DtypeLike]] = np.float64, + dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, meta_keys: Optional[KeysCollection] = None, meta_key_postfix: str = DEFAULT_POST_FIX, allow_missing_keys: bool = False, @@ -201,7 +360,7 @@ def __init__( the meta data is a dictionary object which contains: filename, affine, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys=None, use `key_{postfix}` to to fetch the meta data according + meta_key_postfix: if meta_keys=None, use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. @@ -277,7 +436,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] orig_size = transform[TraceKeys.ORIG_SIZE] - orig_pixdim = np.sqrt(np.sum(np.square(old_affine), 0))[:-1] + orig_pixdim = affine_to_spacing(old_affine, -1) inverse_transform = Spacing(orig_pixdim, diagonal=self.spacing_transform.diagonal) # Apply inverse d[key], _, new_affine = inverse_transform( @@ -305,6 +464,10 @@ class Orientationd(MapTransform, InvertibleTransform): After reorienting the input array, this transform will write the new affine to the `affine` field of metadata which is formed by ``key_{meta_key_postfix}``. + + This transform assumes the channel-first input format. + In the case of using this transform for normalizing the orientations of images, + it should be used before any anisotropic spatial transforms. """ backend = Orientation.backend @@ -335,7 +498,7 @@ def __init__( the meta data is a dictionary object which contains: filename, affine, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. @@ -607,6 +770,7 @@ def __init__( padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, as_tensor_output: bool = True, device: Optional[torch.device] = None, + dtype: Union[DtypeLike, torch.dtype] = np.float32, allow_missing_keys: bool = False, ) -> None: """ @@ -646,6 +810,9 @@ def __init__( See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. device: device on which the tensor will be allocated. + dtype: data type for resampling computation. Defaults to ``np.float32``. + If ``None``, use the data type of input data. To be compatible with other modules, + the output data type is always `float32`. allow_missing_keys: don't raise exception if key is missing. See also: @@ -665,6 +832,7 @@ def __init__( affine=affine, spatial_size=spatial_size, device=device, + dtype=dtype, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) @@ -699,7 +867,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) - grid, _ = affine_grid(orig_size) # type: ignore + grid, _ = affine_grid(orig_size) # Apply inverse transform d[key] = self.affine.resampler(d[key], grid, mode, padding_mode) @@ -833,7 +1001,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N grid = self.rand_affine.get_identity_grid(sp_size) if self._do_transform: # add some random factors grid = self.rand_affine.rand_affine_grid(grid=grid) - affine = self.rand_affine.rand_affine_grid.get_transformation_matrix() # type: ignore[assignment] + affine = self.rand_affine.rand_affine_grid.get_transformation_matrix() for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): self.push_transform( @@ -866,7 +1034,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) - grid, _ = affine_grid(orig_size) # type: ignore + grid, _ = affine_grid(orig_size) # Apply inverse transform d[key] = self.rand_affine.resampler(d[key], grid, mode, padding_mode) @@ -1320,7 +1488,7 @@ class Rotated(MapTransform, InvertibleTransform): align_corners: Defaults to False. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of bool, each element corresponds to a key in ``keys``. - dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. + dtype: data type for resampling computation. Defaults to ``np.float32``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. It also can be a sequence of dtype or None, each element corresponds to a key in ``keys``. @@ -1337,7 +1505,7 @@ def __init__( mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, align_corners: Union[Sequence[bool], bool] = False, - dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], Union[DtypeLike, torch.dtype]] = np.float64, + dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], DtypeLike, torch.dtype] = np.float32, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) @@ -1389,10 +1557,8 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd align_corners=False if align_corners == TraceKeys.NONE else align_corners, reverse_indexing=True, ) - img_t: torch.Tensor - img_t, *_ = convert_data_type(d[key], torch.Tensor, dtype=dtype) # type: ignore - transform_t: torch.Tensor - transform_t, *_ = convert_to_dst_type(inv_rot_mat, img_t) # type: ignore + img_t, *_ = convert_data_type(d[key], torch.Tensor, dtype=dtype) + transform_t, *_ = convert_to_dst_type(inv_rot_mat, img_t) out = xform(img_t.unsqueeze(0), transform_t, spatial_size=transform[TraceKeys.ORIG_SIZE]).squeeze(0) out, *_ = convert_to_dst_type(out, dst=d[key], dtype=out.dtype) @@ -1451,7 +1617,7 @@ def __init__( mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, align_corners: Union[Sequence[bool], bool] = False, - dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], Union[DtypeLike, torch.dtype]] = np.float64, + dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], DtypeLike, torch.dtype] = np.float32, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) @@ -1523,10 +1689,8 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd align_corners=False if align_corners == TraceKeys.NONE else align_corners, reverse_indexing=True, ) - img_t: torch.Tensor - img_t, *_ = convert_data_type(d[key], torch.Tensor, dtype=dtype) # type: ignore - transform_t: torch.Tensor - transform_t, *_ = convert_to_dst_type(inv_rot_mat, img_t) # type: ignore + img_t, *_ = convert_data_type(d[key], torch.Tensor, dtype=dtype) + transform_t, *_ = convert_to_dst_type(inv_rot_mat, img_t) output: torch.Tensor out = xform(img_t.unsqueeze(0), transform_t, spatial_size=transform[TraceKeys.ORIG_SIZE]).squeeze(0) out, *_ = convert_to_dst_type(out, dst=d[key], dtype=out.dtype) @@ -1622,7 +1786,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd align_corners=None if align_corners == TraceKeys.NONE else align_corners, ) # Size might be out by 1 voxel so pad - d[key] = SpatialPad(transform[TraceKeys.ORIG_SIZE], mode="edge")(d[key]) # type: ignore + d[key] = SpatialPad(transform[TraceKeys.ORIG_SIZE], mode="edge")(d[key]) # Remove the applied transform self.pop_transform(d, key) @@ -1746,7 +1910,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd align_corners=None if align_corners == TraceKeys.NONE else align_corners, ) # Size might be out by 1 voxel so pad - d[key] = SpatialPad(transform[TraceKeys.ORIG_SIZE], mode="edge")(d[key]) # type: ignore + d[key] = SpatialPad(transform[TraceKeys.ORIG_SIZE], mode="edge")(d[key]) # Remove the applied transform self.pop_transform(d, key) @@ -1870,6 +2034,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d +SpatialResampleD = SpatialResampleDict = SpatialResampled SpacingD = SpacingDict = Spacingd OrientationD = OrientationDict = Orientationd Rotate90D = Rotate90Dict = Rotate90d diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 430e659c95..8537f7eb89 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -54,7 +54,11 @@ def _apply_transform( def apply_transform( - transform: Callable[..., ReturnType], data: Any, map_items: bool = True, unpack_items: bool = False + transform: Callable[..., ReturnType], + data: Any, + map_items: bool = True, + unpack_items: bool = False, + log_stats: bool = False, ) -> Union[List[ReturnType], ReturnType]: """ Transform `data` with `transform`. @@ -69,6 +73,9 @@ def apply_transform( map_items: whether to apply transform to each item in `data`, if `data` is a list or tuple. Defaults to True. unpack_items: whether to unpack parameters using `*`. Defaults to False. + log_stats: whether to log the detailed information of data and applied transform when error happened, + for NumPy array and PyTorch Tensor, log the data shape and value range, + for other meta data, log the values directly. default to `False`. Raises: Exception: When ``transform`` raises an exception. @@ -82,7 +89,7 @@ def apply_transform( return _apply_transform(transform, data, unpack_items) except Exception as e: - if not isinstance(transform, transforms.compose.Compose): + if log_stats and not isinstance(transform, transforms.compose.Compose): # log the input data information of exact transform in the transform chain datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False) logger = logging.getLogger(datastats._logger_name) @@ -93,7 +100,7 @@ def apply_transform( def _log_stats(data, prefix: Optional[str] = "Data"): if isinstance(data, (np.ndarray, torch.Tensor)): # log data type, shape, range for array - datastats(img=data, data_shape=True, value_range=True, prefix=prefix) # type: ignore + datastats(img=data, data_shape=True, value_range=True, prefix=prefix) else: # log data type and value for other meta data datastats(img=data, data_value=True, prefix=prefix) @@ -360,7 +367,10 @@ def key_iterator(self, data: Dict[Hashable, Any], *extra_iterables: Optional[Ite if key in data: yield (key,) + tuple(_ex_iters) if extra_iterables else key elif not self.allow_missing_keys: - raise KeyError(f"Key was missing ({key}) and allow_missing_keys==False") + raise KeyError( + f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data" + " and allow_missing_keys==False." + ) def first_key(self, data: Dict[Hashable, Any]): """ diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 664433270b..0100c33719 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -325,7 +325,7 @@ def __init__(self, dtype=np.float32) -> None: """ self.dtype = dtype - def __call__(self, img: NdarrayOrTensor, dtype: Optional[Union[DtypeLike, torch.dtype]] = None) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, dtype: Union[DtypeLike, torch.dtype] = None) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is a numpy array or PyTorch Tensor. @@ -336,8 +336,7 @@ def __call__(self, img: NdarrayOrTensor, dtype: Optional[Union[DtypeLike, torch. TypeError: When ``img`` type is not in ``Union[numpy.ndarray, torch.Tensor]``. """ - img_out, *_ = convert_data_type(img, output_type=type(img), dtype=dtype or self.dtype) - return img_out + return convert_data_type(img, output_type=type(img), dtype=dtype or self.dtype)[0] # type: ignore class ToTensor(Transform): @@ -412,6 +411,7 @@ def __call__(self, data: NdarrayOrTensor): """ output_type = torch.Tensor if self.data_type == "tensor" else np.ndarray + out: NdarrayOrTensor out, *_ = convert_data_type( data=data, output_type=output_type, dtype=self.dtype, device=self.device, wrap_sequence=self.wrap_sequence ) @@ -547,6 +547,11 @@ class DataStats(Transform): It can be inserted into any place of a transform chain and check results of previous transforms. It support both `numpy.ndarray` and `torch.tensor` as input data, so it can be used in pre-processing and post-processing. + + It gets logger from `logging.getLogger(name)`, we can setup a logger outside first with the same `name`. + If the log level of `logging.RootLogger` is higher than `INFO`, will add a separate `StreamHandler` + log handler with `INFO` level and record to `stdout`. + """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -559,7 +564,7 @@ def __init__( value_range: bool = True, data_value: bool = False, additional_info: Optional[Callable] = None, - logger_handler: Optional[logging.Handler] = None, + name: str = "DataStats", ) -> None: """ Args: @@ -570,9 +575,7 @@ def __init__( data_value: whether to show the raw value of input data. a typical example is to print some properties of Nifti image: affine, pixdim, etc. additional_info: user can define callable function to extract additional info from input data. - logger_handler: add additional handler to output data: save to file, etc. - all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html. - the handler should have a logging level of at least `INFO`. + name: identifier of `logging.logger` to use, defaulting to "DataStats". Raises: TypeError: When ``additional_info`` is not an ``Optional[Callable]``. @@ -588,14 +591,14 @@ def __init__( if additional_info is not None and not callable(additional_info): raise TypeError(f"additional_info must be None or callable but is {type(additional_info).__name__}.") self.additional_info = additional_info - self._logger_name = "DataStats" + self._logger_name = name _logger = logging.getLogger(self._logger_name) _logger.setLevel(logging.INFO) - console = logging.StreamHandler(sys.stdout) # always stdout - console.setLevel(logging.INFO) - _logger.addHandler(console) - if logger_handler is not None: - _logger.addHandler(logger_handler) + if logging.root.getEffectiveLevel() > logging.INFO: + # if the root log level is higher than INFO, set a separate stream handler to record + console = logging.StreamHandler(sys.stdout) + console.setLevel(logging.INFO) + _logger.addHandler(console) def __call__( self, @@ -1021,7 +1024,7 @@ def __call__(self, img: NdarrayOrTensor): img: PyTorch Tensor data for the TorchVision transform. """ - img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore + img_t, *_ = convert_data_type(img, torch.Tensor) out = self.trans(img_t) out, *_ = convert_to_dst_type(src=out, dst=img) return out @@ -1113,8 +1116,7 @@ def __call__( mask must have the same shape as input `img`. """ - img_np: np.ndarray - img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore + img_np, *_ = convert_data_type(img, np.ndarray) if meta_data is None: meta_data = {} diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index e61aa5f512..ecf9aaffa4 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -15,7 +15,6 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ -import logging import re from copy import deepcopy from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union @@ -520,7 +519,7 @@ def __init__( self, keys: KeysCollection, data_type: str = "tensor", - dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + dtype: Union[DtypeLike, torch.dtype] = None, device: Optional[torch.device] = None, wrap_sequence: bool = True, allow_missing_keys: bool = False, @@ -775,7 +774,7 @@ def __init__( value_range: Union[Sequence[bool], bool] = True, data_value: Union[Sequence[bool], bool] = False, additional_info: Optional[Union[Sequence[Callable], Callable]] = None, - logger_handler: Optional[logging.Handler] = None, + name: str = "DataStats", allow_missing_keys: bool = False, ) -> None: """ @@ -796,9 +795,7 @@ def __init__( additional_info: user can define callable function to extract additional info from input data. it also can be a sequence of string, each element corresponds to a key in ``keys``. - logger_handler: add additional handler to output data: save to file, etc. - all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html. - the handler should have a logging level of at least `INFO`. + name: identifier of `logging.logger` to use, defaulting to "DataStats". allow_missing_keys: don't raise exception if key is missing. """ @@ -809,8 +806,7 @@ def __init__( self.value_range = ensure_tuple_rep(value_range, len(self.keys)) self.data_value = ensure_tuple_rep(data_value, len(self.keys)) self.additional_info = ensure_tuple_rep(additional_info, len(self.keys)) - self.logger_handler = logger_handler - self.printer = DataStats(logger_handler=logger_handler) + self.printer = DataStats(name=name) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -1449,7 +1445,7 @@ class IntensityStatsd(MapTransform): the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. used to store the computed statistics to the meta dict. allow_missing_keys: don't raise exception if key is missing. diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 8265cd2a72..d222810bbe 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -21,7 +21,7 @@ import monai from monai.config import DtypeLike, IndexSelection -from monai.config.type_definitions import NdarrayOrTensor +from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor from monai.networks.layers import GaussianFilter from monai.networks.utils import meshgrid_ij from monai.transforms.compose import Compose, OneOf @@ -48,6 +48,7 @@ ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, + get_equivalent_dtype, issequenceiterable, look_up_option, min_version, @@ -60,6 +61,7 @@ ndimage, _ = optional_import("scipy.ndimage") cp, has_cp = optional_import("cupy") cp_ndarray, _ = optional_import("cupy", name="ndarray") +cucim, has_cucim = optional_import("cucim") exposure, has_skimage = optional_import("skimage.exposure") __all__ = [ @@ -155,7 +157,7 @@ def rescale_array( arr: NdarrayOrTensor, minv: Optional[float] = 0.0, maxv: Optional[float] = 1.0, - dtype: Optional[Union[DtypeLike, torch.dtype]] = np.float32, + dtype: Union[DtypeLike, torch.dtype] = np.float32, ) -> NdarrayOrTensor: """ Rescale the values of numpy array `arr` to be from `minv` to `maxv`. @@ -358,7 +360,7 @@ def map_classes_to_indices( label_flat = ravel(any_np_pt(label[c : c + 1] if channels > 1 else label == c, 0)) label_flat = img_flat & label_flat if img_flat is not None else label_flat # no need to save the indices in GPU, otherwise, still need to move to CPU at runtime when crop by indices - cls_indices, *_ = convert_data_type(nonzero(label_flat), device=torch.device("cpu")) + cls_indices: NdarrayOrTensor = convert_data_type(nonzero(label_flat), device=torch.device("cpu"))[0] indices.append(cls_indices) return indices @@ -402,7 +404,7 @@ def weighted_patch_samples( if not v[-1] or not isfinite(v[-1]) or v[-1] < 0: # uniform sampling idx = r_state.randint(0, len(v), size=n_samples) else: - r, *_ = convert_to_dst_type(r_state.random(n_samples), v) # type: ignore + r, *_ = convert_to_dst_type(r_state.random(n_samples), v) idx = searchsorted(v, r * v[-1], right=True) # type: ignore idx, *_ = convert_to_dst_type(idx, v, dtype=torch.int) # type: ignore # compensate 'valid' mode @@ -569,7 +571,7 @@ def create_grid( spatial_size: Sequence[int], spacing: Optional[Sequence[float]] = None, homogeneous: bool = True, - dtype=float, + dtype: Union[DtypeLike, torch.dtype] = float, device: Optional[torch.device] = None, backend=TransformBackends.NUMPY, ): @@ -580,16 +582,17 @@ def create_grid( spatial_size: spatial size of the grid. spacing: same len as ``spatial_size``, defaults to 1.0 (dense grid). homogeneous: whether to make homogeneous coordinates. - dtype: output grid data type. + dtype: output grid data type, defaults to `float`. device: device to compute and store the output (when the backend is "torch"). backend: APIs to use, ``numpy`` or ``torch``. """ _backend = look_up_option(backend, TransformBackends) + _dtype = dtype or float if _backend == TransformBackends.NUMPY: - return _create_grid_numpy(spatial_size, spacing, homogeneous, dtype) + return _create_grid_numpy(spatial_size, spacing, homogeneous, _dtype) if _backend == TransformBackends.TORCH: - return _create_grid_torch(spatial_size, spacing, homogeneous, dtype, device) + return _create_grid_torch(spatial_size, spacing, homogeneous, _dtype, device) raise ValueError(f"backend {backend} is not supported") @@ -597,14 +600,14 @@ def _create_grid_numpy( spatial_size: Sequence[int], spacing: Optional[Sequence[float]] = None, homogeneous: bool = True, - dtype: DtypeLike = float, + dtype: Union[DtypeLike, torch.dtype] = float, ): """ compute a `spatial_size` mesh with the numpy API. """ spacing = spacing or tuple(1.0 for _ in spatial_size) ranges = [np.linspace(-(d - 1.0) / 2.0 * s, (d - 1.0) / 2.0 * s, int(d)) for d, s in zip(spatial_size, spacing)] - coords = np.asarray(np.meshgrid(*ranges, indexing="ij"), dtype=dtype) + coords = np.asarray(np.meshgrid(*ranges, indexing="ij"), dtype=get_equivalent_dtype(dtype, np.ndarray)) if not homogeneous: return coords return np.concatenate([coords, np.ones_like(coords[:1])]) @@ -622,7 +625,13 @@ def _create_grid_torch( """ spacing = spacing or tuple(1.0 for _ in spatial_size) ranges = [ - torch.linspace(-(d - 1.0) / 2.0 * s, (d - 1.0) / 2.0 * s, int(d), device=device, dtype=dtype) + torch.linspace( + -(d - 1.0) / 2.0 * s, + (d - 1.0) / 2.0 * s, + int(d), + device=device, + dtype=get_equivalent_dtype(dtype, torch.Tensor), + ) for d, s in zip(spatial_size, spacing) ] coords = meshgrid_ij(*ranges) @@ -852,7 +861,7 @@ def create_translate( spatial_dims=spatial_dims, shift=shift, eye_func=lambda x: torch.eye(torch.as_tensor(x), device=device), # type: ignore - array_func=lambda x: torch.as_tensor(x, device=device), # type: ignore + array_func=lambda x: torch.as_tensor(x, device=device), ) raise ValueError(f"backend {backend} is not supported") @@ -874,7 +883,7 @@ def generate_spatial_bounding_box( margin: Union[Sequence[int], int] = 0, ) -> Tuple[List[int], List[int]]: """ - generate the spatial bounding box of foreground in the image with start-end positions (inclusive). + Generate the spatial bounding box of foreground in the image with start-end positions (inclusive). Users can define arbitrary function to select expected foreground from the whole image or specified channels. And it can also add margin to every dim of the bounding box. The output format of the coordinates is: @@ -886,7 +895,7 @@ def generate_spatial_bounding_box( This function returns [-1, -1, ...], [-1, -1, ...] if there's no positive intensity. Args: - img: source image to generate bounding box from. + img: a "channel-first" image of shape (C, spatial_dim1[, spatial_dim2, ...]) to generate bounding box from. select_fn: function to select expected foreground, default is to select values > 0. channel_indices: if defined, select foreground only on the specified channels of image. if None, select foreground on the whole image. @@ -922,7 +931,7 @@ def generate_spatial_bounding_box( return box_start, box_end -def get_largest_connected_component_mask(img: NdarrayOrTensor, connectivity: Optional[int] = None) -> NdarrayOrTensor: +def get_largest_connected_component_mask(img: NdarrayTensor, connectivity: Optional[int] = None) -> NdarrayTensor: """ Gets the largest connected component mask of an image. @@ -933,13 +942,23 @@ def get_largest_connected_component_mask(img: NdarrayOrTensor, connectivity: Opt connectivity of ``input.ndim`` is used. for more details: https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.label. """ - img_arr: np.ndarray = convert_data_type(img, np.ndarray)[0] # type: ignore + if isinstance(img, torch.Tensor) and has_cp and has_cucim: + x_cupy = monai.transforms.ToCupy()(img.short()) + x_label = cucim.skimage.measure.label(x_cupy, connectivity=connectivity) + vals, counts = cp.unique(x_label[cp.nonzero(x_label)], return_counts=True) + comp = x_label == vals[cp.ndarray.argmax(counts)] + out_tensor = monai.transforms.ToTensor(device=img.device)(comp) + out_tensor = out_tensor.bool() + + return out_tensor # type: ignore + + img_arr = convert_data_type(img, np.ndarray)[0] largest_cc: np.ndarray = np.zeros(shape=img_arr.shape, dtype=img_arr.dtype) img_arr = measure.label(img_arr, connectivity=connectivity) if img_arr.max() != 0: largest_cc[...] = img_arr == (np.argmax(np.bincount(img_arr.flat)[1:]) + 1) - largest_cc = convert_to_dst_type(largest_cc, dst=img, dtype=largest_cc.dtype)[0] # type: ignore - return largest_cc + + return convert_to_dst_type(largest_cc, dst=img, dtype=largest_cc.dtype)[0] def fill_holes( diff --git a/monai/transforms/utils_create_transform_ims.py b/monai/transforms/utils_create_transform_ims.py index ab282d5332..b096e1b93d 100644 --- a/monai/transforms/utils_create_transform_ims.py +++ b/monai/transforms/utils_create_transform_ims.py @@ -244,7 +244,8 @@ def update_docstring(code_path, transform_name): contents.insert(image_line + 1, " :alt: example of " + transform_name + "\n") # check that we've only added two lines - assert len(contents) == len(contents_orig) + 2 + if len(contents) != len(contents_orig) + 2: + raise AssertionError # write the updated doc to overwrite the original with open(code_path, "w") as f: @@ -382,7 +383,7 @@ def get_images(data, is_label=False): # we might need to panel the images. this happens if a transform produces e.g. 4 output images. # In this case, we create a 2-by-2 grid from them. Output will be a list containing n_orthog_views, # each element being either the image (if num_samples is 1) or the panelled image. - nrows = int(np.floor(num_samples ** 0.5)) + nrows = int(np.floor(num_samples**0.5)) for view in range(num_orthog_views): result = np.asarray([d[view] for d in data]) nindex, height, width = result.shape diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 90932553c4..25cb1455dd 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -14,10 +14,12 @@ import numpy as np import torch -from monai.config.type_definitions import NdarrayOrTensor +from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor from monai.utils.misc import ensure_tuple, is_module_ver_at_least +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type __all__ = [ + "allclose", "moveaxis", "in1d", "clip", @@ -37,16 +39,26 @@ "repeat", "isnan", "ascontiguousarray", + "stack", + "mode", ] +def allclose(a: NdarrayTensor, b: NdarrayOrTensor, rtol=1e-5, atol=1e-8, equal_nan=False) -> bool: + """`np.allclose` with equivalent implementation for torch.""" + b, *_ = convert_to_dst_type(b, a) + if isinstance(a, np.ndarray): + return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) # type: ignore + + def moveaxis(x: NdarrayOrTensor, src: Union[int, Sequence[int]], dst: Union[int, Sequence[int]]) -> NdarrayOrTensor: """`moveaxis` for pytorch and numpy, using `permute` for pytorch version < 1.7""" if isinstance(x, torch.Tensor): if hasattr(torch, "movedim"): # `movedim` is new in torch 1.7.0 # torch.moveaxis is a recent alias since torch 1.8.0 return torch.movedim(x, src, dst) # type: ignore - return _moveaxis_with_permute(x, src, dst) # type: ignore + return _moveaxis_with_permute(x, src, dst) return np.moveaxis(x, src, dst) @@ -314,7 +326,7 @@ def isfinite(x: NdarrayOrTensor) -> NdarrayOrTensor: return torch.isfinite(x) -def searchsorted(a: NdarrayOrTensor, v: NdarrayOrTensor, right=False, sorter=None, **kwargs) -> NdarrayOrTensor: +def searchsorted(a: NdarrayTensor, v: NdarrayOrTensor, right=False, sorter=None, **kwargs) -> NdarrayTensor: """ `np.searchsorted` with equivalent implementation for torch. @@ -362,7 +374,7 @@ def isnan(x: NdarrayOrTensor) -> NdarrayOrTensor: return torch.isnan(x) -def ascontiguousarray(x: NdarrayOrTensor, **kwargs) -> NdarrayOrTensor: +def ascontiguousarray(x: NdarrayTensor, **kwargs) -> NdarrayOrTensor: """`np.ascontiguousarray` with equivalent implementation for torch (`contiguous`). Args: @@ -378,3 +390,30 @@ def ascontiguousarray(x: NdarrayOrTensor, **kwargs) -> NdarrayOrTensor: if isinstance(x, torch.Tensor): return x.contiguous(**kwargs) return x + + +def stack(x: Sequence[NdarrayTensor], dim: int) -> NdarrayTensor: + """`np.stack` with equivalent implementation for torch. + + Args: + x: array/tensor + dim: dimension along which to perform the stack (referred to as `axis` by numpy) + """ + if isinstance(x[0], np.ndarray): + return np.stack(x, dim) # type: ignore + return torch.stack(x, dim) # type: ignore + + +def mode(x: NdarrayTensor, dim: int = -1, to_long: bool = True) -> NdarrayTensor: + """`torch.mode` with equivalent implementation for numpy. + + Args: + x: array/tensor + dim: dimension along which to perform `mode` (referred to as `axis` by numpy) + to_long: convert input to long before performing mode. + """ + dtype = torch.int64 if to_long else None + x_t, *_ = convert_data_type(x, torch.Tensor, dtype=dtype) + o_t = torch.mode(x_t, dim).values + o, *_ = convert_to_dst_type(o_t, x) + return o diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index a1e9390c30..636ea15c8d 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -56,22 +56,21 @@ list_to_dict, progress_bar, sample_slices, + save_obj, set_determinism, star_zip_with, zip_with, ) from .module import ( - ClassScanner, InvalidPyTorchVersionError, OptionalImportError, damerau_levenshtein_distance, exact_version, export, - get_class, get_full_type_name, get_package_version, get_torch_version_tuple, - instantiate_class, + instantiate, load_submodules, look_up_option, min_version, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 7aa6c5bbc3..022cd2c58a 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -12,17 +12,21 @@ import collections.abc import inspect import itertools +import os import random +import shutil +import tempfile import types import warnings from ast import literal_eval from distutils.util import strtobool +from pathlib import Path from typing import Any, Callable, Optional, Sequence, Tuple, Union, cast import numpy as np import torch -from monai.config.type_definitions import NdarrayOrTensor +from monai.config.type_definitions import NdarrayOrTensor, PathLike from monai.utils.module import version_leq __all__ = [ @@ -46,6 +50,7 @@ "is_module_ver_at_least", "has_option", "sample_slices", + "save_obj", ] _seed = None @@ -386,3 +391,51 @@ def sample_slices(data: NdarrayOrTensor, dim: int = 1, as_indices: bool = True, slices[dim] = slicevals if as_indices else slice(*slicevals) # type: ignore return data[tuple(slices)] + + +def save_obj( + obj, path: PathLike, create_dir: bool = True, atomic: bool = True, func: Optional[Callable] = None, **kwargs +): + """ + Save an object to file with specified path. + Support to serialize to a temporary file first, then move to final destination, + so that files are guaranteed to not be damaged if exception occurs. + + Args: + obj: input object data to save. + path: target file path to save the input object. + create_dir: whether to create dictionary of the path if not existng, default to `True`. + atomic: if `True`, state is serialized to a temporary file first, then move to final destination. + so that files are guaranteed to not be damaged if exception occurs. default to `True`. + func: the function to save file, if None, default to `torch.save`. + kwargs: other args for the save `func` except for the checkpoint and filename. + default `func` is `torch.save()`, details of other args: + https://pytorch.org/docs/stable/generated/torch.save.html. + + """ + path = Path(path) + path_dir = path.parent + if not path_dir.exists(): + if create_dir: + path_dir.mkdir(parents=True) + else: + raise ValueError(f"the directory of specified path is not existing: {path_dir}.") + if path.exists(): + # remove the existing file + os.remove(path) + + if func is None: + func = torch.save + + if not atomic: + func(obj=obj, f=path, **kwargs) + return + try: + # writing to a temporary directory and then using a nearly atomic rename operation + with tempfile.TemporaryDirectory() as tempdir: + temp_path: Path = Path(tempdir) / path.name + func(obj=obj, f=temp_path, **kwargs) + if temp_path.is_file(): + shutil.move(str(temp_path), path) + except PermissionError: # project-monai/monai issue #3613 + pass diff --git a/monai/utils/module.py b/monai/utils/module.py index 2994b80421..1dcbe6849f 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -15,12 +15,13 @@ import re import sys import warnings -from functools import wraps +from functools import partial, wraps from importlib import import_module from pkgutil import walk_packages +from pydoc import locate from re import match from types import FunctionType -from typing import Any, Callable, Collection, Hashable, Iterable, List, Mapping, Sequence, Tuple, cast +from typing import Any, Callable, Collection, Hashable, Iterable, List, Mapping, Tuple, Union, cast import torch @@ -37,9 +38,7 @@ "optional_import", "require_pkg", "load_submodules", - "ClassScanner", - "get_class", - "instantiate_class", + "instantiate", "get_full_type_name", "get_package_version", "get_torch_version_tuple", @@ -48,7 +47,7 @@ ] -def look_up_option(opt_str, supported: Collection, default="no_default"): +def look_up_option(opt_str, supported: Union[Collection, enum.EnumMeta], default="no_default"): """ Look up the option in the supported collection and return the matched item. Raise a value error possibly with a guess of the closest match. @@ -197,98 +196,27 @@ def load_submodules(basemod, load_all: bool = True, exclude_pattern: str = "(.*[ return submodules, err_mod -class ClassScanner: +def instantiate(path: str, **kwargs): """ - Scan all the available classes in the specified packages and modules. - Map the all the class names and the module names in a table. + Method for creating an instance for the specified class / function path. + kwargs will be class args or default args for `partial` function. + The target component must be a class or a function, if not, return the component directly. Args: - pkgs: the expected packages to scan modules and parse class names in the config. - modules: the expected modules in the packages to scan for all the classes. - for example, to parser "LoadImage" in config, `pkgs` can be ["monai"], `modules` can be ["transforms"]. + path: full path of the target class or function component. + kwargs: arguments to initialize the class instance or set default args + for `partial` function. """ - def __init__(self, pkgs: Sequence[str], modules: Sequence[str]): - self.pkgs = pkgs - self.modules = modules - self._class_table = self._create_classes_table() + component = locate(path) + if inspect.isclass(component): + return component(**kwargs) + if inspect.isfunction(component): + return partial(component, **kwargs) - def _create_classes_table(self): - class_table = {} - for pkg in self.pkgs: - package = import_module(pkg) - - for _, modname, _ in walk_packages(path=package.__path__, prefix=package.__name__ + "."): - # if no modules specified, load all modules in the package - if len(self.modules) == 0 or any(name in modname for name in self.modules): - try: - module = import_module(modname) - for name, obj in inspect.getmembers(module): - if inspect.isclass(obj) and obj.__module__ == modname: - class_table[name] = modname - except ModuleNotFoundError: - pass - return class_table - - def get_class_module_name(self, class_name): - """ - Get the module name of the class with specified class name. - - Args: - class_name: name of the expected class. - - """ - return self._class_table.get(class_name, None) - - -def get_class(class_path: str): - """ - Get the class from specified class path. - - Args: - class_path (str): full path of the class. - - Raises: - ValueError: invalid class_path, missing the module name. - ValueError: class does not exist. - ValueError: module does not exist. - - """ - if len(class_path.split(".")) < 2: - raise ValueError(f"invalid class_path: {class_path}, missing the module name.") - module_name, class_name = class_path.rsplit(".", 1) - - try: - module_ = import_module(module_name) - - try: - class_ = getattr(module_, class_name) - except AttributeError as e: - raise ValueError(f"class {class_name} does not exist.") from e - - except AttributeError as e: - raise ValueError(f"module {module_name} does not exist.") from e - - return class_ - - -def instantiate_class(class_path: str, **kwargs): - """ - Method for creating an instance for the specified class. - - Args: - class_path: full path of the class. - kwargs: arguments to initialize the class instance. - - Raises: - ValueError: class has paramenters error. - """ - - try: - return get_class(class_path)(**kwargs) - except TypeError as e: - raise ValueError(f"class {class_path} has parameters error.") from e + warnings.warn(f"target component must be a valid class or function, but got {path}.") + return component def get_full_type_name(typeobj): @@ -523,9 +451,12 @@ def version_leq(lhs: str, rhs: str): """ lhs, rhs = str(lhs), str(rhs) - ver, has_ver = optional_import("pkg_resources", name="parse_version") + pkging, has_ver = optional_import("pkg_resources", name="packaging") if has_ver: - return ver(lhs) <= ver(rhs) + try: + return pkging.version.Version(lhs) <= pkging.version.Version(rhs) + except pkging.version.InvalidVersion: + return True def _try_cast(val: str): val = val.strip() diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 8686557176..d5944e265b 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -10,12 +10,12 @@ # limitations under the License. import re -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Sequence, Tuple, Type, Union import numpy as np import torch -from monai.config.type_definitions import DtypeLike, NdarrayOrTensor +from monai.config.type_definitions import DtypeLike, NdarrayTensor from monai.utils import optional_import from monai.utils.module import look_up_option @@ -212,18 +212,18 @@ def convert_to_cupy(data, dtype: Optional[np.dtype] = None, wrap_sequence: bool def convert_data_type( data: Any, - output_type: Optional[type] = None, + output_type: Optional[Type[NdarrayTensor]] = None, device: Optional[torch.device] = None, - dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + dtype: Union[DtypeLike, torch.dtype] = None, wrap_sequence: bool = False, -) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]: +) -> Tuple[NdarrayTensor, type, Optional[torch.device]]: """ Convert to `torch.Tensor`/`np.ndarray` from `torch.Tensor`/`np.ndarray`/`float`/`int` etc. Args: data: data to be converted - output_type: `torch.Tensor` or `np.ndarray` (if blank, unchanged) - device: if output is `torch.Tensor`, select device (if blank, unchanged) + output_type: `torch.Tensor` or `np.ndarray` (if `None`, unchanged) + device: if output is `torch.Tensor`, select device (if `None`, unchanged) dtype: dtype of output data. Converted to correct library type (e.g., `np.float32` is converted to `torch.float32` if output type is `torch.Tensor`). If left blank, it remains unchanged. @@ -241,7 +241,7 @@ def convert_data_type( (1.0, , None) """ - orig_type: Any + orig_type: type if isinstance(data, torch.Tensor): orig_type = torch.Tensor elif isinstance(data, np.ndarray): @@ -257,20 +257,22 @@ def convert_data_type( dtype_ = get_equivalent_dtype(dtype, output_type) - if output_type is torch.Tensor: - data = convert_to_tensor(data, dtype=dtype_, device=device, wrap_sequence=wrap_sequence) - elif output_type is np.ndarray: - data = convert_to_numpy(data, dtype=dtype_, wrap_sequence=wrap_sequence) - elif has_cp and output_type is cp.ndarray: - data = convert_to_cupy(data, dtype=dtype_, wrap_sequence=wrap_sequence) - else: - raise ValueError(f"Unsupported output type: {output_type}") - return data, orig_type, orig_device + data_: NdarrayTensor + if issubclass(output_type, torch.Tensor): + data_ = convert_to_tensor(data, dtype=dtype_, device=device, wrap_sequence=wrap_sequence) + return data_, orig_type, orig_device + if issubclass(output_type, np.ndarray): + data_ = convert_to_numpy(data, dtype=dtype_, wrap_sequence=wrap_sequence) + return data_, orig_type, orig_device + elif has_cp and issubclass(output_type, cp.ndarray): + data_ = convert_to_cupy(data, dtype=dtype_, wrap_sequence=wrap_sequence) + return data_, orig_type, orig_device + raise ValueError(f"Unsupported output type: {output_type}") def convert_to_dst_type( - src: Any, dst: NdarrayOrTensor, dtype: Optional[Union[DtypeLike, torch.dtype]] = None, wrap_sequence: bool = False -) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]: + src: Any, dst: NdarrayTensor, dtype: Union[DtypeLike, torch.dtype, None] = None, wrap_sequence: bool = False +) -> Tuple[NdarrayTensor, type, Optional[torch.device]]: """ Convert source data to the same data type and device as the destination data. If `dst` is an instance of `torch.Tensor` or its subclass, convert `src` to `torch.Tensor` with the same data type as `dst`, diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index b7e433d6fa..16fb64cb46 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -39,9 +39,9 @@ def _compute(data: np.ndarray) -> np.ndarray: return np.stack([scaler(i) for i in data], axis=0) if isinstance(x, torch.Tensor): - return torch.as_tensor(_compute(x.detach().cpu().numpy()), device=x.device) + return torch.as_tensor(_compute(x.detach().cpu().numpy()), device=x.device) # type: ignore - return _compute(x) + return _compute(x) # type: ignore class ModelWithHooks: diff --git a/monai/visualize/img2tensorboard.py b/monai/visualize/img2tensorboard.py index 4d4e91c6fd..0af05adf32 100644 --- a/monai/visualize/img2tensorboard.py +++ b/monai/visualize/img2tensorboard.py @@ -49,8 +49,7 @@ def _image3_animated_gif( if len(image.shape) != 3: raise AssertionError("3D image tensors expected to be in `HWD` format, len(image.shape) != 3") - image_np: np.ndarray - image_np, *_ = convert_data_type(image, output_type=np.ndarray) # type: ignore + image_np, *_ = convert_data_type(image, output_type=np.ndarray) ims = [(i * scale_factor).astype(np.uint8, copy=False) for i in np.moveaxis(image_np, frame_dim, 0)] ims = [GifImage.fromarray(im) for im in ims] img_str = b"" diff --git a/monai/visualize/occlusion_sensitivity.py b/monai/visualize/occlusion_sensitivity.py index 23c8ee857a..d87b93396a 100644 --- a/monai/visualize/occlusion_sensitivity.py +++ b/monai/visualize/occlusion_sensitivity.py @@ -276,9 +276,7 @@ def _compute_occlusion_sensitivity(self, x, b_box): return sensitivity_ims, output_im_shape - def __call__( # type: ignore - self, x: torch.Tensor, b_box: Optional[Sequence] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: + def __call__(self, x: torch.Tensor, b_box: Optional[Sequence] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: x: Image to use for inference. Should be a tensor consisting of 1 batch. diff --git a/monai/visualize/utils.py b/monai/visualize/utils.py index 63cac7ea35..be4f3f60fc 100644 --- a/monai/visualize/utils.py +++ b/monai/visualize/utils.py @@ -93,7 +93,7 @@ def matshow3d( >>> plt.show() """ - vol: np.ndarray = convert_data_type(data=volume, output_type=np.ndarray)[0] # type: ignore + vol = convert_data_type(data=volume, output_type=np.ndarray)[0] if channel_dim is not None: if channel_dim not in [0, 1] or vol.shape[channel_dim] not in [1, 3, 4]: raise ValueError("channel_dim must be: None, 0 or 1, and channels of image must be 1, 3 or 4.") @@ -109,11 +109,11 @@ def matshow3d( while len(vol.shape) < 3: vol = np.expand_dims(vol, 0) # so that we display 2d as well - if channel_dim is not None: - vol = np.moveaxis(vol, frame_dim, -4) # move the expected dim to construct frames with `B` dim + if channel_dim is not None: # move the expected dim to construct frames with `B` dim + vol = np.moveaxis(vol, frame_dim, -4) # type: ignore vol = vol.reshape((-1, vol.shape[-3], vol.shape[-2], vol.shape[-1])) else: - vol = np.moveaxis(vol, frame_dim, -3) + vol = np.moveaxis(vol, frame_dim, -3) # type: ignore vol = vol.reshape((-1, vol.shape[-2], vol.shape[-1])) vmin = np.nanmin(vol) if vmin is None else vmin vmax = np.nanmax(vol) if vmax is None else vmax @@ -129,7 +129,7 @@ def matshow3d( if channel_dim is not None: width += [[0, 0]] # add pad width for the channel dim width += [[margin, margin]] * 2 - vol = np.pad(vol.astype(dtype, copy=False), width, mode="constant", constant_values=fill_value) + vol = np.pad(vol.astype(dtype, copy=False), width, mode="constant", constant_values=fill_value) # type: ignore im = np.block([[vol[i * cols + j] for j in range(cols)] for i in range(rows)]) if channel_dim is not None: # move channel dim to the end @@ -186,8 +186,7 @@ def blend_images( def get_label_rgb(cmap: str, label: NdarrayOrTensor): _cmap = cm.get_cmap(cmap) - label_np: np.ndarray - label_np, *_ = convert_data_type(label, np.ndarray) # type: ignore + label_np, *_ = convert_data_type(label, np.ndarray) label_rgb_np = _cmap(label_np[0]) label_rgb_np = np.moveaxis(label_rgb_np, -1, 0)[:3] label_rgb, *_ = convert_to_dst_type(label_rgb_np, label) diff --git a/tests/min_tests.py b/tests/min_tests.py index 783ab370c1..426650eb04 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -87,6 +87,7 @@ def run_testsuit(): "test_header_correct", "test_hilbert_transform", "test_image_dataset", + "test_image_rw", "test_img2tensorboard", "test_integration_fast_train", "test_integration_segmentation_3d", @@ -155,6 +156,7 @@ def run_testsuit(): "test_zoom_affine", "test_zoomd", "test_prepare_batch_default_dist", + "test_parallel_execution_dist", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_affined.py b/tests/test_affined.py index 355833b858..665c93d23f 100644 --- a/tests/test_affined.py +++ b/tests/test_affined.py @@ -28,6 +28,13 @@ p(np.arange(9).reshape(1, 3, 3)), ] ) + TESTS.append( + [ + dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0), device=device, dtype=None), + {"img": p(np.arange(9, dtype=float).reshape((1, 3, 3)))}, + p(np.arange(9).reshape(1, 3, 3)), + ] + ) TESTS.append( [ dict(keys="img", padding_mode="zeros", device=device), diff --git a/tests/test_apply_filter.py b/tests/test_apply_filter.py index ff8480a02f..3174211f34 100644 --- a/tests/test_apply_filter.py +++ b/tests/test_apply_filter.py @@ -80,7 +80,7 @@ def test_wrong_args(self): with self.assertRaisesRegex(NotImplementedError, ""): apply_filter(torch.ones((1, 1, 1, 2, 3, 2)), torch.ones((2,))) with self.assertRaisesRegex(TypeError, ""): - apply_filter(((1, 1, 1, 2, 3, 2)), torch.ones((2,))) # type: ignore + apply_filter(((1, 1, 1, 2, 3, 2)), torch.ones((2,))) if __name__ == "__main__": diff --git a/tests/test_cachedataset.py b/tests/test_cachedataset.py index a742f5889a..7227f53e04 100644 --- a/tests/test_cachedataset.py +++ b/tests/test_cachedataset.py @@ -42,24 +42,12 @@ class TestCacheDataset(unittest.TestCase): def test_shape(self, transform, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) with tempfile.TemporaryDirectory() as tempdir: - nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) - test_data = [ - { - "image": os.path.join(tempdir, "test_image1.nii.gz"), - "label": os.path.join(tempdir, "test_label1.nii.gz"), - "extra": os.path.join(tempdir, "test_extra1.nii.gz"), - }, - { - "image": os.path.join(tempdir, "test_image2.nii.gz"), - "label": os.path.join(tempdir, "test_label2.nii.gz"), - "extra": os.path.join(tempdir, "test_extra2.nii.gz"), - }, - ] + test_data = [] + for i in ["1", "2"]: + for k in ["image", "label", "extra"]: + nib.save(test_image, os.path.join(tempdir, f"{k}{i}.nii.gz")) + test_data.append({k: os.path.join(tempdir, f"{k}{i}.nii.gz") for k in ["image", "label", "extra"]}) + dataset = CacheDataset(data=test_data, transform=transform, cache_rate=0.5, as_contiguous=True) data1 = dataset[0] data2 = dataset[1] @@ -68,9 +56,9 @@ def test_shape(self, transform, expected_shape): self.assertEqual(len(data3), 1) if transform is None: - self.assertEqual(data1["image"], os.path.join(tempdir, "test_image1.nii.gz")) - self.assertEqual(data2["label"], os.path.join(tempdir, "test_label2.nii.gz")) - self.assertEqual(data4["image"], os.path.join(tempdir, "test_image2.nii.gz")) + self.assertEqual(data1["image"], os.path.join(tempdir, "image1.nii.gz")) + self.assertEqual(data2["label"], os.path.join(tempdir, "label2.nii.gz")) + self.assertEqual(data4["image"], os.path.join(tempdir, "image2.nii.gz")) else: self.assertTupleEqual(data1["image"].shape, expected_shape) self.assertTupleEqual(data1["label"].shape, expected_shape) @@ -195,6 +183,46 @@ def test_thread_safe(self, persistent_workers, cache_workers, loader_workers): self.assertListEqual(expected, [y.item() for y in loader]) self.assertListEqual(expected, [y.item() for y in loader]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_hash_as_key(self, transform, expected_shape): + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) + with tempfile.TemporaryDirectory() as tempdir: + test_data = [] + for i in ["1", "2", "2", "3", "3"]: + for k in ["image", "label", "extra"]: + nib.save(test_image, os.path.join(tempdir, f"{k}{i}.nii.gz")) + test_data.append({k: os.path.join(tempdir, f"{k}{i}.nii.gz") for k in ["image", "label", "extra"]}) + + dataset = CacheDataset(data=test_data, transform=transform, cache_num=4, num_workers=2, hash_as_key=True) + self.assertEqual(len(dataset), 5) + # ensure no duplicated cache content + self.assertEqual(len(dataset._cache), 3) + self.assertEqual(dataset.cache_num, 3) + data1 = dataset[0] + data2 = dataset[1] + data3 = dataset[-1] + # test slice indices + data4 = dataset[0:-1] + self.assertEqual(len(data4), 4) + + if transform is None: + self.assertEqual(data1["image"], os.path.join(tempdir, "image1.nii.gz")) + self.assertEqual(data2["label"], os.path.join(tempdir, "label2.nii.gz")) + self.assertEqual(data3["image"], os.path.join(tempdir, "image3.nii.gz")) + else: + self.assertTupleEqual(data1["image"].shape, expected_shape) + self.assertTupleEqual(data2["label"].shape, expected_shape) + self.assertTupleEqual(data3["image"].shape, expected_shape) + for d in data4: + self.assertTupleEqual(d["image"].shape, expected_shape) + + test_data2 = test_data[:3] + dataset.set_data(data=test_data2) + self.assertEqual(len(dataset), 3) + # ensure no duplicated cache content + self.assertEqual(len(dataset._cache), 2) + self.assertEqual(dataset.cache_num, 2) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_component_locator.py b/tests/test_component_locator.py new file mode 100644 index 0000000000..adff25457a --- /dev/null +++ b/tests/test_component_locator.py @@ -0,0 +1,35 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from pydoc import locate + +from monai.apps.mmars import ComponentLocator +from monai.utils import optional_import + +_, has_ignite = optional_import("ignite") + + +class TestComponentLocator(unittest.TestCase): + def test_locate(self): + locator = ComponentLocator(excludes=None if has_ignite else ["monai.handlers"]) + # test init mapping table and get the module path of component + self.assertEqual(locator.get_component_module_name("LoadImage"), "monai.transforms.io.array") + self.assertGreater(len(locator._components_table), 0) + for _, mods in locator._components_table.items(): + for i in mods: + self.assertGreater(len(mods), 0) + # ensure we can locate all the items by `name` + self.assertIsNotNone(locate(i), msg=f"can not locate target: {i}.") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_compose.py b/tests/test_compose.py index e0913f59e1..4d1bcfe01c 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -154,7 +154,7 @@ def __call__(self, data): c.randomize() def test_err_msg(self): - transforms = Compose([abs, AddChannel(), round]) + transforms = Compose([abs, AddChannel(), round], log_stats=False) with self.assertRaisesRegex(Exception, "AddChannel"): transforms(42.1) diff --git a/tests/test_config_component.py b/tests/test_config_component.py deleted file mode 100644 index d48a4e274f..0000000000 --- a/tests/test_config_component.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from typing import Callable, Iterator - -import torch -from parameterized import parameterized - -import monai -from monai.apps import ConfigComponent -from monai.data import DataLoader, Dataset -from monai.transforms import LoadImaged, RandTorchVisiond -from monai.utils import ClassScanner, optional_import - -_, has_tv = optional_import("torchvision") - -TEST_CASE_1 = [ - dict(pkgs=["monai"], modules=["transforms"]), - {"": "LoadImaged", "": {"keys": ["image"]}}, - LoadImaged, -] -# test python `` -TEST_CASE_2 = [ - dict(pkgs=[], modules=[]), - {"": "monai.transforms.LoadImaged", "": {"keys": ["image"]}}, - LoadImaged, -] -# test `` -TEST_CASE_3 = [ - dict(pkgs=["monai"], modules=["transforms"]), - {"": "LoadImaged", "": True, "": {"keys": ["image"]}}, - dict, -] -# test unresolved dependency -TEST_CASE_4 = [ - dict(pkgs=["monai"], modules=["transforms"]), - {"": "LoadImaged", "": {"keys": ["@key_name"]}}, - dict, -] -# test non-monai modules -TEST_CASE_5 = [ - dict(pkgs=["torch.optim", "monai"], modules=["adam"]), - {"": "Adam", "": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, - torch.optim.Adam, -] -# test args contains "name" field -TEST_CASE_6 = [ - dict(pkgs=["monai"], modules=["transforms"]), - {"": "RandTorchVisiond", "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}}, - RandTorchVisiond, -] -# test dependencies of dict config -TEST_CASE_7 = [{"dataset": "@dataset", "batch_size": 2}, ["test#dataset", "dataset", "test#batch_size"]] -# test dependencies of list config -TEST_CASE_8 = [ - {"dataset": "@dataset", "transforms": ["@trans0", "@trans1"]}, - ["test#dataset", "dataset", "test#transforms", "test#transforms#0", "trans0", "test#transforms#1", "trans1"], -] -# test dependencies of execute code -TEST_CASE_9 = [ - {"dataset": "$@dataset.test_func()", "transforms": ["$torch.zeros([2, 2]) + @trans"]}, - ["test#dataset", "dataset", "test#transforms", "test#transforms#0", "trans"], -] - -# test dependencies of lambda function -TEST_CASE_10 = [ - {"lr_range": "$lambda x: x + @num_epochs", "lr": ["$lambda x: torch.zeros([2, 2]) + @init_lr"]}, - ["test#lr_range", "num_epochs", "test#lr", "test#lr#0", "init_lr"], -] -# test instance with no dependencies -TEST_CASE_11 = [ - "transform#1", - {"": "LoadImaged", "": {"keys": ["image"]}}, - {"transform#1#": "LoadImaged", "transform#1#": {"keys": ["image"]}}, - LoadImaged, -] -# test dataloader refers to `@dataset`, here we don't test recursive dependencies, test that in `ConfigResolver` -TEST_CASE_12 = [ - "dataloader", - {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}}, - {"dataloader#": "DataLoader", "dataloader#": {"dataset": Dataset(data=[1, 2]), "batch_size": 2}}, - DataLoader, -] -# test dependencies in code execution -TEST_CASE_13 = [ - "optimizer", - {"": "Adam", "": {"params": "$@model.parameters()", "lr": "@learning_rate"}}, - {"optimizer#": "Adam", "optimizer#": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, - torch.optim.Adam, -] -# test replace dependencies with code execution result -TEST_CASE_14 = ["optimizer##params", "$@model.parameters()", {"model": torch.nn.PReLU()}, Iterator] -# test execute some function in args, test pre-imported global packages `monai` -TEST_CASE_15 = ["dataloader##collate_fn", "$monai.data.list_data_collate", {}, Callable] -# test lambda function, should not execute the lambda function, just change the string with dependent objects -TEST_CASE_16 = ["dataloader##collate_fn", "$lambda x: monai.data.list_data_collate(x) + 100", {}, Callable] - - -class TestConfigComponent(unittest.TestCase): - @parameterized.expand( - [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5] + ([TEST_CASE_6] if has_tv else []) - ) - def test_build(self, input_param, test_input, output_type): - scanner = ClassScanner(**input_param) - configer = ConfigComponent(id="test", config=test_input, class_scanner=scanner) - ret = configer.build() - self.assertTrue(isinstance(ret, output_type)) - if isinstance(ret, LoadImaged): - self.assertEqual(ret.keys[0], "image") - if isinstance(ret, dict): - # test `` works fine - self.assertDictEqual(ret, test_input) - - @parameterized.expand([TEST_CASE_7, TEST_CASE_8, TEST_CASE_9, TEST_CASE_10]) - def test_dependent_ids(self, test_input, ref_ids): - scanner = ClassScanner(pkgs=[], modules=[]) - configer = ConfigComponent(id="test", config=test_input, class_scanner=scanner) - ret = configer.get_dependent_ids() - self.assertListEqual(ret, ref_ids) - - @parameterized.expand([TEST_CASE_11, TEST_CASE_12, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15, TEST_CASE_16]) - def test_update_dependencies(self, id, test_input, deps, output_type): - scanner = ClassScanner(pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"]) - configer = ConfigComponent( - id=id, config=test_input, class_scanner=scanner, globals={"monai": monai, "torch": torch} - ) - config = configer.get_updated_config(deps) - ret = configer.build(config) - self.assertTrue(isinstance(ret, output_type)) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_config_item.py b/tests/test_config_item.py new file mode 100644 index 0000000000..af6a02802b --- /dev/null +++ b/tests/test_config_item.py @@ -0,0 +1,138 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import Callable, Iterator + +import torch +from parameterized import parameterized + +import monai +from monai.apps import ComponentLocator, ConfigComponent, ConfigItem, is_instantiable +from monai.data import DataLoader, Dataset +from monai.transforms import LoadImaged, RandTorchVisiond +from monai.utils import optional_import + +_, has_tv = optional_import("torchvision") + +TEST_CASE_1 = [{"": "LoadImaged", "": {"keys": ["image"]}}, LoadImaged] +# test python `` +TEST_CASE_2 = [{"": "monai.transforms.LoadImaged", "": {"keys": ["image"]}}, LoadImaged] +# test `` +TEST_CASE_3 = [{"": "LoadImaged", "": True, "": {"keys": ["image"]}}, dict] +# test unresolved reference +TEST_CASE_4 = [{"": "LoadImaged", "": {"keys": ["@key_name"]}}] +# test non-monai modules and excludes +TEST_CASE_5 = [ + {"": "torch.optim.Adam", "": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, + torch.optim.Adam, +] +TEST_CASE_6 = [{"": "decollate_batch", "": {"detach": True, "pad": True}}, partial] +# test args contains "name" field +TEST_CASE_7 = [ + {"": "RandTorchVisiond", "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}}, + RandTorchVisiond, +] +# test references of dict config +TEST_CASE_8 = [{"dataset": "@dataset", "batch_size": 2}, ["dataset"]] +# test references of list config +TEST_CASE_9 = [{"dataset": "@dataset", "transforms": ["@trans0", "@trans1"]}, ["dataset", "trans0", "trans1"]] +# test references of execute code +TEST_CASE_10 = [ + {"dataset": "$@dataset.test_func()", "transforms": ["$torch.zeros([2, 2]) + @trans"]}, + ["dataset", "trans"], +] +# test references of lambda function +TEST_CASE_11 = [ + {"lr_range": "$lambda x: x + @num_epochs", "lr": ["$lambda x: torch.zeros([2, 2]) + @init_lr"]}, + ["num_epochs", "init_lr"], +] +# test instance with no references +TEST_CASE_12 = ["transform#1", {"": "LoadImaged", "": {"keys": ["image"]}}, {}, LoadImaged] +# test dataloader refers to `@dataset`, here we don't test recursive references, test that in `ConfigResolver` +TEST_CASE_13 = [ + "dataloader", + {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}}, + {"dataset": Dataset(data=[1, 2])}, + DataLoader, +] +# test references in code execution +TEST_CASE_14 = [ + "optimizer", + {"": "torch.optim.Adam", "": {"params": "$@model.parameters()", "lr": "@learning_rate"}}, + {"model": torch.nn.PReLU(), "learning_rate": 1e-4}, + torch.optim.Adam, +] +# test replace references with code execution result +TEST_CASE_15 = ["optimizer##params", "$@model.parameters()", {"model": torch.nn.PReLU()}, Iterator] +# test execute some function in args, test pre-imported global packages `monai` +TEST_CASE_16 = ["dataloader##collate_fn", "$monai.data.list_data_collate", {}, Callable] +# test lambda function, should not execute the lambda function, just change the string with referring objects +TEST_CASE_17 = ["dataloader##collate_fn", "$lambda x: monai.data.list_data_collate(x) + 100", {}, Callable] + + +class TestConfigItem(unittest.TestCase): + @parameterized.expand( + [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_5, TEST_CASE_6] + ([TEST_CASE_7] if has_tv else []) + ) + def test_instantiate(self, test_input, output_type): + locator = ComponentLocator(excludes=["metrics"]) + configer = ConfigComponent(id="test", config=test_input, locator=locator) + ret = configer.instantiate() + if test_input.get("", False): + # test `` works fine + self.assertEqual(ret, None) + return + self.assertTrue(isinstance(ret, output_type)) + if isinstance(ret, LoadImaged): + self.assertEqual(ret.keys[0], "image") + + @parameterized.expand([TEST_CASE_4]) + def test_raise_error(self, test_input): + with self.assertRaises(KeyError): # has unresolved keys + configer = ConfigItem(id="test", config=test_input) + configer.resolve() + + @parameterized.expand([TEST_CASE_8, TEST_CASE_9, TEST_CASE_10, TEST_CASE_11]) + def test_referring_ids(self, test_input, ref_ids): + configer = ConfigItem(id="test", config=test_input) # also test default locator + ret = configer.get_id_of_refs() + self.assertListEqual(ret, ref_ids) + + @parameterized.expand([TEST_CASE_12, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15, TEST_CASE_16, TEST_CASE_17]) + def test_resolve_references(self, id, test_input, refs, output_type): + configer = ConfigComponent( + id=id, config=test_input, locator=None, excludes=["utils"], globals={"monai": monai, "torch": torch} + ) + configer.resolve(refs=refs) + ret = configer.get_resolved_config() + if is_instantiable(ret): + ret = configer.instantiate(**{}) # also test kwargs + self.assertTrue(isinstance(ret, output_type)) + + def test_lazy_instantiation(self): + config = {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}} + refs = {"dataset": Dataset(data=[1, 2])} + configer = ConfigComponent(config=config, locator=None) + init_config = configer.get_config() + # modify config content at runtime + init_config[""]["batch_size"] = 4 + configer.update_config(config=init_config) + + configer.resolve(refs=refs) + ret = configer.instantiate() + self.assertTrue(isinstance(ret, DataLoader)) + self.assertEqual(ret.batch_size, 4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_cross_validation.py b/tests/test_cross_validation.py index 72c7b53506..c378a52f78 100644 --- a/tests/test_cross_validation.py +++ b/tests/test_cross_validation.py @@ -11,12 +11,11 @@ import os import unittest -from urllib.error import ContentTooShortError, HTTPError from monai.apps import CrossValidation, DecathlonDataset from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord from monai.utils.enums import PostFix -from tests.utils import skip_if_quick +from tests.utils import skip_if_downloading_fails, skip_if_quick class TestCrossValidation(unittest.TestCase): @@ -51,14 +50,8 @@ def _test_dataset(dataset): download=True, ) - try: # will start downloading if testing_dir doesn't have the Decathlon files + with skip_if_downloading_fails(): data = cvdataset.get_dataset(folds=0) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors _test_dataset(data) diff --git a/tests/test_data_stats.py b/tests/test_data_stats.py index 65a61ce7ec..c18abfcedc 100644 --- a/tests/test_data_stats.py +++ b/tests/test_data_stats.py @@ -29,7 +29,7 @@ "value_range": False, "data_value": False, "additional_info": None, - "logger_handler": None, + "name": "DataStats", }, np.array([[0, 1], [1, 2]]), "test data statistics:", @@ -43,7 +43,7 @@ "value_range": False, "data_value": False, "additional_info": None, - "logger_handler": None, + "name": "DataStats", }, np.array([[0, 1], [1, 2]]), "test data statistics:\nType: ", @@ -57,7 +57,7 @@ "value_range": False, "data_value": False, "additional_info": None, - "logger_handler": None, + "name": "DataStats", }, np.array([[0, 1], [1, 2]]), "test data statistics:\nType: \nShape: (2, 2)", @@ -71,7 +71,7 @@ "value_range": True, "data_value": False, "additional_info": None, - "logger_handler": None, + "name": "DataStats", }, np.array([[0, 1], [1, 2]]), "test data statistics:\nType: \nShape: (2, 2)\nValue range: (0, 2)", @@ -85,7 +85,7 @@ "value_range": True, "data_value": True, "additional_info": None, - "logger_handler": None, + "name": "DataStats", }, np.array([[0, 1], [1, 2]]), "test data statistics:\nType: \nShape: (2, 2)\nValue range: (0, 2)\nValue: [[0 1]\n [1 2]]", @@ -99,7 +99,7 @@ "value_range": True, "data_value": True, "additional_info": np.mean, - "logger_handler": None, + "name": "DataStats", }, np.array([[0, 1], [1, 2]]), ( @@ -116,7 +116,7 @@ "value_range": True, "data_value": True, "additional_info": lambda x: torch.mean(x.float()), - "logger_handler": None, + "name": "DataStats", }, torch.tensor([[0, 1], [1, 2]]).to("cuda" if torch.cuda.is_available() else "cpu"), ( @@ -145,6 +145,9 @@ def test_file(self, input_data, expected_print): filename = os.path.join(tempdir, "test_data_stats.log") handler = logging.FileHandler(filename, mode="w") handler.setLevel(logging.INFO) + name = "DataStats" + logger = logging.getLogger(name) + logger.addHandler(handler) input_param = { "prefix": "test data", "data_type": True, @@ -152,14 +155,13 @@ def test_file(self, input_data, expected_print): "value_range": True, "data_value": True, "additional_info": np.mean, - "logger_handler": handler, + "name": name, } transform = DataStats(**input_param) _ = transform(input_data) - _logger = logging.getLogger(transform._logger_name) - for h in _logger.handlers[:]: + for h in logger.handlers[:]: h.close() - _logger.removeHandler(h) + logger.removeHandler(h) with open(filename) as f: content = f.read() if sys.platform != "win32": diff --git a/tests/test_data_statsd.py b/tests/test_data_statsd.py index 1f38db2b05..28da936cd0 100644 --- a/tests/test_data_statsd.py +++ b/tests/test_data_statsd.py @@ -30,7 +30,7 @@ "value_range": False, "data_value": False, "additional_info": None, - "logger_handler": None, + "name": "DataStats", }, {"img": np.array([[0, 1], [1, 2]])}, "test data statistics:", @@ -45,7 +45,7 @@ "value_range": False, "data_value": False, "additional_info": None, - "logger_handler": None, + "name": "DataStats", }, {"img": np.array([[0, 1], [1, 2]])}, "test data statistics:\nType: ", @@ -60,7 +60,7 @@ "value_range": False, "data_value": False, "additional_info": None, - "logger_handler": None, + "name": "DataStats", }, {"img": np.array([[0, 1], [1, 2]])}, "test data statistics:\nType: \nShape: (2, 2)", @@ -75,7 +75,7 @@ "value_range": True, "data_value": False, "additional_info": None, - "logger_handler": None, + "name": "DataStats", }, {"img": np.array([[0, 1], [1, 2]])}, "test data statistics:\nType: \nShape: (2, 2)\nValue range: (0, 2)", @@ -90,7 +90,7 @@ "value_range": True, "data_value": True, "additional_info": None, - "logger_handler": None, + "name": "DataStats", }, {"img": np.array([[0, 1], [1, 2]])}, "test data statistics:\nType: \nShape: (2, 2)\nValue range: (0, 2)\nValue: [[0 1]\n [1 2]]", @@ -105,7 +105,7 @@ "value_range": True, "data_value": True, "additional_info": np.mean, - "logger_handler": None, + "name": "DataStats", }, {"img": np.array([[0, 1], [1, 2]])}, ( @@ -123,7 +123,7 @@ "value_range": True, "data_value": True, "additional_info": lambda x: torch.mean(x.float()), - "logger_handler": None, + "name": "DataStats", }, {"img": torch.tensor([[0, 1], [1, 2]]).to("cuda" if torch.cuda.is_available() else "cpu")}, ( @@ -141,6 +141,7 @@ "value_range": (True, False), "data_value": (False, True), "additional_info": (np.mean, None), + "name": "DataStats", }, {"img": np.array([[0, 1], [1, 2]]), "affine": np.eye(2, 2)}, "affine statistics:\nType: \nShape: (2, 2)\nValue: [[1. 0.]\n [0. 1.]]", @@ -168,6 +169,9 @@ def test_file(self, input_data, expected_print): filename = os.path.join(tempdir, "test_stats.log") handler = logging.FileHandler(filename, mode="w") handler.setLevel(logging.INFO) + name = "DataStats" + logger = logging.getLogger(name) + logger.addHandler(handler) input_param = { "keys": "img", "prefix": "test data", @@ -175,14 +179,13 @@ def test_file(self, input_data, expected_print): "value_range": True, "data_value": True, "additional_info": np.mean, - "logger_handler": handler, + "name": name, } transform = DataStatsd(**input_param) _ = transform(input_data) - _logger = logging.getLogger(transform.printer._logger_name) - for h in _logger.handlers[:]: + for h in logger.handlers[:]: h.close() - _logger.removeHandler(h) + logger.removeHandler(h) del handler with open(filename) as f: content = f.read() diff --git a/tests/test_decathlondataset.py b/tests/test_decathlondataset.py index a5d9ce3e27..744dccefaa 100644 --- a/tests/test_decathlondataset.py +++ b/tests/test_decathlondataset.py @@ -13,12 +13,11 @@ import shutil import unittest from pathlib import Path -from urllib.error import ContentTooShortError, HTTPError from monai.apps import DecathlonDataset from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord from monai.utils.enums import PostFix -from tests.utils import skip_if_quick +from tests.utils import skip_if_downloading_fails, skip_if_quick class TestDecathlonDataset(unittest.TestCase): @@ -41,7 +40,7 @@ def _test_dataset(dataset): self.assertTrue(PostFix.meta("image") in dataset[0]) self.assertTupleEqual(dataset[0]["image"].shape, (1, 36, 47, 44)) - try: # will start downloading if testing_dir doesn't have the Decathlon files + with skip_if_downloading_fails(): data = DecathlonDataset( root_dir=testing_dir, task="Task04_Hippocampus", @@ -50,12 +49,6 @@ def _test_dataset(dataset): download=True, copy_cache=False, ) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors _test_dataset(data) data = DecathlonDataset( diff --git a/tests/test_download_and_extract.py b/tests/test_download_and_extract.py index f896d4ae93..435b280022 100644 --- a/tests/test_download_and_extract.py +++ b/tests/test_download_and_extract.py @@ -16,7 +16,7 @@ from urllib.error import ContentTooShortError, HTTPError from monai.apps import download_and_extract, download_url, extractall -from tests.utils import skip_if_quick +from tests.utils import skip_if_downloading_fails, skip_if_quick class TestDownloadAndExtract(unittest.TestCase): @@ -27,22 +27,15 @@ def test_actions(self): filepath = Path(testing_dir) / "MedNIST.tar.gz" output_dir = Path(testing_dir) md5_value = "0bc7306e7427e00ad1c5526a6677552d" - try: + with skip_if_downloading_fails(): download_and_extract(url, filepath, output_dir, md5_value) download_and_extract(url, filepath, output_dir, md5_value) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors wrong_md5 = "0" with self.assertLogs(logger="monai.apps", level="ERROR"): try: download_url(url, filepath, wrong_md5) except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) if isinstance(e, RuntimeError): # FIXME: skip MD5 check as current downloading method may fail self.assertTrue(str(e).startswith("md5 check")) @@ -56,7 +49,7 @@ def test_actions(self): @skip_if_quick def test_default(self): with tempfile.TemporaryDirectory() as tmp_dir: - try: + with skip_if_downloading_fails(): # icon.tar.gz https://drive.google.com/file/d/1HrQd-AKPbts9jkTNN4pT8vLZyhM5irVn/view?usp=sharing download_and_extract( "https://drive.google.com/uc?id=1HrQd-AKPbts9jkTNN4pT8vLZyhM5irVn", @@ -71,12 +64,6 @@ def test_default(self): hash_val="ac6e167ee40803577d98237f2b0241e5", file_type="zip", ) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors if __name__ == "__main__": diff --git a/tests/test_efficientnet.py b/tests/test_efficientnet.py index 0ab383fd56..a2a5e30750 100644 --- a/tests/test_efficientnet.py +++ b/tests/test_efficientnet.py @@ -13,7 +13,6 @@ import unittest from typing import TYPE_CHECKING from unittest import skipUnless -from urllib.error import ContentTooShortError, HTTPError import torch from parameterized import parameterized @@ -27,7 +26,7 @@ get_efficientnet_image_size, ) from monai.utils import optional_import -from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save +from tests.utils import skip_if_downloading_fails, skip_if_quick, test_pretrained_networks, test_script_save if TYPE_CHECKING: import torchvision @@ -251,12 +250,8 @@ class TestEFFICIENTNET(unittest.TestCase): def test_shape(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" - try: - # initialize model + with skip_if_downloading_fails(): net = EfficientNetBN(**input_param).to(device) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - return # skipping the tests because of http errors # run inference with random tensor with eval_mode(net): @@ -269,12 +264,8 @@ def test_shape(self, input_param, input_shape, expected_shape): def test_non_default_shapes(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" - try: - # initialize model + with skip_if_downloading_fails(): net = EfficientNetBN(**input_param).to(device) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - return # skipping the tests because of http errors # override input shape with different variations num_dims = len(input_shape) - 2 @@ -387,12 +378,8 @@ class TestExtractFeatures(unittest.TestCase): def test_shape(self, input_param, input_shape, expected_shapes): device = "cuda" if torch.cuda.is_available() else "cpu" - try: - # initialize model + with skip_if_downloading_fails(): net = EfficientNetBNFeatures(**input_param).to(device) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - return # skipping the tests because of http errors # run inference with random tensor with eval_mode(net): diff --git a/tests/test_ensemble_evaluator.py b/tests/test_ensemble_evaluator.py index c7554e9421..dab46f366f 100644 --- a/tests/test_ensemble_evaluator.py +++ b/tests/test_ensemble_evaluator.py @@ -13,12 +13,18 @@ import torch from ignite.engine import EventEnum, Events +from parameterized import parameterized from monai.engines import EnsembleEvaluator +TEST_CASE_1 = [["pred_0", "pred_1", "pred_2", "pred_3", "pred_4"]] + +TEST_CASE_2 = [None] + class TestEnsembleEvaluator(unittest.TestCase): - def test_content(self): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_content(self, pred_keys): device = torch.device("cpu:0") class TestDataset(torch.utils.data.Dataset): @@ -52,7 +58,7 @@ class CustomEvents(EventEnum): device=device, val_data_loader=val_loader, networks=[net0, net1, net2, net3, net4], - pred_keys=["pred0", "pred1", "pred2", "pred3", "pred4"], + pred_keys=pred_keys, event_names=["bwd_event", "opt_event", CustomEvents], event_to_attr={CustomEvents.FOO_EVENT: "foo", "opt_event": "opt"}, ) @@ -61,7 +67,7 @@ class CustomEvents(EventEnum): def run_transform(engine): for i in range(5): expected_value = engine.state.iteration + i - torch.testing.assert_allclose(engine.state.output[0][f"pred{i}"].item(), expected_value) + torch.testing.assert_allclose(engine.state.output[0][f"pred_{i}"].item(), expected_value) @val_engine.on(Events.EPOCH_COMPLETED) def trigger_custom_event(): diff --git a/tests/test_global_mutual_information_loss.py b/tests/test_global_mutual_information_loss.py index a919396e3e..6aa5455352 100644 --- a/tests/test_global_mutual_information_loss.py +++ b/tests/test_global_mutual_information_loss.py @@ -16,7 +16,7 @@ from monai import transforms from monai.losses.image_dissimilarity import GlobalMutualInformationLoss -from tests.utils import SkipIfBeforePyTorchVersion, download_url_or_skip_test +from tests.utils import SkipIfBeforePyTorchVersion, download_url_or_skip_test, skip_if_quick device = "cuda" if torch.cuda.is_available() else "cpu" @@ -51,6 +51,7 @@ } +@skip_if_quick class TestGlobalMutualInformationLoss(unittest.TestCase): def setUp(self): download_url_or_skip_test(FILE_URL, FILE_PATH) @@ -100,6 +101,8 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0. result = loss_fn(a2, a1).detach().cpu().numpy() np.testing.assert_allclose(result, expected_value, rtol=1e-3, atol=5e-3) + +class TestGlobalMutualInformationLossIll(unittest.TestCase): def test_ill_shape(self): loss = GlobalMutualInformationLoss() with self.assertRaisesRegex(ValueError, ""): diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py index ec43ed357b..2331602234 100644 --- a/tests/test_handler_checkpoint_loader.py +++ b/tests/test_handler_checkpoint_loader.py @@ -9,8 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys import tempfile import unittest @@ -23,7 +21,6 @@ class TestHandlerCheckpointLoader(unittest.TestCase): def test_one_save_one_load(self): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) net1 = torch.nn.PReLU() data1 = net1.state_dict() data1["weight"] = torch.tensor([0.1]) @@ -58,7 +55,6 @@ def check_epoch(engine: Engine): self.assertEqual(engine3.state.max_epochs, 5) def test_two_save_one_load(self): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) net1 = torch.nn.PReLU() optimizer = optim.SGD(net1.parameters(), lr=0.02) data1 = net1.state_dict() @@ -80,7 +76,6 @@ def test_two_save_one_load(self): torch.testing.assert_allclose(net2.state_dict()["weight"], torch.tensor([0.1])) def test_save_single_device_load_multi_devices(self): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) net1 = torch.nn.PReLU() data1 = net1.state_dict() data1["weight"] = torch.tensor([0.1]) @@ -101,7 +96,6 @@ def test_save_single_device_load_multi_devices(self): torch.testing.assert_allclose(net2.state_dict()["module.weight"].cpu(), torch.tensor([0.1])) def test_partial_under_load(self): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) net1 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()]) data1 = net1.state_dict() data1["0.weight"] = torch.tensor([0.1]) @@ -124,7 +118,6 @@ def test_partial_under_load(self): torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.1])) def test_partial_over_load(self): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) net1 = torch.nn.Sequential(*[torch.nn.PReLU()]) data1 = net1.state_dict() data1["0.weight"] = torch.tensor([0.1]) @@ -147,7 +140,6 @@ def test_partial_over_load(self): torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.1])) def test_strict_shape(self): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) net1 = torch.nn.Sequential(*[torch.nn.PReLU(num_parameters=5)]) data1 = net1.state_dict() data1["0.weight"] = torch.tensor([1, 2, 3, 4, 5]) diff --git a/tests/test_handler_checkpoint_saver.py b/tests/test_handler_checkpoint_saver.py index 1b746184a4..c87866490c 100644 --- a/tests/test_handler_checkpoint_saver.py +++ b/tests/test_handler_checkpoint_saver.py @@ -9,9 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import os -import sys import tempfile import unittest @@ -131,7 +129,6 @@ def test_file( filenames, multi_devices=False, ): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) data = [0] * 8 # set up engine @@ -171,7 +168,6 @@ def _train_func(engine, batch): self.assertTrue(os.path.exists(os.path.join(tempdir, filename))) def test_exception(self): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) net = torch.nn.PReLU() # set up engine @@ -190,7 +186,6 @@ def _train_func(engine, batch): self.assertTrue(os.path.exists(os.path.join(tempdir, "net_final_iteration=1.pt"))) def test_load_state_dict(self): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) net = torch.nn.PReLU() # set up engine diff --git a/tests/test_handler_lr_scheduler.py b/tests/test_handler_lr_scheduler.py index b260686315..15401fe1b2 100644 --- a/tests/test_handler_lr_scheduler.py +++ b/tests/test_handler_lr_scheduler.py @@ -12,7 +12,6 @@ import logging import os import re -import sys import tempfile import unittest @@ -25,7 +24,6 @@ class TestHandlerLrSchedule(unittest.TestCase): def test_content(self): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) data = [0] * 8 test_lr = 0.1 gamma = 0.1 @@ -59,20 +57,22 @@ def _reduce_lr_on_plateau(): # test with additional logging handler file_saver = logging.FileHandler(filename, mode="w") file_saver.setLevel(logging.INFO) + logger = logging.getLogger(key_to_handler) + logger.setLevel(logging.INFO) + logger.addHandler(file_saver) def _reduce_on_step(): optimizer = torch.optim.SGD(net.parameters(), test_lr) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=gamma) - handler = LrScheduleHandler(lr_scheduler, name=key_to_handler, logger_handler=file_saver) + handler = LrScheduleHandler(lr_scheduler, name=key_to_handler) handler.attach(train_engine) - handler.logger.setLevel(logging.INFO) return handler schedulers = _reduce_lr_on_plateau(), _reduce_on_step() train_engine.run(data, max_epochs=5) file_saver.close() - schedulers[1].logger.removeHandler(file_saver) + logger.removeHandler(file_saver) with open(filename) as f: output_str = f.read() diff --git a/tests/test_handler_segmentation_saver.py b/tests/test_handler_segmentation_saver.py index 3632a98cfc..ee6566f6cb 100644 --- a/tests/test_handler_segmentation_saver.py +++ b/tests/test_handler_segmentation_saver.py @@ -39,7 +39,9 @@ def _train_func(engine, batch): engine = Engine(_train_func) # set up testing handler - saver = SegmentationSaver(output_dir=tempdir, output_postfix="seg", output_ext=output_ext, scale=255) + saver = SegmentationSaver( + output_dir=tempdir, output_postfix="seg", output_ext=output_ext, scale=255, output_dtype=np.uint8 + ) saver.attach(engine) data = [ @@ -65,7 +67,9 @@ def _train_func(engine, batch): engine = Engine(_train_func) # set up testing handler - saver = SegmentationSaver(output_dir=tempdir, output_postfix="seg", output_ext=output_ext, scale=255) + saver = SegmentationSaver( + output_dir=tempdir, output_postfix="seg", output_ext=output_ext, scale=255, output_dtype=np.uint8 + ) saver.attach(engine) data = [ diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index ee0df74002..a93f2c5e5f 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -43,9 +43,11 @@ def _update_metric(engine): engine.state.metrics[key_to_print] = current_metric + 0.1 # set up testing handler - stats_handler = StatsHandler(name=key_to_handler, logger_handler=log_handler) + logger = logging.getLogger(key_to_handler) + logger.setLevel(logging.INFO) + logger.addHandler(log_handler) + stats_handler = StatsHandler(name=key_to_handler) stats_handler.attach(engine) - stats_handler.logger.setLevel(logging.INFO) engine.run(range(3), max_epochs=2) @@ -73,9 +75,11 @@ def _train_func(engine, batch): engine = Engine(_train_func) # set up testing handler - stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print, logger_handler=log_handler) + logger = logging.getLogger(key_to_handler) + logger.setLevel(logging.INFO) + logger.addHandler(log_handler) + stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print) stats_handler.attach(engine) - stats_handler.logger.setLevel(logging.INFO) engine.run(range(3), max_epochs=2) @@ -103,11 +107,11 @@ def _train_func(engine, batch): engine = Engine(_train_func) # set up testing handler - stats_handler = StatsHandler( - name=key_to_handler, output_transform=lambda x: {key_to_print: x[0]}, logger_handler=log_handler - ) + logger = logging.getLogger(key_to_handler) + logger.setLevel(logging.INFO) + logger.addHandler(log_handler) + stats_handler = StatsHandler(name=key_to_handler, output_transform=lambda x: {key_to_print: x[0]}) stats_handler.attach(engine) - stats_handler.logger.setLevel(logging.INFO) engine.run(range(3), max_epochs=2) @@ -137,9 +141,11 @@ def _train_func(engine, batch): engine = Engine(_train_func) # set up testing handler - stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print, logger_handler=handler) + logger = logging.getLogger(key_to_handler) + logger.setLevel(logging.INFO) + logger.addHandler(handler) + stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print) stats_handler.attach(engine) - stats_handler.logger.setLevel(logging.INFO) engine.run(range(3), max_epochs=2) handler.close() @@ -190,11 +196,11 @@ def _update_metric(engine): engine.state.test2 += 0.2 # set up testing handler - stats_handler = StatsHandler( - name=key_to_handler, state_attributes=["test1", "test2", "test3"], logger_handler=log_handler - ) + logger = logging.getLogger(key_to_handler) + logger.setLevel(logging.INFO) + logger.addHandler(log_handler) + stats_handler = StatsHandler(name=key_to_handler, state_attributes=["test1", "test2", "test3"]) stats_handler.attach(engine) - stats_handler.logger.setLevel(logging.INFO) engine.run(range(3), max_epochs=2) @@ -208,6 +214,39 @@ def _update_metric(engine): content_count += 1 self.assertTrue(content_count > 0) + def test_default_logger(self): + log_stream = StringIO() + log_handler = logging.StreamHandler(log_stream) + log_handler.setLevel(logging.INFO) + key_to_print = "myLoss" + + # set up engine + def _train_func(engine, batch): + return [torch.tensor(0.0)] + + engine = Engine(_train_func) + engine.logger.addHandler(log_handler) + + # set up testing handler + stats_handler = StatsHandler(name=None, tag_name=key_to_print) + stats_handler.attach(engine) + # leverage `engine.logger` to print info + engine.logger.setLevel(logging.INFO) + level = logging.root.getEffectiveLevel() + logging.basicConfig(level=logging.INFO) + engine.run(range(3), max_epochs=2) + logging.basicConfig(level=level) + + # check logging output + output_str = log_stream.getvalue() + log_handler.close() + has_key_word = re.compile(f".*{key_to_print}.*") + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + self.assertTrue(content_count > 0) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_image_rw.py b/tests/test_image_rw.py new file mode 100644 index 0000000000..e1079e63f7 --- /dev/null +++ b/tests/test_image_rw.py @@ -0,0 +1,136 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import os +import shutil +import tempfile +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.data.image_reader import ITKReader, NibabelReader, PILReader +from monai.data.image_writer import ITKWriter, NibabelWriter, PILWriter, register_writer, resolve_writer +from monai.transforms import LoadImage, SaveImage, moveaxis +from monai.utils import OptionalImportError +from tests.utils import TEST_NDARRAYS, assert_allclose + + +class TestLoadSaveNifti(unittest.TestCase): + def setUp(self): + self.test_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.test_dir, ignore_errors=True) + + def nifti_rw(self, test_data, reader, writer, dtype, resample=True): + test_data = test_data.astype(dtype) + ndim = len(test_data.shape) - 1 + for p in TEST_NDARRAYS: + output_ext = ".nii.gz" + filepath = f"testfile_{ndim}d" + saver = SaveImage( + output_dir=self.test_dir, output_ext=output_ext, resample=resample, separate_folder=False, writer=writer + ) + saver( + p(test_data), + { + "filename_or_obj": f"{filepath}.png", + "affine": np.eye(4), + "original_affine": np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]), + }, + ) + saved_path = os.path.join(self.test_dir, filepath + "_trans" + output_ext) + self.assertTrue(os.path.exists(saved_path)) + loader = LoadImage(reader=reader, squeeze_non_spatial_dims=True) + data, meta = loader(saved_path) + if meta["original_channel_dim"] == -1: + _test_data = moveaxis(test_data, 0, -1) + else: + _test_data = test_data[0] + if resample: + _test_data = moveaxis(_test_data, 0, 1) + assert_allclose(data, _test_data) + + @parameterized.expand(itertools.product([NibabelReader, ITKReader], [NibabelWriter, ITKWriter])) + def test_2d(self, reader, writer): + test_data = np.arange(48, dtype=np.uint8).reshape(1, 6, 8) + self.nifti_rw(test_data, reader, writer, np.uint8) + self.nifti_rw(test_data, reader, writer, np.float32) + + @parameterized.expand(itertools.product([NibabelReader, ITKReader], [NibabelWriter, ITKWriter])) + def test_3d(self, reader, writer): + test_data = np.arange(48, dtype=np.uint8).reshape(1, 2, 3, 8) + self.nifti_rw(test_data, reader, writer, int) + self.nifti_rw(test_data, reader, writer, int, False) + + @parameterized.expand(itertools.product([NibabelReader, ITKReader], [NibabelWriter, ITKWriter])) + def test_4d(self, reader, writer): + test_data = np.arange(48, dtype=np.uint8).reshape(2, 1, 3, 8) + self.nifti_rw(test_data, reader, writer, np.float16) + + +class TestLoadSavePNG(unittest.TestCase): + def setUp(self): + self.test_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.test_dir, ignore_errors=True) + + def png_rw(self, test_data, reader, writer, dtype, resample=True): + test_data = test_data.astype(dtype) + ndim = len(test_data.shape) - 1 + for p in TEST_NDARRAYS: + output_ext = ".png" + filepath = f"testfile_{ndim}d" + saver = SaveImage( + output_dir=self.test_dir, output_ext=output_ext, resample=resample, separate_folder=False, writer=writer + ) + saver(p(test_data), {"filename_or_obj": f"{filepath}.png", "spatial_shape": (6, 8)}) + saved_path = os.path.join(self.test_dir, filepath + "_trans" + output_ext) + self.assertTrue(os.path.exists(saved_path)) + loader = LoadImage(reader=reader) + data, meta = loader(saved_path) + if meta["original_channel_dim"] == -1: + _test_data = moveaxis(test_data, 0, -1) + else: + _test_data = test_data[0] + assert_allclose(data, _test_data) + + @parameterized.expand(itertools.product([PILReader, ITKReader], [PILWriter, ITKWriter])) + def test_2d(self, reader, writer): + test_data = np.arange(48, dtype=np.uint8).reshape(1, 6, 8) + self.png_rw(test_data, reader, writer, np.uint8) + + @parameterized.expand(itertools.product([PILReader, ITKReader], [PILWriter, ITKWriter])) + def test_rgb(self, reader, writer): + test_data = np.arange(48, dtype=np.uint8).reshape(3, 2, 8) + self.png_rw(test_data, reader, writer, np.uint8, False) + + +class TestRegRes(unittest.TestCase): + def test_0_default(self): + self.assertTrue(len(resolve_writer(".png")) > 0, "has png writer") + self.assertTrue(len(resolve_writer(".nrrd")) > 0, "has nrrd writer") + self.assertTrue(len(resolve_writer("unknown")) > 0, "has writer") + register_writer("unknown1", lambda: (_ for _ in ()).throw(OptionalImportError)) + with self.assertRaises(OptionalImportError): + resolve_writer("unknown1") + + def test_1_new(self): + register_writer("new", lambda x: x + 1) + register_writer("new2", lambda x: x + 1) + self.assertEqual(resolve_writer("new")[0](0), 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_integration_classification_2d.py b/tests/test_integration_classification_2d.py index 2c0c9e1f2e..0e79a26ea7 100644 --- a/tests/test_integration_classification_2d.py +++ b/tests/test_integration_classification_2d.py @@ -12,7 +12,6 @@ import os import unittest import warnings -from urllib.error import ContentTooShortError, HTTPError import numpy as np import torch @@ -39,7 +38,7 @@ ) from monai.utils import set_determinism from tests.testing_data.integration_answers import test_integration_value -from tests.utils import DistTestCase, TimedCall, skip_if_quick +from tests.utils import DistTestCase, TimedCall, skip_if_downloading_fails, skip_if_quick TEST_DATA_URL = "https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE" MD5_VALUE = "0bc7306e7427e00ad1c5526a6677552d" @@ -69,7 +68,7 @@ def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0", AddChannel(), Transpose(indices=[0, 2, 1]), ScaleIntensity(), - RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True), + RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True, dtype=np.float64), RandFlip(spatial_axis=0, prob=0.5), RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5), ToTensor(), @@ -186,14 +185,8 @@ def setUp(self): dataset_file = os.path.join(self.data_dir, "MedNIST.tar.gz") if not os.path.exists(data_dir): - try: + with skip_if_downloading_fails(): download_and_extract(TEST_DATA_URL, dataset_file, self.data_dir, MD5_VALUE) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors assert os.path.exists(data_dir) diff --git a/tests/test_integration_fast_train.py b/tests/test_integration_fast_train.py index 3522a57342..51b2ac1d3f 100644 --- a/tests/test_integration_fast_train.py +++ b/tests/test_integration_fast_train.py @@ -123,6 +123,7 @@ def test_train_timing(self): range_x=np.pi / 4, mode=("bilinear", "nearest"), align_corners=True, + dtype=np.float64, ), RandAffined(keys=["image", "label"], prob=0.5, rotate_range=np.pi / 2, mode=("bilinear", "nearest")), RandGaussianNoised(keys="image", prob=0.5), diff --git a/tests/test_integration_segmentation_3d.py b/tests/test_integration_segmentation_3d.py index 718c9291fb..5c273d0a46 100644 --- a/tests/test_integration_segmentation_3d.py +++ b/tests/test_integration_segmentation_3d.py @@ -21,7 +21,7 @@ from torch.utils.tensorboard import SummaryWriter import monai -from monai.data import NiftiSaver, create_test_image_3d, decollate_batch +from monai.data import create_test_image_3d, decollate_batch from monai.inferers import sliding_window_inference from monai.metrics import DiceMetric from monai.networks import eval_mode @@ -34,6 +34,7 @@ LoadImaged, RandCropByPosNegLabeld, RandRotate90d, + SaveImage, ScaleIntensityd, Spacingd, ToTensor, @@ -213,17 +214,25 @@ def run_inference_test(root_dir, device="cuda:0"): with eval_mode(model): # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 - saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"), dtype=np.float32) + saver = SaveImage( + output_dir=os.path.join(root_dir, "output"), + dtype=np.float32, + output_ext=".nii.gz", + output_postfix="seg", + mode="bilinear", + ) for val_data in val_loader: val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) - # decollate prediction into a list and execute post processing for every item + # decollate prediction into a list val_outputs = [val_post_tran(i) for i in decollate_batch(val_outputs)] + val_meta = decollate_batch(val_data[PostFix.meta("img")]) # compute metrics dice_metric(y_pred=val_outputs, y=val_labels) - saver.save_batch(val_outputs, val_data[PostFix.meta("img")]) + for img, meta in zip(val_outputs, val_meta): # save a decollated batch of files + saver(img, meta) return dice_metric.aggregate().item() diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 7c0ac1da45..5169289776 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -9,10 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import os import shutil -import sys import tempfile import unittest import warnings @@ -295,7 +293,6 @@ def setUp(self): self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0") monai.config.print_config() - logging.basicConfig(stream=sys.stdout, level=logging.INFO) def tearDown(self): set_determinism(seed=None) diff --git a/tests/test_integration_workflows_gan.py b/tests/test_integration_workflows_gan.py index 790b222ea0..c9306b349f 100644 --- a/tests/test_integration_workflows_gan.py +++ b/tests/test_integration_workflows_gan.py @@ -9,10 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import os import shutil -import sys import tempfile import unittest from glob import glob @@ -138,7 +136,6 @@ def setUp(self): self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0") monai.config.print_config() - logging.basicConfig(stream=sys.stdout, level=logging.INFO) def tearDown(self): set_determinism(seed=None) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 4455009658..c04e9b0cd7 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -209,10 +209,15 @@ TESTS.append(("RandZoom 3d", "3D", 9e-2, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) -TESTS.append(("RandRotated, prob 0", "2D", 0, RandRotated(KEYS, prob=0))) +TESTS.append(("RandRotated, prob 0", "2D", 0, RandRotated(KEYS, prob=0, dtype=np.float64))) TESTS.append( - ("Rotated 2d", "2D", 8e-2, Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False)) + ( + "Rotated 2d", + "2D", + 8e-2, + Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False, dtype=np.float64), + ) ) TESTS.append( @@ -220,7 +225,7 @@ "Rotated 3d", "3D", 1e-1, - Rotated(KEYS, [random.uniform(np.pi / 6, np.pi) for _ in range(3)], True), # type: ignore + Rotated(KEYS, [random.uniform(np.pi / 6, np.pi) for _ in range(3)], True, dtype=np.float64), ) ) @@ -229,7 +234,7 @@ "RandRotated 3d", "3D", 1e-1, - RandRotated(KEYS, *[random.uniform(np.pi / 6, np.pi) for _ in range(3)], 1), # type: ignore + RandRotated(KEYS, *[random.uniform(np.pi / 6, np.pi) for _ in range(3)], 1, dtype=np.float64), # type: ignore ) ) diff --git a/tests/test_inverse_collation.py b/tests/test_inverse_collation.py index 3c293070bb..4e8c6b58cc 100644 --- a/tests/test_inverse_collation.py +++ b/tests/test_inverse_collation.py @@ -50,7 +50,7 @@ RandAxisFlipd(keys=KEYS, prob=0.5), Compose([RandRotate90d(keys=KEYS, spatial_axes=(1, 2)), ToTensord(keys=KEYS)]), RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), - RandRotated(keys=KEYS, prob=0.5, range_x=np.pi), + RandRotated(keys=KEYS, prob=0.5, range_x=np.pi, dtype=np.float64), RandAffined( keys=KEYS, prob=0.5, rotate_range=np.pi, device=torch.device("cuda" if torch.cuda.is_available() else "cpu") ), @@ -65,7 +65,7 @@ RandAxisFlipd(keys=KEYS, prob=0.5), Compose([RandRotate90d(keys=KEYS, prob=0.5, spatial_axes=(0, 1)), ToTensord(keys=KEYS)]), RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), - RandRotated(keys=KEYS, prob=0.5, range_x=np.pi), + RandRotated(keys=KEYS, prob=0.5, range_x=np.pi, dtype=np.float64), RandAffined( keys=KEYS, prob=0.5, rotate_range=np.pi, device=torch.device("cuda" if torch.cuda.is_available() else "cpu") ), diff --git a/tests/test_invertd.py b/tests/test_invertd.py index f28587eb6b..64c26c4012 100644 --- a/tests/test_invertd.py +++ b/tests/test_invertd.py @@ -58,7 +58,7 @@ def test_invert(self): RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), - RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), + RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True, dtype=np.float64), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), # test EnsureTensor for complicated dict data and invert it diff --git a/tests/test_itk_writer.py b/tests/test_itk_writer.py new file mode 100644 index 0000000000..163fead76e --- /dev/null +++ b/tests/test_itk_writer.py @@ -0,0 +1,55 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import numpy as np +import torch + +from monai.data import ITKWriter +from monai.utils import optional_import + +itk, has_itk = optional_import("itk") +nib, has_nibabel = optional_import("nibabel") + + +@unittest.skipUnless(has_itk, "Requires `itk` package.") +class TestITKWriter(unittest.TestCase): + def test_channel_shape(self): + with tempfile.TemporaryDirectory() as tempdir: + for c in (0, 1, 2, 3): + fname = os.path.join(tempdir, f"testing{c}.nii") + itk_writer = ITKWriter() + itk_writer.set_data_array(torch.zeros(1, 2, 3, 4), channel_dim=c, squeeze_end_dims=False) + itk_writer.set_metadata({}) + itk_writer.write(fname) + itk_obj = itk.imread(fname) + s = [1, 2, 3, 4] + s.pop(c) + np.testing.assert_allclose(itk.size(itk_obj), s) + + def test_rgb(self): + with tempfile.TemporaryDirectory() as tempdir: + fname = os.path.join(tempdir, "testing.png") + writer = ITKWriter(output_dtype=np.uint8) + writer.set_data_array(np.arange(48).reshape(3, 4, 4), channel_dim=0) + writer.set_metadata({"spatial_shape": (5, 5)}) + writer.write(fname) + + output = np.asarray(itk.imread(fname)) + np.testing.assert_allclose(output.shape, (5, 5, 3)) + np.testing.assert_allclose(output[1, 1], (5, 5, 4)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_lmdbdataset.py b/tests/test_lmdbdataset.py index b624e5c4e3..33f27ee4bc 100644 --- a/tests/test_lmdbdataset.py +++ b/tests/test_lmdbdataset.py @@ -57,7 +57,7 @@ SimulateDelayd(keys=["image", "label", "extra"], delay_time=[1e-7, 1e-6, 1e-5]), ], (128, 128, 128), - {"pickle_protocol": 2, "lmdb_kwargs": {"map_size": 100 * 1024 ** 2}}, + {"pickle_protocol": 2, "lmdb_kwargs": {"map_size": 100 * 1024**2}}, ] TEST_CASE_6 = [ @@ -66,7 +66,7 @@ SimulateDelayd(keys=["image", "label", "extra"], delay_time=[1e-7, 1e-6, 1e-5]), ], (128, 128, 128), - {"db_name": "testdb", "lmdb_kwargs": {"map_size": 100 * 1024 ** 2}}, + {"db_name": "testdb", "lmdb_kwargs": {"map_size": 100 * 1024**2}}, ] TEST_CASE_7 = [ @@ -75,7 +75,7 @@ SimulateDelayd(keys=["image", "label", "extra"], delay_time=[1e-7, 1e-6, 1e-5]), ], (128, 128, 128), - {"db_name": "testdb", "lmdb_kwargs": {"map_size": 2 * 1024 ** 2}}, + {"db_name": "testdb", "lmdb_kwargs": {"map_size": 2 * 1024**2}}, ] diff --git a/tests/test_load_image.py b/tests/test_load_image.py index f215c925d8..2c8638ebbe 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -116,7 +116,11 @@ def get_data(self, _obj): TEST_CASE_13 = [{"reader": "nibabelreader", "channel_dim": 0}, "test_image.nii.gz", (3, 128, 128, 128)] -TEST_CASE_14 = [{"reader": "nibabelreader", "channel_dim": -1}, "test_image.nii.gz", (128, 128, 128, 3)] +TEST_CASE_14 = [ + {"reader": "nibabelreader", "channel_dim": -1, "ensure_channel_first": True}, + "test_image.nii.gz", + (128, 128, 128, 3), +] TEST_CASE_15 = [{"reader": "nibabelreader", "channel_dim": 2}, "test_image.nii.gz", (128, 128, 3, 128)] @@ -124,7 +128,11 @@ def get_data(self, _obj): TEST_CASE_17 = [{"reader": "ITKReader", "channel_dim": -1}, "test_image.nii.gz", (128, 128, 128, 3)] -TEST_CASE_18 = [{"reader": "ITKReader", "channel_dim": 2}, "test_image.nii.gz", (128, 128, 3, 128)] +TEST_CASE_18 = [ + {"reader": "ITKReader", "channel_dim": 2, "ensure_channel_first": True}, + "test_image.nii.gz", + (128, 128, 3, 128), +] class TestLoadImage(unittest.TestCase): @@ -290,7 +298,9 @@ def test_channel_dim(self, input_param, filename, expected_shape): nib.save(nib.Nifti1Image(test_image, np.eye(4)), filename) result = LoadImage(**input_param)(filename) - self.assertTupleEqual(result[0].shape, expected_shape) + self.assertTupleEqual( + result[0].shape, (3, 128, 128, 128) if input_param.get("ensure_channel_first", False) else expected_shape + ) self.assertTupleEqual(tuple(result[1]["spatial_shape"]), (128, 128, 128)) self.assertEqual(result[1]["original_channel_dim"], input_param["channel_dim"]) diff --git a/tests/test_load_imaged.py b/tests/test_load_imaged.py index f4499311d3..bc001cf2fd 100644 --- a/tests/test_load_imaged.py +++ b/tests/test_load_imaged.py @@ -82,7 +82,7 @@ class TestConsistency(unittest.TestCase): def _cmp(self, filename, shape, ch_shape, reader_1, reader_2, outname, ext): data_dict = {"img": filename} keys = data_dict.keys() - xforms = Compose([LoadImaged(keys, reader=reader_1), EnsureChannelFirstD(keys)]) + xforms = Compose([LoadImaged(keys, reader=reader_1, ensure_channel_first=True)]) img_dict = xforms(data_dict) # load dicom with itk self.assertTupleEqual(img_dict["img"].shape, ch_shape) self.assertTupleEqual(tuple(img_dict[PostFix.meta("img")]["spatial_shape"]), shape) diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py index 78c94d4e41..a76808be20 100644 --- a/tests/test_lr_finder.py +++ b/tests/test_lr_finder.py @@ -24,6 +24,7 @@ from monai.optimizers import LearningRateFinder from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord from monai.utils import optional_import, set_determinism +from tests.utils import skip_if_downloading_fails if TYPE_CHECKING: import matplotlib.pyplot as plt @@ -61,14 +62,15 @@ def setUp(self): def test_lr_finder(self): # 0.001 gives 54 examples - train_ds = MedNISTDataset( - root_dir=self.root_dir, - transform=self.transforms, - section="validation", - val_frac=0.001, - download=True, - num_workers=10, - ) + with skip_if_downloading_fails(): + train_ds = MedNISTDataset( + root_dir=self.root_dir, + transform=self.transforms, + section="validation", + val_frac=0.001, + download=True, + num_workers=10, + ) train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10) num_classes = train_ds.get_num_classes() diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py index 1da3b73de2..e7cc1a60ff 100644 --- a/tests/test_mednistdataset.py +++ b/tests/test_mednistdataset.py @@ -13,12 +13,11 @@ import shutil import unittest from pathlib import Path -from urllib.error import ContentTooShortError, HTTPError from monai.apps import MedNISTDataset from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord from monai.utils.enums import PostFix -from tests.utils import skip_if_quick +from tests.utils import skip_if_downloading_fails, skip_if_quick MEDNIST_FULL_DATASET_LENGTH = 58954 @@ -43,16 +42,10 @@ def _test_dataset(dataset): self.assertTrue(PostFix.meta("image") in dataset[0]) self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64)) - try: # will start downloading if testing_dir doesn't have the MedNIST files + with skip_if_downloading_fails(): data = MedNISTDataset( root_dir=testing_dir, transform=transform, section="test", download=True, copy_cache=False ) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors _test_dataset(data) diff --git a/tests/test_mmar_download.py b/tests/test_mmar_download.py index 2cae5969db..cab051e781 100644 --- a/tests/test_mmar_download.py +++ b/tests/test_mmar_download.py @@ -13,7 +13,6 @@ import tempfile import unittest from pathlib import Path -from urllib.error import ContentTooShortError, HTTPError import numpy as np import torch @@ -22,7 +21,7 @@ from monai.apps import RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar from monai.apps.mmars import MODEL_DESC from monai.apps.mmars.mmars import _get_val -from tests.utils import skip_if_quick +from tests.utils import skip_if_downloading_fails, skip_if_quick TEST_CASES = [["clara_pt_prostate_mri_segmentation_1"], ["clara_pt_covid19_ct_lesion_segmentation_1"]] TEST_EXTRACT_CASES = [ @@ -105,7 +104,7 @@ class TestMMMARDownload(unittest.TestCase): @parameterized.expand(TEST_CASES) @skip_if_quick def test_download(self, idx): - try: + with skip_if_downloading_fails(): # test model specification cand = get_model_spec(idx) self.assertEqual(cand[RemoteMMARKeys.ID], idx) @@ -116,22 +115,12 @@ def test_download(self, idx): download_mmar(idx, mmar_dir=tmp_dir, progress=False) download_mmar(idx, mmar_dir=Path(tmp_dir), progress=False, version=1) # repeated to check caching self.assertTrue(os.path.exists(os.path.join(tmp_dir, idx))) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, HTTPError): - self.assertTrue("500" in str(e)) # http error has the code 500 - return # skipping this test due the network connection errors @parameterized.expand(TEST_EXTRACT_CASES) @skip_if_quick def test_load_ckpt(self, input_args, expected_name, expected_val): - try: + with skip_if_downloading_fails(): output = load_from_mmar(**input_args) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, HTTPError): - self.assertTrue("500" in str(e)) # http error has the code 500 - return self.assertEqual(output.__class__.__name__, expected_name) x = next(output.parameters()) # verify the first element np.testing.assert_allclose(x[0][0].detach().cpu().numpy(), expected_val, rtol=1e-3, atol=1e-3) diff --git a/tests/test_module_list.py b/tests/test_module_list.py index 5ec4aa9ff1..ea520c59f3 100644 --- a/tests/test_module_list.py +++ b/tests/test_module_list.py @@ -10,6 +10,7 @@ # limitations under the License. import glob +import inspect import os import unittest @@ -33,6 +34,24 @@ def test_public_api(self): mod.append(code_folder) self.assertEqual(sorted(monai.__all__), sorted(mod)) + def test_transform_api(self): + """monai subclasses of MapTransforms must have alias names ending with 'd', 'D', 'Dict'""" + to_exclude = {"MapTransform"} # except for these transforms + xforms = { + name: obj + for name, obj in monai.transforms.__dict__.items() + if inspect.isclass(obj) and issubclass(obj, monai.transforms.MapTransform) + } + names = sorted(x for x in xforms if x not in to_exclude) + remained = set(names) + for n in names: + if not n.endswith("d"): + continue + basename = n[:-1] # Transformd basename is Transform + for postfix in ("D", "d", "Dict"): + remained.remove(f"{basename}{postfix}") + self.assertFalse(remained) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_nifti_rw.py b/tests/test_nifti_rw.py index 1322fa6a45..2c0a8dc9a3 100644 --- a/tests/test_nifti_rw.py +++ b/tests/test_nifti_rw.py @@ -17,7 +17,7 @@ import numpy as np from parameterized import parameterized -from monai.data import write_nifti +from monai.data import NibabelWriter from monai.transforms import LoadImage, Orientation, Spacing from tests.utils import TEST_NDARRAYS, assert_allclose, make_nifti_image @@ -94,10 +94,13 @@ def test_orientation(self, array, affine, reader_param, expected): os.remove(test_image) # write test cases + writer_obj = NibabelWriter() + writer_obj.set_data_array(data_array, channel_dim=None) if header is not None: - write_nifti(data_array, test_image, header["affine"], header.get("original_affine", None)) + writer_obj.set_metadata(header) elif affine is not None: - write_nifti(data_array, test_image, affine) + writer_obj.set_metadata({"affine": affine}) + writer_obj.write(test_image, verbose=True) saved = nib.load(test_image) saved_affine = saved.affine saved_data = saved.get_fdata() @@ -116,36 +119,35 @@ def test_consistency(self): data, _, new_affine = Orientation("ILP")(data, new_affine) if os.path.exists(test_image): os.remove(test_image) - write_nifti(data[0], test_image, new_affine, original_affine, mode="nearest", padding_mode="border") + writer_obj = NibabelWriter() + writer_obj.set_data_array(data[0], channel_dim=None) + writer_obj.set_metadata( + meta_dict={"affine": new_affine, "original_affine": original_affine}, mode="nearest", padding_mode="border" + ) + writer_obj.write(test_image, verbose=True) saved = nib.load(test_image) saved_data = saved.get_fdata() np.testing.assert_allclose(saved_data, np.arange(64).reshape(1, 8, 8), atol=1e-7) if os.path.exists(test_image): os.remove(test_image) - write_nifti( - data[0], - test_image, - new_affine, - original_affine, + writer_obj.set_data_array(data[0], channel_dim=None) + writer_obj.set_metadata( + meta_dict={"affine": new_affine, "original_affine": original_affine, "spatial_shape": (1, 8, 8)}, mode="nearest", padding_mode="border", - output_spatial_shape=(1, 8, 8), ) + writer_obj.write(test_image, verbose=True) saved = nib.load(test_image) saved_data = saved.get_fdata() np.testing.assert_allclose(saved_data, np.arange(64).reshape(1, 8, 8), atol=1e-7) if os.path.exists(test_image): os.remove(test_image) - # test the case that only correct orientation but don't resample - write_nifti(data[0], test_image, new_affine, original_affine, resample=False) + # test the case no resample + writer_obj.set_data_array(data[0], channel_dim=None) + writer_obj.set_metadata(meta_dict={"affine": new_affine, "original_affine": original_affine}, resample=False) + writer_obj.write(test_image, verbose=True) saved = nib.load(test_image) - # compute expected affine - start_ornt = nib.orientations.io_orientation(new_affine) - target_ornt = nib.orientations.io_orientation(original_affine) - ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt) - data_shape = data[0].shape - expected_affine = new_affine @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape) - np.testing.assert_allclose(saved.affine, expected_affine) + np.testing.assert_allclose(saved.affine, new_affine) if os.path.exists(test_image): os.remove(test_image) @@ -154,16 +156,21 @@ def test_write_2d(self): image_name = os.path.join(out_dir, "test.nii.gz") for p in TEST_NDARRAYS: img = p(np.arange(6).reshape((2, 3))) - write_nifti(img, image_name, affine=np.diag([1]), target_affine=np.diag([1.4])) + writer_obj = NibabelWriter() + writer_obj.set_data_array(img, channel_dim=None) + writer_obj.set_metadata({"affine": np.diag([1, 1, 1]), "original_affine": np.diag([1.4, 1, 1])}) + writer_obj.write(image_name, verbose=True) out = nib.load(image_name) np.testing.assert_allclose(out.get_fdata(), [[0, 1, 2], [3.0, 4, 5]]) np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) image_name = os.path.join(out_dir, "test1.nii.gz") img = np.arange(5).reshape((1, 5)) - write_nifti( - img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 1, 3, 5]) + writer_obj.set_data_array(img, channel_dim=None) + writer_obj.set_metadata( + {"affine": np.diag([1, 1, 1, 3, 3]), "original_affine": np.diag([1.4, 2.0, 1, 3, 5])} ) + writer_obj.write(image_name, verbose=True) out = nib.load(image_name) np.testing.assert_allclose(out.get_fdata(), [[0, 2, 4]]) np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 1, 1])) @@ -173,16 +180,21 @@ def test_write_3d(self): image_name = os.path.join(out_dir, "test.nii.gz") for p in TEST_NDARRAYS: img = p(np.arange(6).reshape((1, 2, 3))) - write_nifti(img, image_name, affine=np.diag([1]), target_affine=np.diag([1.4])) + writer_obj = NibabelWriter() + writer_obj.set_data_array(img, channel_dim=None) + writer_obj.set_metadata({"affine": np.diag([1, 1, 1, 1]), "original_affine": np.diag([1.4, 1, 1, 1])}) + writer_obj.write(image_name, verbose=True) out = nib.load(image_name) np.testing.assert_allclose(out.get_fdata(), [[[0, 1, 2], [3, 4, 5]]]) np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) image_name = os.path.join(out_dir, "test1.nii.gz") img = p(np.arange(5).reshape((1, 1, 5))) - write_nifti( - img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 2, 3, 5]) + writer_obj.set_data_array(img, channel_dim=None) + writer_obj.set_metadata( + {"affine": np.diag([1, 1, 1, 3, 3]), "original_affine": np.diag([1.4, 2.0, 2, 3, 5])} ) + writer_obj.write(image_name, verbose=True) out = nib.load(image_name) np.testing.assert_allclose(out.get_fdata(), [[[0, 2, 4]]]) np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) @@ -192,16 +204,21 @@ def test_write_4d(self): image_name = os.path.join(out_dir, "test.nii.gz") for p in TEST_NDARRAYS: img = p(np.arange(6).reshape((1, 1, 3, 2))) - write_nifti(img, image_name, affine=np.diag([1.4, 1]), target_affine=np.diag([1, 1.4, 1])) + writer_obj = NibabelWriter() + writer_obj.set_data_array(img, channel_dim=-1) + writer_obj.set_metadata({"affine": np.diag([1.4, 1, 1, 1]), "original_affine": np.diag([1, 1.4, 1, 1])}) + writer_obj.write(image_name, verbose=True) out = nib.load(image_name) np.testing.assert_allclose(out.get_fdata(), [[[[0, 1], [2, 3], [4, 5]]]]) np.testing.assert_allclose(out.affine, np.diag([1, 1.4, 1, 1])) image_name = os.path.join(out_dir, "test1.nii.gz") img = p(np.arange(5).reshape((1, 1, 5, 1))) - write_nifti( - img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 2, 3, 5]) + writer_obj.set_data_array(img, channel_dim=-1, squeeze_end_dims=False) + writer_obj.set_metadata( + {"affine": np.diag([1, 1, 1, 3, 3]), "original_affine": np.diag([1.4, 2.0, 2, 3, 5])} ) + writer_obj.write(image_name, verbose=True) out = nib.load(image_name) np.testing.assert_allclose(out.get_fdata(), [[[[0], [2], [4]]]]) np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) @@ -211,7 +228,10 @@ def test_write_5d(self): image_name = os.path.join(out_dir, "test.nii.gz") for p in TEST_NDARRAYS: img = p(np.arange(12).reshape((1, 1, 3, 2, 2))) - write_nifti(img, image_name, affine=np.diag([1]), target_affine=np.diag([1.4])) + writer_obj = NibabelWriter() + writer_obj.set_data_array(img, channel_dim=-1, squeeze_end_dims=False, spatial_ndim=None) + writer_obj.set_metadata({"affine": np.diag([1, 1, 1, 1]), "original_affine": np.diag([1.4, 1, 1, 1])}) + writer_obj.write(image_name, verbose=True) out = nib.load(image_name) np.testing.assert_allclose( out.get_fdata(), @@ -221,11 +241,11 @@ def test_write_5d(self): image_name = os.path.join(out_dir, "test1.nii.gz") img = p(np.arange(10).reshape((1, 1, 5, 1, 2))) - write_nifti( - img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 2, 3, 5]) - ) + writer_obj.set_data_array(img, channel_dim=-1, squeeze_end_dims=False, spatial_ndim=None) + writer_obj.set_metadata({"affine": np.diag([1, 1, 1, 3]), "original_affine": np.diag([1.4, 2.0, 2, 3])}) + writer_obj.write(image_name, verbose=True) out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), np.array([[[[[0.0, 1.0]], [[4.0, 5.0]], [[8.0, 9.0]]]]])) + np.testing.assert_allclose(out.get_fdata(), np.array([[[[[0.0, 2.0]], [[4.0, 5.0]], [[7.0, 9.0]]]]])) np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) diff --git a/tests/test_ori_ras_lps.py b/tests/test_ori_ras_lps.py new file mode 100644 index 0000000000..4ed223bf5b --- /dev/null +++ b/tests/test_ori_ras_lps.py @@ -0,0 +1,46 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.data.utils import orientation_ras_lps +from tests.utils import TEST_NDARRAYS, assert_allclose + +TEST_CASES_AFFINE = [] +for p in TEST_NDARRAYS: + case_1d = p([[1.0, 0.0], [1.0, 1.0]]), p([[-1.0, 0.0], [1.0, 1.0]]) + TEST_CASES_AFFINE.append(case_1d) + case_2d_1 = p([[1.0, 0.0, 1.0], [1.0, 1.0, 1.0]]), p([[-1.0, 0.0, -1.0], [1.0, 1.0, 1.0]]) + TEST_CASES_AFFINE.append(case_2d_1) + case_2d_2 = p([[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]), p( + [[-1.0, 0.0, -1.0], [0.0, -1.0, -1.0], [1.0, 1.0, 1.0]] + ) + TEST_CASES_AFFINE.append(case_2d_2) + case_3d = p([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 2.0], [1.0, 1.0, 1.0, 3.0]]), p( + [[-1.0, 0.0, -1.0, -1.0], [0.0, -1.0, -1.0, -2.0], [1.0, 1.0, 1.0, 3.0]] + ) + TEST_CASES_AFFINE.append(case_3d) + case_4d = p(np.ones((5, 5))), p([[-1] * 5, [-1] * 5, [1] * 5, [1] * 5, [1] * 5]) + TEST_CASES_AFFINE.append(case_4d) + + +class TestITKWriter(unittest.TestCase): + @parameterized.expand(TEST_CASES_AFFINE) + def test_ras_to_lps(self, param, expected): + assert_allclose(orientation_ras_lps(param), expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py index a070fc760f..530e5f86a3 100644 --- a/tests/test_pad_collation.py +++ b/tests/test_pad_collation.py @@ -42,12 +42,12 @@ PadListDataCollate(method="end", mode="constant"), ]: TESTS.append((dict, pad_collate, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) - TESTS.append((dict, pad_collate, RandRotated("image", prob=1, range_x=np.pi, keep_size=False))) + TESTS.append((dict, pad_collate, RandRotated("image", prob=1, range_x=np.pi, keep_size=False, dtype=np.float64))) TESTS.append((dict, pad_collate, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) TESTS.append((dict, pad_collate, Compose([RandRotate90d("image", prob=1, max_k=2), ToTensord("image")]))) TESTS.append((list, pad_collate, RandSpatialCrop(roi_size=[8, 7], random_size=True))) - TESTS.append((list, pad_collate, RandRotate(prob=1, range_x=np.pi, keep_size=False))) + TESTS.append((list, pad_collate, RandRotate(prob=1, range_x=np.pi, keep_size=False, dtype=np.float64))) TESTS.append((list, pad_collate, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) TESTS.append((list, pad_collate, Compose([RandRotate90(prob=1, max_k=2), ToTensor()]))) diff --git a/tests/test_parallel_execution_dist.py b/tests/test_parallel_execution_dist.py new file mode 100644 index 0000000000..f067b71d14 --- /dev/null +++ b/tests/test_parallel_execution_dist.py @@ -0,0 +1,45 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +import torch.distributed as dist + +from monai.engines import create_multigpu_supervised_trainer +from tests.utils import DistCall, DistTestCase, skip_if_no_cuda + + +def fake_loss(y_pred, y): + return (y_pred[0] + y).sum() + + +def fake_data_stream(): + while True: + yield torch.rand((10, 1, 64, 64)), torch.rand((10, 1, 64, 64)) + + +class DistributedTestParallelExecution(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) + @skip_if_no_cuda + def test_distributed(self): + device = torch.device(f"cuda:{dist.get_rank()}") + net = torch.nn.Conv2d(1, 1, 3, padding=1).to(device) + opt = torch.optim.Adam(net.parameters(), 1e-3) + + trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, [device], distributed=True) + trainer.run(fake_data_stream(), 2, 2) + # assert the trainer output is loss value + self.assertTrue(isinstance(trainer.state.output, float)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_png_rw.py b/tests/test_png_rw.py index 84251b391f..47b5571ac0 100644 --- a/tests/test_png_rw.py +++ b/tests/test_png_rw.py @@ -16,7 +16,7 @@ import numpy as np from PIL import Image -from monai.data import write_png +from monai.data.image_writer import PILWriter class TestPngWrite(unittest.TestCase): @@ -25,7 +25,9 @@ def test_write_gray(self): image_name = os.path.join(out_dir, "test.png") img = np.random.rand(2, 3) img_save_val = (255 * img).astype(np.uint8) - write_png(img, image_name, scale=255) + writer_obj = PILWriter(output_dtype=np.uint8) + writer_obj.set_data_array(img, channel_dim=None) + writer_obj.write(image_name, format="PNG") out = np.asarray(Image.open(image_name)) out = np.moveaxis(out, 0, 1) np.testing.assert_allclose(out, img_save_val) @@ -35,7 +37,9 @@ def test_write_gray_1height(self): image_name = os.path.join(out_dir, "test.png") img = np.random.rand(1, 3) img_save_val = (65535 * img).astype(np.uint16) - write_png(img, image_name, scale=65535) + writer_obj = PILWriter(output_dtype=np.uint16, scale=65535) + writer_obj.set_data_array(img, channel_dim=None) + writer_obj.write(image_name, format="PNG") out = np.asarray(Image.open(image_name)) out = np.moveaxis(out, 0, 1) np.testing.assert_allclose(out, img_save_val) @@ -45,17 +49,22 @@ def test_write_gray_1channel(self): image_name = os.path.join(out_dir, "test.png") img = np.random.rand(2, 3, 1) img_save_val = (255 * img).astype(np.uint8).squeeze(2) - write_png(img, image_name, scale=255) + writer_obj = PILWriter(output_dtype=np.uint8, scale=255) + writer_obj.set_data_array(img, channel_dim=None) + writer_obj.write(image_name, format="PNG") out = np.asarray(Image.open(image_name)) out = np.moveaxis(out, 0, 1) np.testing.assert_allclose(out, img_save_val) def test_write_rgb(self): + """testing default kwargs and obj_kwargs""" with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.png") img = np.random.rand(2, 3, 3) img_save_val = (255 * img).astype(np.uint8) - write_png(img, image_name, scale=255) + writer_obj = PILWriter(output_dtype=np.uint8) + writer_obj.set_data_array(img, channel_dim=-1) + writer_obj.write(image_name) out = np.asarray(Image.open(image_name)) out = np.moveaxis(out, 0, 1) np.testing.assert_allclose(out, img_save_val) @@ -65,7 +74,9 @@ def test_write_2channels(self): image_name = os.path.join(out_dir, "test.png") img = np.random.rand(2, 3, 2) img_save_val = (255 * img).astype(np.uint8) - write_png(img, image_name, scale=255) + writer_obj = PILWriter(output_dtype=np.uint8) + writer_obj.set_data_array(img, channel_dim=-1) + writer_obj.write(image_name, format="PNG") out = np.asarray(Image.open(image_name)) out = np.moveaxis(out, 0, 1) np.testing.assert_allclose(out, img_save_val) @@ -74,7 +85,10 @@ def test_write_output_shape(self): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.png") img = np.random.rand(2, 2, 3) - write_png(img, image_name, (4, 4), scale=255) + writer_obj = PILWriter(output_dtype=np.uint8) + writer_obj.set_data_array(img, channel_dim=-1) + writer_obj.set_metadata({"spatial_shape": (4, 4)}, scale=255) + writer_obj.write(image_name, format="PNG") out = np.asarray(Image.open(image_name)) np.testing.assert_allclose(out.shape, (4, 4, 3)) diff --git a/tests/test_rand_elastic_3d.py b/tests/test_rand_elastic_3d.py index d4eed3753f..39ce779cb0 100644 --- a/tests/test_rand_elastic_3d.py +++ b/tests/test_rand_elastic_3d.py @@ -88,7 +88,7 @@ def test_rand_3d_elastic(self, input_param, input_data, expected_val): g = Rand3DElastic(**input_param) g.set_random_state(123) result = g(**input_data) - assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=1e-1, atol=1e-1) if __name__ == "__main__": diff --git a/tests/test_rand_elasticd_3d.py b/tests/test_rand_elasticd_3d.py index 84bc765ab0..c78ed1f42e 100644 --- a/tests/test_rand_elasticd_3d.py +++ b/tests/test_rand_elasticd_3d.py @@ -145,7 +145,7 @@ def test_rand_3d_elasticd(self, input_param, input_data, expected_val): for key in res: result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected, rtol=1e-2, atol=1e-2) if __name__ == "__main__": diff --git a/tests/test_rand_rotate.py b/tests/test_rand_rotate.py index b453b01884..7a85fce23b 100644 --- a/tests/test_rand_rotate.py +++ b/tests/test_rand_rotate.py @@ -71,6 +71,7 @@ def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, mode=mode, padding_mode=padding_mode, align_corners=align_corners, + dtype=np.float64, ) rotate_fn.set_random_state(243) rotated = rotate_fn(im_type(self.imt[0])) @@ -104,6 +105,7 @@ def test_correct_results(self, im_type, x, y, z, keep_size, mode, padding_mode, mode=mode, padding_mode=padding_mode, align_corners=align_corners, + dtype=np.float64, ) rotate_fn.set_random_state(243) rotated = rotate_fn(im_type(self.imt[0])) diff --git a/tests/test_rand_rotated.py b/tests/test_rand_rotated.py index 23314720c1..464b37d925 100644 --- a/tests/test_rand_rotated.py +++ b/tests/test_rand_rotated.py @@ -116,6 +116,7 @@ def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, mode=mode, padding_mode=padding_mode, align_corners=align_corners, + dtype=np.float64, ) rotate_fn.set_random_state(243) rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) @@ -151,6 +152,7 @@ def test_correct_shapes(self, im_type, x, y, z, keep_size, mode, padding_mode, a mode=mode, padding_mode=padding_mode, align_corners=align_corners, + dtype=np.float64, ) rotate_fn.set_random_state(243) rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) diff --git a/tests/test_rotate.py b/tests/test_rotate.py index 42947e7f72..01842f6d73 100644 --- a/tests/test_rotate.py +++ b/tests/test_rotate.py @@ -46,7 +46,7 @@ class TestRotate2D(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES_2D) def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): - rotate_fn = Rotate(angle, keep_size, mode, padding_mode, align_corners) + rotate_fn = Rotate(angle, keep_size, mode, padding_mode, align_corners, dtype=np.float64) rotated = rotate_fn(im_type(self.imt[0])) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated.shape) @@ -74,7 +74,7 @@ def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, al class TestRotate3D(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): - rotate_fn = Rotate([angle, 0, 0], keep_size, mode, padding_mode, align_corners) + rotate_fn = Rotate([angle, 0, 0], keep_size, mode, padding_mode, align_corners, dtype=np.float64) rotated = rotate_fn(im_type(self.imt[0])) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated.shape) @@ -100,7 +100,7 @@ def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, al @parameterized.expand(TEST_CASES_SHAPE_3D) def test_correct_shape(self, im_type, angle, mode, padding_mode, align_corners): - rotate_fn = Rotate(angle, True, align_corners=align_corners) + rotate_fn = Rotate(angle, True, align_corners=align_corners, dtype=np.float64) rotated = rotate_fn(im_type(self.imt[0]), mode=mode, padding_mode=padding_mode) np.testing.assert_allclose(self.imt[0].shape, rotated.shape) diff --git a/tests/test_rotated.py b/tests/test_rotated.py index 1b759cfef5..43b5a68f61 100644 --- a/tests/test_rotated.py +++ b/tests/test_rotated.py @@ -40,7 +40,9 @@ class TestRotated2D(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES_2D) def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): - rotate_fn = Rotated(("img", "seg"), angle, keep_size, (mode, "nearest"), padding_mode, align_corners) + rotate_fn = Rotated( + ("img", "seg"), angle, keep_size, (mode, "nearest"), padding_mode, align_corners, dtype=np.float64 + ) rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) @@ -69,7 +71,9 @@ def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, al class TestRotated3D(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): - rotate_fn = Rotated(("img", "seg"), [0, angle, 0], keep_size, (mode, "nearest"), padding_mode, align_corners) + rotate_fn = Rotated( + ("img", "seg"), [0, angle, 0], keep_size, (mode, "nearest"), padding_mode, align_corners, dtype=np.float64 + ) rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) @@ -98,7 +102,9 @@ def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, al class TestRotated3DXY(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): - rotate_fn = Rotated(("img", "seg"), [0, 0, angle], keep_size, (mode, "nearest"), padding_mode, align_corners) + rotate_fn = Rotated( + ("img", "seg"), [0, 0, angle], keep_size, (mode, "nearest"), padding_mode, align_corners, dtype=np.float64 + ) rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) diff --git a/tests/test_save_image.py b/tests/test_save_image.py index d3671cf830..a1297c1e61 100644 --- a/tests/test_save_image.py +++ b/tests/test_save_image.py @@ -13,6 +13,7 @@ import tempfile import unittest +import numpy as np import torch from parameterized import parameterized @@ -22,17 +23,25 @@ TEST_CASE_2 = [torch.randint(0, 255, (1, 2, 3, 4)), None, ".nii.gz", False] +TEST_CASE_3 = [torch.randint(0, 255, (1, 2, 3, 4)), {"filename_or_obj": "testfile0.nrrd"}, ".nrrd", False] + +TEST_CASE_4 = [ + np.random.randint(0, 255, (3, 2, 4, 5), dtype=np.uint8), + {"filename_or_obj": "testfile0.dcm"}, + ".dcm", + False, +] + class TestSaveImage(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_saved_content(self, test_data, meta_data, output_ext, resample): with tempfile.TemporaryDirectory() as tempdir: trans = SaveImage( output_dir=tempdir, output_ext=output_ext, resample=resample, - # test saving into the same folder - separate_folder=False, + separate_folder=False, # test saving into the same folder ) trans(test_data, meta_data) diff --git a/tests/test_save_imaged.py b/tests/test_save_imaged.py index 6f0bb4c2ba..a6988683e5 100644 --- a/tests/test_save_imaged.py +++ b/tests/test_save_imaged.py @@ -35,9 +35,19 @@ False, ] +TEST_CASE_3 = [ + { + "img": torch.randint(0, 255, (1, 2, 3, 4)), + PostFix.meta("img"): {"filename_or_obj": "testfile0.nrrd"}, + "patch_index": 6, + }, + ".nrrd", + False, +] + class TestSaveImaged(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_saved_content(self, test_data, output_ext, resample): with tempfile.TemporaryDirectory() as tempdir: trans = SaveImaged( diff --git a/tests/test_save_state.py b/tests/test_save_state.py new file mode 100644 index 0000000000..c48b12ebdc --- /dev/null +++ b/tests/test_save_state.py @@ -0,0 +1,70 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import torch +import torch.optim as optim +from parameterized import parameterized + +from monai.networks import save_state + +TEST_CASE_1 = [torch.nn.PReLU(), ["weight"]] + +TEST_CASE_2 = [{"net": torch.nn.PReLU()}, ["net"]] + +TEST_CASE_3 = [{"net": torch.nn.PReLU(), "opt": optim.SGD(torch.nn.PReLU().parameters(), lr=0.02)}, ["net", "opt"]] + +TEST_CASE_4 = [torch.nn.DataParallel(torch.nn.PReLU()), ["weight"]] + +TEST_CASE_5 = [{"net": torch.nn.DataParallel(torch.nn.PReLU())}, ["net"]] + +TEST_CASE_6 = [torch.nn.PReLU(), ["weight"], True, True, None, {"pickle_protocol": 2}] + +TEST_CASE_7 = [torch.nn.PReLU().state_dict(), ["weight"]] + +TEST_CASE_8 = [torch.nn.PReLU(), ["weight"], False] + +TEST_CASE_9 = [torch.nn.PReLU(), ["weight"], True, False] + +TEST_CASE_10 = [torch.nn.PReLU(), ["weight"], True, True, torch.save] + + +class TestSaveState(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + TEST_CASE_10, + ] + ) + def test_file(self, src, expected_keys, create_dir=True, atomic=True, func=None, kwargs=None): + with tempfile.TemporaryDirectory() as tempdir: + path = os.path.join(tempdir, "test_ckpt.pt") + if kwargs is None: + kwargs = {} + save_state(src=src, path=path, create_dir=create_dir, atomic=atomic, func=func, **kwargs) + ckpt = dict(torch.load(path)) + for k in ckpt.keys(): + self.assertIn(k, expected_keys) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_scale_intensity.py b/tests/test_scale_intensity.py index b81351ed2e..bd1adac4f4 100644 --- a/tests/test_scale_intensity.py +++ b/tests/test_scale_intensity.py @@ -58,7 +58,7 @@ def test_int(self): def test_channel_wise(self): for p in TEST_NDARRAYS: scaler = ScaleIntensity(minv=1.0, maxv=2.0, channel_wise=True) - data = p(self.imt) + data = p(np.tile(self.imt, (3, 1, 1, 1))) result = scaler(data) mina = self.imt.min() maxa = self.imt.max() diff --git a/tests/test_scale_intensity_range_percentiles.py b/tests/test_scale_intensity_range_percentiles.py index 6ccfaf7ba6..f8656dd929 100644 --- a/tests/test_scale_intensity_range_percentiles.py +++ b/tests/test_scale_intensity_range_percentiles.py @@ -66,7 +66,7 @@ def test_invalid_instantiation(self): self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=30, upper=900, b_min=0, b_max=255) def test_channel_wise(self): - img = self.imt[0] + img = np.tile(self.imt, (3, 1, 1, 1)) lower = 10 upper = 99 b_min = 0 @@ -78,8 +78,8 @@ def test_channel_wise(self): for c in img: a_min = np.percentile(c, lower) a_max = np.percentile(c, upper) - expected.append(((c - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min) - expected = np.stack(expected).astype(np.uint8) + expected.append((((c - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min).astype(np.uint8)) + expected = np.stack(expected) for p in TEST_NDARRAYS: result = scaler(p(img)) diff --git a/tests/test_scale_intensity_range_percentilesd.py b/tests/test_scale_intensity_range_percentilesd.py index cd0a5f8c35..5441832a77 100644 --- a/tests/test_scale_intensity_range_percentilesd.py +++ b/tests/test_scale_intensity_range_percentilesd.py @@ -75,7 +75,7 @@ def test_invalid_instantiation(self): s(self.imt) def test_channel_wise(self): - img = self.imt + img = np.tile(self.imt, (3, 1, 1, 1)) lower = 10 upper = 99 b_min = 0 diff --git a/tests/test_separable_filter.py b/tests/test_separable_filter.py index 167add9556..e152ad2c2b 100644 --- a/tests/test_separable_filter.py +++ b/tests/test_separable_filter.py @@ -78,7 +78,7 @@ def test_3d(self): def test_wrong_args(self): with self.assertRaisesRegex(TypeError, ""): - separable_filtering(((1, 1, 1, 2, 3, 2)), torch.ones((2,))) # type: ignore + separable_filtering(((1, 1, 1, 2, 3, 2)), torch.ones((2,))) if __name__ == "__main__": diff --git a/tests/test_spacing.py b/tests/test_spacing.py index ebff25712d..80df981b73 100644 --- a/tests/test_spacing.py +++ b/tests/test_spacing.py @@ -15,6 +15,7 @@ import torch from parameterized import parameterized +from monai.data.utils import affine_to_spacing from monai.transforms import Spacing from monai.utils import ensure_tuple, fall_back_tuple from tests.utils import TEST_NDARRAYS @@ -78,7 +79,7 @@ TESTS.append( [ p, - {"pixdim": (1.0, 1.0)}, + {"pixdim": (1.0, 1.0), "align_corners": True}, np.arange(24).reshape((2, 3, 4)), # data {}, np.array( @@ -192,6 +193,15 @@ np.array([[[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]]]), ] ) + TESTS.append( # 5D input + [ + p, + {"pixdim": [-1, -1, 0.5], "padding_mode": "zeros", "dtype": float, "align_corners": True}, + np.ones((1, 2, 2, 2, 1)), # data + {"affine": np.eye(4)}, + np.ones((1, 2, 2, 3, 1)), + ] + ) class TestSpacingCase(unittest.TestCase): @@ -203,13 +213,13 @@ def test_spacing(self, in_type, init_param, img, data_param, expected_output): self.assertEqual(_img.device, output_data.device) output_data = output_data.cpu() - np.testing.assert_allclose(output_data, expected_output, atol=1e-3, rtol=1e-3) - sr = len(output_data.shape) - 1 + np.testing.assert_allclose(output_data, expected_output, atol=1e-1, rtol=1e-1) + sr = min(len(output_data.shape) - 1, 3) if isinstance(init_param["pixdim"], float): init_param["pixdim"] = [init_param["pixdim"]] * sr init_pixdim = ensure_tuple(init_param["pixdim"]) init_pixdim = init_param["pixdim"][:sr] - norm = np.sqrt(np.sum(np.square(new_affine), axis=0))[:sr] + norm = affine_to_spacing(new_affine, sr) np.testing.assert_allclose(fall_back_tuple(init_pixdim, norm), norm) diff --git a/tests/test_spatial_resample.py b/tests/test_spatial_resample.py new file mode 100644 index 0000000000..9ee84de85b --- /dev/null +++ b/tests/test_spatial_resample.py @@ -0,0 +1,146 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.config import USE_COMPILED +from monai.transforms import SpatialResample +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] + +for ind, dst in enumerate( + [ + np.asarray([[1.0, 0.0, 0.0], [0.0, -1.0, 1.0], [0.0, 0.0, 1.0]]), # flip the second + np.asarray([[-1.0, 0.0, 1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), # flip the first + ] +): + for p in TEST_NDARRAYS: + for p_src in TEST_NDARRAYS: + for align in (False, True): + for interp_mode in ("nearest", "bilinear"): + TESTS.append( + [ + {}, # default no params + np.arange(4).reshape((1, 2, 2)) + 1.0, # data + { + "src_affine": p_src(np.eye(3)), + "dst_affine": p(dst), + "dtype": np.float32, + "align_corners": align, + "mode": interp_mode, + "padding_mode": "zeros", + }, + np.array([[[2.0, 1.0], [4.0, 3.0]]]) if ind == 0 else np.array([[[3.0, 4.0], [1.0, 2.0]]]), + ] + ) + +for ind, dst in enumerate( + [ + np.asarray([[1.0, 0.0, 0.0, 0.0], [0.0, -1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]), + np.asarray([[-1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]), + ] +): + for p_src in TEST_NDARRAYS: + for align in (True, False): + if align and USE_COMPILED: + interp = ("nearest", "bilinear", 0, 1) + else: + interp = ("nearest", "bilinear") # type: ignore + for interp_mode in interp: # type: ignore + for padding_mode in ("zeros", "border", "reflection"): + TESTS.append( + [ + {}, # default no params + np.arange(12).reshape((1, 2, 2, 3)) + 1.0, # data + { + "src_affine": p_src(np.eye(4)), + "dst_affine": p_src(dst), + "dtype": np.float64, + "align_corners": align, + "mode": interp_mode, + "padding_mode": padding_mode, + }, + np.array([[[[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]], [[10.0, 11.0, 12.0], [7.0, 8.0, 9.0]]]]) + if ind == 0 + else np.array( + [[[[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]] + ), + ] + ) + + +class TestSpatialResample(unittest.TestCase): + @parameterized.expand(itertools.product(TEST_NDARRAYS, TESTS)) + def test_flips(self, p_type, args): + init_param, img, data_param, expected_output = args + _img = p_type(img) + _expected_output = p_type(expected_output) + output_data, output_dst = SpatialResample(**init_param)(img=_img, **data_param) + assert_allclose(output_data, _expected_output, rtol=1e-2, atol=1e-2) + expected_dst = ( + data_param.get("dst_affine") if data_param.get("dst_affine") is not None else data_param.get("src_affine") + ) + assert_allclose(output_dst, expected_dst, type_test=False, rtol=1e-2, atol=1e-2) + + @parameterized.expand(itertools.product([True, False], TEST_NDARRAYS)) + def test_4d_5d(self, is_5d, p_type): + new_shape = (1, 2, 2, 3, 1, 1) if is_5d else (1, 2, 2, 3, 1) + img = np.arange(12).reshape(new_shape) + img = np.tile(img, (1, 1, 1, 1, 2, 2) if is_5d else (1, 1, 1, 1, 2)) + _img = p_type(img) + dst = np.asarray([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, -1.0, 1.5], [0.0, 0.0, 0.0, 1.0]]) + output_data, output_dst = SpatialResample(dtype=np.float32)( + img=_img, src_affine=p_type(np.eye(4)), dst_affine=dst + ) + expected_data = ( + np.asarray( + [ + [ + [[[0.0, 0.0], [0.0, 1.0]], [[0.5, 0.0], [1.5, 1.0]], [[1.0, 2.0], [2.0, 2.0]]], + [[[3.0, 3.0], [3.0, 4.0]], [[3.5, 3.0], [4.5, 4.0]], [[4.0, 5.0], [5.0, 5.0]]], + ], + [ + [[[6.0, 6.0], [6.0, 7.0]], [[6.5, 6.0], [7.5, 7.0]], [[7.0, 8.0], [8.0, 8.0]]], + [[[9.0, 9.0], [9.0, 10.0]], [[9.5, 9.0], [10.5, 10.0]], [[10.0, 11.0], [11.0, 11.0]]], + ], + ], + dtype=np.float32, + ) + if is_5d + else np.asarray( + [ + [[[0.5, 0.0], [0.0, 2.0], [1.5, 1.0]], [[3.5, 3.0], [3.0, 5.0], [4.5, 4.0]]], + [[[6.5, 6.0], [6.0, 8.0], [7.5, 7.0]], [[9.5, 9.0], [9.0, 11.0], [10.5, 10.0]]], + ], + dtype=np.float32, + ) + ) + assert_allclose(output_data, p_type(expected_data[None]), rtol=1e-2, atol=1e-2) + assert_allclose(output_dst, dst, type_test=False, rtol=1e-2, atol=1e-2) + + def test_ill_affine(self): + img = np.arange(12).reshape(1, 2, 2, 3) + ill_affine = np.asarray( + [[1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, -1.0, 1.5], [0.0, 0.0, 0.0, 1.0]] + ) + with self.assertRaises(ValueError): + SpatialResample()(img=img, src_affine=np.eye(4), dst_affine=ill_affine) + with self.assertRaises(ValueError): + SpatialResample()(img=img, src_affine=ill_affine, dst_affine=np.eye(3)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_spatial_resampled.py b/tests/test_spatial_resampled.py new file mode 100644 index 0000000000..73f83791d9 --- /dev/null +++ b/tests/test_spatial_resampled.py @@ -0,0 +1,113 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.config import USE_COMPILED +from monai.transforms import SpatialResampleD +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] + +for ind, dst in enumerate( + [ + np.asarray([[1.0, 0.0, 0.0], [0.0, -1.0, 1.0], [0.0, 0.0, 1.0]]), # flip the second + np.asarray([[-1.0, 0.0, 1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), # flip the first + ] +): + for p in TEST_NDARRAYS: + for p_src in TEST_NDARRAYS: + for align in (False, True): + for interp_mode in ("nearest", "bilinear"): + TESTS.append( + [ + {}, # default no params + np.arange(4).reshape((1, 2, 2)) + 1.0, # data + { + "src": p_src(np.eye(3)), + "dst": p(dst), + "dtype": np.float32, + "align_corners": align, + "mode": interp_mode, + "padding_mode": "zeros", + }, + np.array([[[2.0, 1.0], [4.0, 3.0]]]) if ind == 0 else np.array([[[3.0, 4.0], [1.0, 2.0]]]), + ] + ) + +for ind, dst in enumerate( + [ + np.asarray([[1.0, 0.0, 0.0, 0.0], [0.0, -1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]), + np.asarray([[-1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]), + ] +): + for p_src in TEST_NDARRAYS: + for align in (True, False): + if align and USE_COMPILED: + interp = ("nearest", "bilinear", 0, 1) + else: + interp = ("nearest", "bilinear") # type: ignore + for interp_mode in interp: # type: ignore + for padding_mode in ("zeros", "border", "reflection"): + TESTS.append( + [ + {}, # default no params + np.arange(12).reshape((1, 2, 2, 3)) + 1.0, # data + { + "src": p_src(np.eye(4)), + "dst": p_src(dst), + "dtype": np.float64, + "align_corners": align, + "mode": interp_mode, + "padding_mode": padding_mode, + }, + np.array([[[[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]], [[10.0, 11.0, 12.0], [7.0, 8.0, 9.0]]]]) + if ind == 0 + else np.array( + [[[[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]] + ), + ] + ) + + +class TestSpatialResample(unittest.TestCase): + @parameterized.expand(itertools.product(TEST_NDARRAYS, TESTS)) + def test_flips_inverse(self, p_type, args): + _, img, data_param, expected_output = args + _img = p_type(img) + _expected_output = p_type(expected_output) + input_dict = {"img": _img, "img_meta_dict": {"src": data_param.get("src"), "dst": data_param.get("dst")}} + xform = SpatialResampleD( + keys="img", + meta_src_keys="src", + meta_dst_keys="dst", + mode=data_param.get("mode"), + padding_mode=data_param.get("padding_mode"), + align_corners=data_param.get("align_corners"), + ) + output_data = xform(input_dict) + assert_allclose(output_data["img"], _expected_output, rtol=1e-2, atol=1e-2) + assert_allclose( + output_data["img_meta_dict"]["src"], data_param.get("dst"), type_test=False, rtol=1e-2, atol=1e-2 + ) + + inverted = xform.inverse(output_data) + self.assertEqual(inverted["img_transforms"], []) # no further invert after inverting + assert_allclose(inverted["img_meta_dict"]["src"], data_param.get("src"), type_test=False, rtol=1e-2, atol=1e-2) + assert_allclose(inverted["img"], _img, rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index 9e22a5609f..21186adc3c 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -74,12 +74,13 @@ def tearDown(self) -> None: set_determinism(None) def test_test_time_augmentation(self): - input_size = (20, 20) - device = "cuda" if torch.cuda.is_available() else "cpu" + input_size = (20, 40) # test different input data shape to pad list collate keys = ["image", "label"] num_training_ims = 10 + train_data = self.get_data(num_training_ims, input_size) test_data = self.get_data(1, input_size) + device = "cuda" if torch.cuda.is_available() else "cpu" transforms = Compose( [ @@ -125,21 +126,28 @@ def test_test_time_augmentation(self): post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) - def inferrer_fn(x): - return post_trans(model(x)) - - tt_aug = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=inferrer_fn, device=device) + tt_aug = TestTimeAugmentation( + transform=transforms, + batch_size=5, + num_workers=0, + inferrer_fn=model, + device=device, + to_tensor=True, + output_device="cpu", + post_func=post_trans, + ) mode, mean, std, vvc = tt_aug(test_data) self.assertEqual(mode.shape, (1,) + input_size) self.assertEqual(mean.shape, (1,) + input_size) self.assertTrue(all(np.unique(mode) == (0, 1))) - self.assertEqual((mean.min(), mean.max()), (0.0, 1.0)) + self.assertGreaterEqual(mean.min(), 0.0) + self.assertLessEqual(mean.max(), 1.0) self.assertEqual(std.shape, (1,) + input_size) self.assertIsInstance(vvc, float) - def test_fail_non_random(self): + def test_warn_non_random(self): transforms = Compose([AddChanneld("im"), SpatialPadd("im", 1)]) - with self.assertRaises(RuntimeError): + with self.assertWarns(UserWarning): TestTimeAugmentation(transforms, None, None, None) def test_warn_random_but_has_no_invertible(self): diff --git a/tests/test_tile_on_grid.py b/tests/test_tile_on_grid.py index 08fb5d96fe..09434de5e0 100644 --- a/tests/test_tile_on_grid.py +++ b/tests/test_tile_on_grid.py @@ -105,9 +105,9 @@ def make_image( for y in range(tile_count): tiles_list.append(image[:, x * step : x * step + tile_size, y * step : y * step + tile_size]) - tiles = np.stack(tiles_list, axis=0) # type: ignore + tiles = np.stack(tiles_list, axis=0) - if (filter_mode == "min" or filter_mode == "max") and len(tiles) > tile_count ** 2: + if (filter_mode == "min" or filter_mode == "max") and len(tiles) > tile_count**2: tiles = tiles[np.argsort(tiles.sum(axis=(1, 2, 3)))] return imlarge, tiles diff --git a/tests/test_tile_on_grid_dict.py b/tests/test_tile_on_grid_dict.py index 9f1d67ac29..c6f35fe738 100644 --- a/tests/test_tile_on_grid_dict.py +++ b/tests/test_tile_on_grid_dict.py @@ -114,9 +114,9 @@ def make_image( for y in range(tile_count): tiles_list.append(image[:, x * step : x * step + tile_size, y * step : y * step + tile_size]) - tiles = np.stack(tiles_list, axis=0) # type: ignore + tiles = np.stack(tiles_list, axis=0) - if (filter_mode == "min" or filter_mode == "max") and len(tiles) > tile_count ** 2: + if (filter_mode == "min" or filter_mode == "max") and len(tiles) > tile_count**2: tiles = tiles[np.argsort(tiles.sum(axis=(1, 2, 3)))] return imlarge, tiles diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py index 4db3056b7b..b13378debe 100644 --- a/tests/test_utils_pytorch_numpy_unification.py +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -13,11 +13,18 @@ import numpy as np import torch +from parameterized import parameterized -from monai.transforms.utils_pytorch_numpy_unification import percentile +from monai.transforms.utils_pytorch_numpy_unification import mode, percentile from monai.utils import set_determinism from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose +TEST_MODE = [] +for p in TEST_NDARRAYS: + TEST_MODE.append([p(np.array([1, 2, 3, 4, 4, 5])), p(4), False]) + TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4.1), False]) + TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4), True]) + class TestPytorchNumpyUnification(unittest.TestCase): def setUp(self) -> None: @@ -54,6 +61,11 @@ def test_dim(self): atol = 0.5 if not hasattr(torch, "quantile") else 1e-4 assert_allclose(results[0], results[-1], type_test=False, atol=atol) + @parameterized.expand(TEST_MODE) + def test_mode(self, array, expected, to_long): + res = mode(array, to_long=to_long) + assert_allclose(res, expected) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_vit.py b/tests/test_vit.py index 870e4010ec..d5ae209e50 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -16,6 +16,7 @@ from monai.networks import eval_mode from monai.networks.nets.vit import ViT +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save TEST_CASE_Vit = [] for dropout_rate in [0.6]: @@ -27,7 +28,7 @@ for mlp_dim in [3072]: for num_layers in [4]: for num_classes in [8]: - for pos_embed in ["conv"]: + for pos_embed in ["conv", "perceptron"]: for classification in [False, True]: for nd in (2, 3): test_case = [ @@ -133,6 +134,17 @@ def test_ill_arg(self): dropout_rate=0.3, ) + @parameterized.expand(TEST_CASE_Vit) + @SkipIfBeforePyTorchVersion((1, 9)) + def test_script(self, input_param, input_shape, _): + net = ViT(**(input_param)) + net.eval() + with torch.no_grad(): + torch.jit.script(net) + + test_data = torch.randn(input_shape) + test_script_save(net, test_data) + if __name__ == "__main__": unittest.main() diff --git a/tests/utils.py b/tests/utils.py index 232c9c9030..0af019f0b0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,10 +21,11 @@ import traceback import unittest import warnings +from contextlib import contextmanager from functools import partial from subprocess import PIPE, Popen from typing import Callable, Optional, Tuple -from urllib.error import ContentTooShortError, HTTPError, URLError +from urllib.error import ContentTooShortError, HTTPError import numpy as np import torch @@ -93,17 +94,25 @@ def assert_allclose( np.testing.assert_allclose(actual, desired, *args, **kwargs) -def test_pretrained_networks(network, input_param, device): +@contextmanager +def skip_if_downloading_fails(): try: + yield + except (ContentTooShortError, HTTPError, ConnectionError) as e: + raise unittest.SkipTest(f"error while downloading: {e}") from e + except RuntimeError as rt_e: + if "network issue" in str(rt_e): + raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e + if "gdown dependency" in str(rt_e): # no gdown installed + raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e + if "md5 check" in str(rt_e): + raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e + raise rt_e + + +def test_pretrained_networks(network, input_param, device): + with skip_if_downloading_fails(): return network(**input_param).to(device) - except (URLError, HTTPError) as e: - raise unittest.SkipTest(e) from e - except RuntimeError as r_error: - if "unexpected EOF" in f"{r_error}": # The file might be corrupted. - raise unittest.SkipTest(f"{r_error}") from r_error - if "network issue" in f"{r_error}": # The network is not available. - raise unittest.SkipTest(f"{r_error}") from r_error - raise def test_is_quick(): @@ -651,16 +660,8 @@ def test_script_save(net, *inputs, device=None, rtol=1e-4, atol=0.0): def download_url_or_skip_test(*args, **kwargs): """``download_url`` and skip the tests if any downloading error occurs.""" - try: + with skip_if_downloading_fails(): download_url(*args, **kwargs) - except (ContentTooShortError, HTTPError) as e: - raise unittest.SkipTest(f"error while downloading: {e}") from e - except RuntimeError as rt_e: - if "network issue" in str(rt_e): - raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e - if "gdown dependency" in str(rt_e): # no gdown installed - raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e - raise rt_e def query_memory(n=2): From 393887dd7276e0dfaa262b11916cec337d9f9c98 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 17 Feb 2022 13:25:40 +0800 Subject: [PATCH 21/30] [DLMED] update reference resolver Signed-off-by: Nic Ma --- docs/source/apps.rst | 17 +- monai/apps/__init__.py | 20 +- monai/apps/manifest/__init__.py | 12 + monai/apps/manifest/config_item.py | 355 ++++++++++++++++++++ monai/apps/manifest/reference_resolver.py | 249 ++++++++++++++ monai/apps/mmars/__init__.py | 10 - monai/apps/mmars/config_item.py | 378 ---------------------- monai/apps/mmars/config_parser.py | 280 ---------------- monai/apps/mmars/mmars.py | 31 +- monai/apps/mmars/utils.py | 169 ---------- monai/utils/module.py | 8 +- tests/test_component_locator.py | 2 +- tests/test_config_item.py | 107 ++---- 13 files changed, 691 insertions(+), 947 deletions(-) create mode 100644 monai/apps/manifest/__init__.py create mode 100644 monai/apps/manifest/config_item.py create mode 100644 monai/apps/manifest/reference_resolver.py delete mode 100644 monai/apps/mmars/config_item.py delete mode 100644 monai/apps/mmars/config_parser.py delete mode 100644 monai/apps/mmars/utils.py diff --git a/docs/source/apps.rst b/docs/source/apps.rst index dea6c56718..d0fe131d85 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -29,16 +29,25 @@ Clara MMARs :annotation: -Model Package -------------- +Model Manifest +-------------- -.. autoclass:: ConfigParser +.. autoclass:: ComponentLocator :members: .. autoclass:: ConfigComponent :members: -.. autoclass:: ConfigResolver +.. autoclass:: ConfigExpression + :members: + +.. autoclass:: ConfigItem + :members: + +.. autoclass:: ConfigParser + :members: + +.. autoclass:: ReferenceResolver :members: diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index 2c3297d023..df085bddea 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -10,22 +10,6 @@ # limitations under the License. from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset -from .mmars import ( - MODEL_DESC, - ComponentLocator, - ConfigComponent, - ConfigItem, - ConfigParser, - ConfigResolver, - RemoteMMARKeys, - download_mmar, - find_refs_in_config, - get_model_spec, - is_expression, - is_instantiable, - load_from_mmar, - match_refs_pattern, - resolve_config_with_refs, - resolve_refs_pattern, -) +from .manifest import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem +from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/manifest/__init__.py b/monai/apps/manifest/__init__.py new file mode 100644 index 0000000000..d8919a5249 --- /dev/null +++ b/monai/apps/manifest/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem diff --git a/monai/apps/manifest/config_item.py b/monai/apps/manifest/config_item.py new file mode 100644 index 0000000000..e8f49f4b03 --- /dev/null +++ b/monai/apps/manifest/config_item.py @@ -0,0 +1,355 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import sys +import warnings +from abc import ABC, abstractmethod +from importlib import import_module +from typing import Any, Dict, List, Mapping, Optional, Sequence, Union + +from monai.utils import ensure_tuple, instantiate + +__all__ = ["ComponentLocator", "ConfigItem", "ConfigExpression", "ConfigComponent"] + + +class Instantiable(ABC): + """ + Base class for instantiable object with module name and arguments. + + .. code-block:: python + + if not is_disabled(): + instantiate(module_name=resolve_module_name(), args=resolve_args()) + + """ + + @abstractmethod + def resolve_module_name(self, *args: Any, **kwargs: Any): + """ + Resolve the target module name, it should return an object class (or function) to be instantiated. + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def resolve_args(self, *args: Any, **kwargs: Any): + """ + Resolve the arguments, it should return arguments to be passed to the object when instantiating. + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def is_disabled(self, *args: Any, **kwargs: Any) -> bool: + """ + Return a boolean flag to indicate whether the object should be instantiated. + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def instantiate(self, *args: Any, **kwargs: Any): + """ + Instantiate the target component. + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + +class ComponentLocator: + """ + Scan all the available classes and functions in the MONAI package and map them with the module paths in a table. + It's used to locate the module path for provided component name. + + Args: + excludes: if any string of the `excludes` exists in the full module name, don't import this module. + + """ + + MOD_START = "monai" + + def __init__(self, excludes: Optional[Union[Sequence[str], str]] = None): + self.excludes = [] if excludes is None else ensure_tuple(excludes) + self._components_table: Optional[Dict[str, List]] = None + + def _find_module_names(self) -> List[str]: + """ + Find all the modules start with MOD_START and don't contain any of `excludes`. + + """ + return [ + m for m in sys.modules.keys() if m.startswith(self.MOD_START) and all(s not in m for s in self.excludes) + ] + + def _find_classes_or_functions(self, modnames: Union[Sequence[str], str]) -> Dict[str, List]: + """ + Find all the classes and functions in the modules with specified `modnames`. + + Args: + modnames: names of the target modules to find all the classes and functions. + + """ + table: Dict[str, List] = {} + # all the MONAI modules are already loaded by `load_submodules` + for modname in ensure_tuple(modnames): + try: + # scan all the classes and functions in the module + module = import_module(modname) + for name, obj in inspect.getmembers(module): + if (inspect.isclass(obj) or inspect.isfunction(obj)) and obj.__module__ == modname: + if name not in table: + table[name] = [] + table[name].append(modname) + except ModuleNotFoundError: + pass + return table + + def get_component_module_name(self, name: str) -> Optional[Union[List[str], str]]: + """ + Get the full module name of the class or function with specified ``name``. + If target component name exists in multiple packages or modules, return a list of full module names. + + Args: + name: name of the expected class or function. + + """ + if not isinstance(name, str): + raise ValueError(f"`name` must be a valid string, but got: {name}.") + if self._components_table is None: + # init component and module mapping table + self._components_table = self._find_classes_or_functions(self._find_module_names()) + + mods: Optional[Union[List[str], str]] = self._components_table.get(name, None) + if isinstance(mods, list) and len(mods) == 1: + mods = mods[0] + return mods + + +class ConfigItem: + """ + Basic data structure to represent a configuration item. + + A `ConfigItem` instance can optionally have a string id, so that other items can refer to it. + It has a build-in `config` property to store the configuration object. + + Args: + config: content of a config item, can be objects of any types, + a configuration resolver may interpret the content to generate a configuration object. + id: optional name of the current config item, defaults to `None`. + + """ + + def __init__(self, config: Any, id: Optional[str] = None) -> None: + self.config = config + self.id = id + + def get_id(self) -> Optional[str]: + """ + Get the ID name of current config item, useful to identify config items during parsing. + + """ + return self.id + + def update_config(self, config: Any): + """ + Replace the content of `self.config` with new `config`. + A typical usage is to modify the initial config content at runtime. + + Args: + config: content of a `ConfigItem`. + + """ + self.config = config + + def get_config(self): + """ + Get the config content of current config item. + + """ + return self.config + + +class ConfigComponent(ConfigItem, Instantiable): + """ + Subclass of :py:class:`monai.apps.ConfigItem`, this class uses a dictionary with string keys to + represent a component of `class` or `function` and supports instantiation. + + Currently, four special keys (strings surrounded by ``<>``) are defined and interpreted beyond the regular literals: + + - class or function identifier of the python module, specified by one of the two keys. + - ``""``: indicates build-in python classes or functions such as "LoadImageDict". + - ``""``: full module name, such as "monai.transforms.LoadImageDict". + - ``""``: input arguments to the python module. + - ``""``: a flag to indicate whether to skip the instantiation. + + .. code-block:: python + + locator = ComponentLocator(excludes=["modules_to_exclude"]) + config = { + "": "LoadImaged", + "": { + "keys": ["image", "label"] + } + } + + configer = ConfigComponent(config, id="test", locator=locator) + image_loader = configer.instantiate() + print(image_loader) # + + Args: + config: content of a config item. + id: optional name of the current config item, defaults to `None`. + locator: a ``ComponentLocator`` to convert a module name string into the actual python module. + if `None`, a ``ComponentLocator(excludes=excludes)`` will be used. + excludes: if ``locator`` is None, create a new ``ComponentLocator`` with ``excludes``. + See also: :py:class:`monai.apps.manifest.ComponentLocator`. + + """ + + def __init__( + self, + config: Any, + id: Optional[str] = None, + locator: Optional[ComponentLocator] = None, + excludes: Optional[Union[Sequence[str], str]] = None, + ) -> None: + super().__init__(config=config, id=id) + self.locator = ComponentLocator(excludes=excludes) if locator is None else locator + + @staticmethod + def is_instantiable(config: Any) -> bool: + """ + Check whether this config represents a `class` or `function` that is to be instantiated. + + Args: + config: input config content to check. + + """ + return isinstance(config, Mapping) and ("" in config or "" in config) + + def resolve_module_name(self): + """ + Resolve the target module name from current config content. + The config content must have ``""`` or ``""``. + When both are specified, ``""`` will be used. + + """ + config = dict(self.get_config()) + path = config.get("") + if path is not None: + if not isinstance(path, str): + raise ValueError(f"'' must be a string, but got: {path}.") + if "" in config: + warnings.warn(f"both '' and '', default to use '': {path}.") + return path + + name = config.get("") + if not isinstance(name, str): + raise ValueError("must provide a string for `` or `` of target component to instantiate.") + + module = self.locator.get_component_module_name(name) + if module is None: + raise ModuleNotFoundError(f"can not find component '{name}' in {self.locator.MOD_START} modules.") + if isinstance(module, list): + warnings.warn( + f"there are more than 1 component have name `{name}`: {module}, use the first one `{module[0]}." + f" if want to use others, please set its module path in `` directly." + ) + module = module[0] + return f"{module}.{name}" + + def resolve_args(self): + """ + Utility function used in `instantiate()` to resolve the arguments from current config content. + + """ + return self.get_config().get("", {}) + + def is_disabled(self) -> bool: # type: ignore + """ + Utility function used in `instantiate()` to check whether to skip the instantiation. + + """ + _is_disabled = self.get_config().get("", False) + return _is_disabled.lower().strip() == "true" if isinstance(_is_disabled, str) else bool(_is_disabled) + + def instantiate(self, **kwargs) -> object: # type: ignore + """ + Instantiate component based on ``self.config`` content. + The target component must be a `class` or a `function`, otherwise, return `None`. + + Args: + kwargs: args to override / add the config args when instantiation. + + """ + if not ConfigComponent.is_instantiable(self.get_config()) or self.is_disabled(): + # if not a class or function or marked as `disabled`, skip parsing and return `None` + return None + + modname = self.resolve_module_name() + args = self.resolve_args() + args.update(kwargs) + return instantiate(modname, **args) + + +class ConfigExpression(ConfigItem): + """ + Subclass of :py:class:`monai.apps.ConfigItem`, the `ConfigItem` represents an executable expression + (execute based on ``eval()``). + + See also: + + - https://docs.python.org/3/library/functions.html#eval. + + For example: + + .. code-block:: python + + import monai + from monai.apps.manifest import ConfigExpression + + config = "$monai.__version__" + expression = ConfigExpression(config, id="test", globals={"monai": monai}) + print(expression.execute()) + + Args: + config: content of a config item. + id: optional name of current config item, defaults to `None`. + globals: additional global context to evaluate the string. + + """ + + def __init__(self, config: Any, id: Optional[str] = None, globals: Optional[Dict] = None) -> None: + super().__init__(config=config, id=id) + self.globals = globals + + def execute(self, locals: Optional[Dict] = None): + """ + Excute current config content and return the result if it is expression, based on python `eval()`. + For more details: https://docs.python.org/3/library/functions.html#eval. + + Args: + locals: besides ``globals``, may also have some local symbols used in the expression at runtime. + + """ + value = self.get_config() + if not ConfigExpression.is_expression(value): + return None + return eval(value[1:], self.globals, locals) + + @staticmethod + def is_expression(config: Union[Dict, List, str]) -> bool: + """ + Check whether the config is an executable expression string. + Currently A string starts with ``"$"`` character is interpreted as an expression. + + Args: + config: input config content to check. + + """ + return isinstance(config, str) and config.startswith("$") diff --git a/monai/apps/manifest/reference_resolver.py b/monai/apps/manifest/reference_resolver.py new file mode 100644 index 0000000000..3c7a36b1c7 --- /dev/null +++ b/monai/apps/manifest/reference_resolver.py @@ -0,0 +1,249 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Callable, Dict, List, Optional, Union +import warnings + +from monai.apps.manifest.config_item import ConfigComponent, ConfigExpression, ConfigItem + + +class ReferenceResolver: + """ + Utility class to resolve the references between config items. + + Args: + components: config components to resolve, if None, can also `add()` component in runtime. + + """ + + def __init__(self, items: Optional[Dict[str, ConfigItem]] = None): + self.items = {} if items is None else items + self.resolved_content = {} + + def add(self, item: ConfigItem): + """ + Add a config item to the resolution graph. + + Args: + item: a config item to resolve. + + """ + id = item.get_id() + if id in self.items: + warnings.warn(f"id '{id}' is already added.") + return + self.items[id] = item + + def resolve_one_item(self, id: str, waiting_list: Optional[List[str]] = None): + """ + Resolve one item with specified id name. + If has unresolved references, recursively resolve the references first. + + Args: + id: id name of expected item to resolve. + waiting_list: list of items wait to resolve references. it's used to detect circular references. + when resolving references like: `{"name": "A", "dep": "@B"}` and `{"name": "B", "dep": "@A"}`. + + """ + if waiting_list is None: + waiting_list = [] + waiting_list.append(id) + item = self.items.get(id) + item_config = item.get_config() + ref_ids = self.find_refs_in_config(config=item_config, id=id) + + # if current item has reference already in the waiting list, that's circular references + for d in ref_ids: + if d in waiting_list: + raise ValueError(f"detected circular references for id='{d}' in the config content.") + + if len(ref_ids) > 0: + # # check whether the component has any unresolved deps + for ref_id in ref_ids: + if ref_id not in self.resolved_content: + # this reffring component is not resolved + if ref_id not in self.items: + raise RuntimeError(f"the referring item `{ref_id}` is not defined in config.") + # resolve the reference first + self.resolve_one_item(id=ref_id, waiting_list=waiting_list) + + # all references are resolved + new_config = self.resolve_config_with_refs(config=item_config, id=id, refs=self.resolved_content) + item.update_config(config=new_config) + if ConfigComponent.is_instantiable(new_config): + self.resolved_content[id] = item.instantiate() + if ConfigExpression.is_expression(new_config): + self.resolved_content[id] = item.execute(locals={"refs": self.resolved_content}) + else: + self.resolved_content[id] = new_config + + def resolve_all(self): + """ + Resolve the references for all the config items. + + """ + for k in self.items.keys(): + self.resolve_one_item(id=k) + + def get_resolved_content(self, id: str): + """ + Get the resolved content with specified id name. + If not resolved, try to resolve it first. + + Args: + id: id name of the expected item. + + """ + if id not in self.resolved_content: + self.resolve_one_item(id=id) + return self.resolved_content(id) + + def get_resolved_config(self, id: str): + """ + Get the resolved config content with specified id name, then can be used for lazy instantiation. + If not resolved, try to resolve it first. + + Args: + id: id name of the expected config item. + + """ + if id not in self.resolved_content: + self.resolve_one_item(id=id) + return self.items(id) + + @staticmethod + def match_refs_pattern(value: str) -> List[str]: + """ + Match regular expression for the input string to find the references. + The reference part starts with "@", like: "@XXX#YYY#ZZZ". + + Args: + value: input value to match regular expression. + + """ + refs: List[str] = [] + # regular expression pattern to match "@XXX" or "@XXX#YYY" + result = re.compile(r"@\w*[\#\w]*").findall(value) + for item in result: + if ConfigExpression.is_expression(value) or value == item: + # only check when string starts with "$" or the whole content is "@XXX" + ref_obj_id = item[1:] + if ref_obj_id not in refs: + refs.append(ref_obj_id) + return refs + + @staticmethod + def resolve_refs_pattern(value: str, refs: Dict) -> str: + """ + Match regular expression for the input string to update content with the references. + The reference part starts with "@", like: "@XXX#YYY#ZZZ". + References dictionary must contain the referring IDs as keys. + + Args: + value: input value to match regular expression. + refs: all the referring components with ids as keys, default to `None`. + + """ + # regular expression pattern to match "@XXX" or "@XXX#YYY" + result = re.compile(r"@\w*[\#\w]*").findall(value) + for item in result: + ref_id = item[1:] + if ConfigExpression.is_expression(value): + # replace with local code and execute later + value = value.replace(item, f"refs['{ref_id}']") + elif value == item: + if ref_id not in refs: + raise KeyError(f"can not find expected ID '{ref_id}' in the references.") + value = refs[ref_id] + return value + + @staticmethod + def find_refs_in_config( + config: Union[Dict, List, str], + id: Optional[str] = None, + refs: Optional[List[str]] = None, + match_fn: Callable = match_refs_pattern, + ) -> List[str]: + """ + Recursively search all the content of input config item to get the ids of references. + References mean (1) referring to the ID of other item, can be extracted by `match_fn`, for example: + `{"lr": "$@epoch / 100"}` with "@" mark, the referring IDs: `["epoch"]`. (2) if sub-item in the config + is instantiable, treat it as reference because must instantiate it before resolving current config. + For `dict` and `list`, recursively check the sub-items. + + Args: + config: input config content to search. + id: ID name for the input config, default to `None`. + refs: list of the ID name of existing references, default to `None`. + match_fn: callable function to match config item for references, take `config` as parameter. + + """ + refs_: List[str] = [] if refs is None else refs + if isinstance(config, str): + refs_ += match_fn(value=config) + + if isinstance(config, list): + for i, v in enumerate(config): + sub_id = f"{id}#{i}" if id is not None else f"{i}" + if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v): + refs_.append(sub_id) + refs_ = ReferenceResolver.find_refs_in_config(v, sub_id, refs_, match_fn) + if isinstance(config, dict): + for k, v in config.items(): + sub_id = f"{id}#{k}" if id is not None else f"{k}" + if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v): + refs_.append(sub_id) + refs_ = ReferenceResolver.find_refs_in_config(v, sub_id, refs_, match_fn) + return refs_ + + @staticmethod + def resolve_config_with_refs( + config: Union[Dict, List, str], + id: Optional[str] = None, + refs: Optional[Dict] = None, + match_fn: Callable = resolve_refs_pattern, + ): + """ + With all the references in `refs`, resolve the config content with them and return new config. + + Args: + config: input config content to resolve. + id: ID name for the input config, default to `None`. + refs: all the referring components with ids, default to `None`. + match_fn: callable function to match config item for references, take `config`, + `refs` and `globals` as parameters. + + """ + refs_: Dict = {} if refs is None else refs + if isinstance(config, str): + config = match_fn(config, refs, globals) + if isinstance(config, list): + # all the items in the list should be replaced with the references + ret_list: List = [] + for i, v in enumerate(config): + sub_id = f"{id}#{i}" if id is not None else f"{i}" + if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v): + ret_list.append(refs_[sub_id]) + else: + ret_list.append(ReferenceResolver.resolve_config_with_refs(v, sub_id, refs_, match_fn)) + return ret_list + if isinstance(config, dict): + # all the items in the dict should be replaced with the references + ret_dict: Dict = {} + for k, v in config.items(): + sub_id = f"{id}#{k}" if id is not None else f"{k}" + if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v): + ret_dict[k] = refs_[sub_id] + else: + ret_dict[k] = ReferenceResolver.resolve_config_with_refs(v, sub_id, refs_, match_fn) + return ret_dict + return config diff --git a/monai/apps/mmars/__init__.py b/monai/apps/mmars/__init__.py index 081c2e511d..8f1448bb06 100644 --- a/monai/apps/mmars/__init__.py +++ b/monai/apps/mmars/__init__.py @@ -9,15 +9,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_item import ComponentLocator, ConfigComponent, ConfigItem -from .config_parser import ConfigParser, ConfigResolver from .mmars import download_mmar, get_model_spec, load_from_mmar from .model_desc import MODEL_DESC, RemoteMMARKeys -from .utils import ( - find_refs_in_config, - is_expression, - is_instantiable, - match_refs_pattern, - resolve_config_with_refs, - resolve_refs_pattern, -) diff --git a/monai/apps/mmars/config_item.py b/monai/apps/mmars/config_item.py deleted file mode 100644 index 536743c5e9..0000000000 --- a/monai/apps/mmars/config_item.py +++ /dev/null @@ -1,378 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import sys -import warnings -from abc import ABC, abstractmethod -from importlib import import_module -from typing import Any, Dict, List, Optional, Sequence, Union - -from monai.apps.mmars.utils import find_refs_in_config, is_instantiable, resolve_config_with_refs -from monai.utils import ensure_tuple, instantiate - -__all__ = ["ComponentLocator", "ConfigItem", "ConfigComponent"] - - -class ComponentLocator: - """ - Scan all the available classes and functions in the MONAI package and map them with the module paths in a table. - It's used to locate the module path for provided component name. - - Args: - excludes: if any string of the `excludes` exists in the full module name, don't import this module. - - """ - - MOD_START = "monai" - - def __init__(self, excludes: Optional[Union[Sequence[str], str]] = None): - self.excludes = [] if excludes is None else ensure_tuple(excludes) - self._components_table: Optional[Dict[str, List]] = None - - def _find_module_names(self) -> List[str]: - """ - Find all the modules start with MOD_START and don't contain any of `excludes`. - - """ - return [ - m for m in sys.modules.keys() if m.startswith(self.MOD_START) and all(s not in m for s in self.excludes) - ] - - def _find_classes_or_functions(self, modnames: Union[Sequence[str], str]) -> Dict[str, List]: - """ - Find all the classes and functions in the modules with specified `modnames`. - - Args: - modnames: names of the target modules to find all the classes and functions. - - """ - table: Dict[str, List] = {} - # all the MONAI modules are already loaded by `load_submodules` - for modname in ensure_tuple(modnames): - try: - # scan all the classes and functions in the module - module = import_module(modname) - for name, obj in inspect.getmembers(module): - if (inspect.isclass(obj) or inspect.isfunction(obj)) and obj.__module__ == modname: - if name not in table: - table[name] = [] - table[name].append(modname) - except ModuleNotFoundError: - pass - return table - - def get_component_module_name(self, name) -> Optional[Union[List[str], str]]: - """ - Get the full module name of the class / function with specified name. - If target component name exists in multiple packages or modules, return a list of full module names. - - Args: - name: name of the expected class or function. - - """ - if self._components_table is None: - # init component and module mapping table - self._components_table = self._find_classes_or_functions(self._find_module_names()) - - mods: Optional[Union[List[str], str]] = self._components_table.get(name, None) - if isinstance(mods, list) and len(mods) == 1: - mods = mods[0] - return mods - - -class ConfigItem: - """ - Utility class to manage every item of the whole config content. - When recursively parsing a complicated config content, every item (like: dict, list, string, int, float, etc.) - can be treated as an "config item" then construct a `ConfigItem`. - For example, below are 5 config items when recursively parsing: - - a dict: `{"preprocessing": ["@transform1", "@transform2", "$lambda x: x"]}` - - a list: `["@transform1", "@transform2", "$lambda x: x"]` - - a string: `"transform1"` - - a string: `"transform2"` - - a string: `"$lambda x: x"` - - `ConfigItem` can set optional unique ID name, then another config item may refer to it. - The references mean the IDs of other config items used as "@XXX" in the current config item, for example: - config item with ID="A" is a list `[1, 2, 3]`, another config item can be `"args": {"input_list": "@A"}`. - If sub-item in the config is instantiable, also treat it as reference because must instantiate it before - resolving current config. - It can search the config content and find out all the references, and resolve the config content - when all the references are resolved. - - Here we predefined 3 kinds special marks (`#`, `@`, `$`) when parsing the whole config content: - - "XXX#YYY": join nested config IDs, like "transforms#5" is ID name of the 6th transform in a list ID="transforms". - - "@XXX": current config item refers to another config item XXX, like `{"args": {"data": "@dataset"}}` uses - resolved config content of `dataset` as the parameter "data". - - "$XXX": execute the string after "$" as python code with `eval()` function, like "$@model.parameters()". - - The typical usage of the APIs: - - Initialize with config content. - - If having references, get the IDs of its referring components. - - When all the referring components are resolved, resolve the config content with them, - and execute expressions in the config. - - .. code-block:: python - - config = {"lr": "$@epoch / 1000"} - - configer = ConfigComponent(config, id="test") - dep_ids = configer.get_id_of_refs() - configer.resolve_config(refs={"epoch": 10}) - lr = configer.get_resolved_config() - - Args: - config: content of a config item, can be a `dict`, `list`, `string`, `float`, `int`, etc. - id: ID name of current config item, useful to construct referring config items. - for example, config item A may have ID "transforms#A" and config item B refers to A - and uses the resolved config content of A as an arg `{"args": {"other": "@transforms#A"}}`. - `id` defaults to `None`, if some component refers to current component, `id` must be a `string`. - globals: to support executable string in the config, sometimes we need to provide the global variables - which are referred in the executable string. for example: `globals={"monai": monai} will be useful - for config `{"collate_fn": "$monai.data.list_data_collate"}`. - - """ - - def __init__(self, config: Any, id: Optional[str] = None, globals: Optional[Dict] = None) -> None: - self.config = config - self.resolved_config = None - self.is_resolved = False - self.id = id - self.globals = globals - self.update_config(config=config) - - def get_id(self) -> Optional[str]: - """ - ID name of current config item, useful to construct referring config items. - for example, config item A may have ID "transforms#A" and config item B refers to A - and uses the resolved config content of A as an arg `{"args": {"other": "@transforms#A"}}`. - `id` defaults to `None`, if some component refers to current component, `id` must be a `string`. - - """ - return self.id - - def update_config(self, config: Any): - """ - Update the config content for a config item at runtime. - If having references, need resolve the config later. - A typical usage is to modify the initial config content at runtime and set back. - - Args: - config: content of a config item, can be a `dict`, `list`, `string`, `float`, `int`, etc. - - """ - self.config = config - self.resolved_config = None - self.is_resolved = False - if not self.get_id_of_refs(): - # if no references, can resolve the config immediately - self.resolve(refs=None) - - def get_config(self): - """ - Get the initial config content of current config item, usually set at the constructor. - It can be useful to dynamically update the config content before resolving. - - """ - return self.config - - def get_id_of_refs(self) -> List[str]: - """ - Recursively search all the content of current config item to get the IDs of references. - It's used to detect and resolve all the references before resolving current config item. - For `dict` and `list`, recursively check the sub-items. - For example: `{"args": {"lr": "$@epoch / 1000"}}`, the reference IDs: `["epoch"]`. - - """ - return find_refs_in_config(self.config, id=self.id) - - def resolve(self, refs: Optional[Dict] = None): - """ - If all the references are resolved in `refs`, resolve the config content with them to construct `resolved_config`. - - Args: - refs: all the resolved referring items with ID as keys, default to `None`. - - """ - self.resolved_config = resolve_config_with_refs(self.config, id=self.id, refs=refs, globals=self.globals) - self.is_resolved = True - - def get_resolved_config(self): - """ - Get the resolved config content, constructed in `resolve_config()`. The returned config has no references, - then use it in the program, for example: initial config item `{"intervals": "@epoch / 10"}` and references - `{"epoch": 100}`, the resolved config will be `{"intervals": 10}`. - - """ - return self.resolved_config - - -class Instantiable(ABC): - """ - Base class for instantiable object with module name and arguments. - - """ - - @abstractmethod - def resolve_module_name(self, *args: Any, **kwargs: Any): - """ - Utility function to resolve the target module name. - - """ - raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") - - @abstractmethod - def resolve_args(self, *args: Any, **kwargs: Any): - """ - Utility function to resolve the arguments. - - """ - raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") - - @abstractmethod - def is_disabled(self, *args: Any, **kwargs: Any): - """ - Utility function to check whether the target component is disabled. - - """ - raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") - - @abstractmethod - def instantiate(self, *args: Any, **kwargs: Any): - """ - Instantiate the target component. - - """ - raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") - - -class ConfigComponent(ConfigItem, Instantiable): - """ - Subclass of :py:class:`monai.apps.ConfigItem`, the config item represents a target component of class - or function, and support to instantiate the component. Example of config item: - `{"": "LoadImage", "": {"keys": "image"}}` - - Here we predefined 4 keys: ``, ``, ``, `` for component config: - - '' - class / function name in the modules of packages. - - '' - directly specify the module path, based on PYTHONPATH, ignore '' if specified. - - '' - arguments to initialize the component instance. - - '' - if defined `'': True`, will skip the buiding, useful for development or tuning. - - The typical usage of the APIs: - - Initialize with config content. - - If no references, `instantiate` the component if having "" or "" keywords and return the instance. - - If having references, get the IDs of its referring components. - - When all the referring components are resolved, resolve the config content with them, execute expressions in - the config and `instantiate`. - - .. code-block:: python - - locator = ComponentLocator(excludes=[""]) - config = {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}} - - configer = ConfigComponent(config, id="test_config", locator=locator) - configer.resolve_config(refs={"dataset": Dataset(data=[1, 2])}) - configer.get_resolved_config() - dataloader: DataLoader = configer.instantiate() - - Args: - config: content of a component config item, should be a dict with `` or `` key. - id: ID name of current config item, useful to construct referring config items. - for example, config item A may have ID "transforms#A" and config item B refers to A - and uses the resolved config content of A as an arg `{"args": {"other": "@transforms#A"}}`. - `id` defaults to `None`, if some component refers to current component, `id` must be a `string`. - locator: `ComponentLocator` to help locate the module path of `` in the config and instantiate. - if `None`, will create a new `ComponentLocator` with specified `excludes`. - excludes: if `locator` is None, create a new `ComponentLocator` with `excludes`. any string of the `excludes` - exists in the module name, don't import this module. - globals: to support executable string in the config, sometimes we need to provide the global variables - which are referred in the executable string. for example: `globals={"monai": monai} will be useful - for config `"collate_fn": "$monai.data.list_data_collate"`. - - """ - - def __init__( - self, - config: Any, - id: Optional[str] = None, - locator: Optional[ComponentLocator] = None, - excludes: Optional[Union[Sequence[str], str]] = None, - globals: Optional[Dict] = None, - ) -> None: - super().__init__(config=config, id=id, globals=globals) - self.locator = ComponentLocator(excludes=excludes) if locator is None else locator - - def resolve_module_name(self): - """ - Utility function used in `instantiate()` to resolve the target module name from provided config content. - The config content must have `` or ``. - - """ - config = self.get_resolved_config() - path = config.get("", None) - if path is not None: - if "" in config: - warnings.warn(f"should not set both '' and '', default to use '': {path}.") - return path - - name = config.get("", None) - if name is None: - raise ValueError("must provide `` or `` of target component to instantiate.") - - module = self.locator.get_component_module_name(name) - if module is None: - raise ModuleNotFoundError(f"can not find component '{name}' in {self.locator.MOD_START} modules.") - if isinstance(module, list): - warnings.warn( - f"there are more than 1 component have name `{name}`: {module}, use the first one `{module[0]}." - f" if want to use others, please set its module path in `` directly." - ) - module = module[0] - return f"{module}.{name}" - - def resolve_args(self): - """ - Utility function used in `instantiate()` to resolve the arguments from config content of target component. - - """ - return self.get_resolved_config().get("", {}) - - def is_disabled(self): - """ - Utility function used in `instantiate()` to check whether the target component is disabled. - - """ - return self.get_resolved_config().get("", False) - - def instantiate(self, **kwargs) -> object: # type: ignore - """ - Instantiate component based on the resolved config content. - The target component must be a `class` or a `function`, otherwise, return `None`. - - Args: - kwargs: args to override / add the config args when instantiation. - - """ - if not self.is_resolved: - warnings.warn( - "the config content of current component has not been resolved," - " please try to resolve the references first." - ) - return None - if not is_instantiable(self.get_resolved_config()) or self.is_disabled(): - # if not a class or function or marked as `disabled`, skip parsing and return `None` - return None - - modname = self.resolve_module_name() - args = self.resolve_args() - args.update(kwargs) - return instantiate(modname, **args) diff --git a/monai/apps/mmars/config_parser.py b/monai/apps/mmars/config_parser.py deleted file mode 100644 index 768a2c89b1..0000000000 --- a/monai/apps/mmars/config_parser.py +++ /dev/null @@ -1,280 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import importlib -from typing import Any, Dict, List, Optional, Sequence, Union - -from monai.apps.mmars.config_item import ConfigComponent, ConfigItem, ComponentLocator -from monai.apps.mmars.utils import is_instantiable - -class ConfigResolver: - """ - Utility class to resolve the dependencies between config components and build instance for specified `id`. - - Args: - components: config components to resolve, if None, can also `add()` component in runtime. - - """ - - def __init__(self, components: Optional[Dict[str, ConfigComponent]] = None): - self.resolved_configs: Dict[str, str] = {} - self.resolved_components: Dict[str, Any] = {} - self.components = {} if components is None else components - - def add(self, component: ConfigComponent): - """ - Add a component to the resolution graph. - - Args: - component: a config component to resolve. - - """ - id = component.get_id() - if id in self.components: - raise ValueError(f"id '{id}' is already added.") - self.components[id] = component - - def _resolve_one_component(self, id: str, instantiate: bool = True, waiting_list: Optional[List[str]] = None): - """ - Resolve one component with specified id name. - If has unresolved dependencies, recursively resolve the dependencies first. - - Args: - id: id name of expected component to resolve. - instantiate: after resolving all the dependencies, whether to build instance. - if False, can support lazy instantiation with the resolved config later. - default to `True`. - waiting_list: list of components wait to resolve dependencies. it's used to detect circular dependencies - when resolving dependencies like: `{"name": "A", "dep": "@B"}` and `{"name": "B", "dep": "@A"}`. - - """ - if waiting_list is None: - waiting_list = [] - waiting_list.append(id) - com = self.components[id] - dep_ids = com.get_dependent_ids() - # if current component has dependency already in the waiting list, that's circular dependencies - for d in dep_ids: - if d in waiting_list: - raise ValueError(f"detected circular dependencies for id='{d}' in the config content.") - - deps = {} - if len(dep_ids) > 0: - # # check whether the component has any unresolved deps - for comp_id in dep_ids: - if comp_id not in self.resolved_components: - # this dependent component is not resolved - if comp_id not in self.components: - raise RuntimeError(f"the dependent component `{comp_id}` is not in config.") - # resolve the dependency first - self._resolve_one_component(id=comp_id, instantiate=True, waiting_list=waiting_list) - deps[comp_id] = self.resolved_components[comp_id] - # all dependent components are resolved already - updated_config = com.get_updated_config(deps) - resolved_com = None - - if instantiate: - resolved_com = com.build(updated_config) - self.resolved_configs[id] = updated_config - self.resolved_components[id] = resolved_com - - return updated_config, resolved_com - - def resolve_all(self): - """ - Resolve all the components and build instances. - - """ - for k in self.components.keys(): - self._resolve_one_component(id=k, instantiate=True) - - def get_resolved_component(self, id: str): - """ - Get the resolved instance component with specified id name. - If not resolved, try to resolve it first. - - Args: - id: id name of the expected component. - - """ - if id not in self.resolved_components: - self._resolve_one_component(id=id, instantiate=True) - return self.resolved_components[id] - - def get_resolved_config(self, id: str): - """ - Get the resolved config component with specified id name, then can be used for lazy instantiation. - If not resolved, try to resolve it with `instantiation=False` first. - - Args: - id: id name of the expected config component. - - """ - if id not in self.resolved_configs: - config, _ = self._resolve_one_component(id=id, instantiate=False) - else: - config = self.resolved_configs[id] - return config - - -class ConfigParser: - """ - Parse a nested config and build components. - A typical usage is a config dictionary contains all the necessary components to define training workflow in JSON. - For more details of the config format, please check :py:class:`monai.apps.ConfigComponent`. - - Args: - excludes: if any string of the `excludes` exists in the full module name, don't import this module. - global_imports: pre-import packages as global variables to execute the python `eval` commands. - for example, pre-import `monai`, then execute `eval("monai.data.list_data_collate")`. - default to `{"monai": "monai", "torch": "torch", "np": "numpy"}` as `numpy` and `torch` - are MONAI mininum requirements. - config: config content to parse. - - """ - - def __init__( - self, - excludes: Optional[Union[Sequence[str], str]] = None, - global_imports: Optional[Dict[str, Any]] = None, - config: Optional[Any] = None, - ): - self.config = None - if config is not None: - self.set_config(config=config) - self.locator = ComponentLocator(excludes=excludes) - self.global_imports: Dict[str, Any] = {"monai": "monai", "torch": "torch", "np": "numpy"} - if global_imports is not None: - for k, v in global_imports.items(): - self.global_imports[k] = importlib.import_module(v) - self.config_resolver: ConfigResolver = ConfigResolver() - self.resolved = False - - def _get_last_config_and_key(self, config: Union[Dict, List], id: str): - """ - Utility to get the last config item and the id from the whole config content with nested id name. - - Args: - config: the whole config content. - id: nested id name to get the last item, joined by "#" mark, use index from 0 for list. - for example: "transforms#5", "transforms#5##keys", etc. - - """ - keys = id.split("#") - for k in keys[:-1]: - config = config[k] if isinstance(config, dict) else config[int(k)] - key = keys[-1] if isinstance(config, dict) else int(keys[-1]) - return config, key - - def set_config(self, config: Any, id: Optional[str] = None): - """ - Set config content for the parser, if `id` provided, `config` will used to replace the config item with `id`. - - Args: - config: target config content to set. - id: nested id name to specify the target position, joined by "#" mark, use index from 0 for list. - for example: "transforms#5", "transforms#5##keys", etc. - - """ - if isinstance(id, str) and isinstance(self.config, (dict, list)): - conf_, key = self._get_last_config_and_key(config=self.config, id=id) - conf_[key] = config - else: - self.config = config - self.resolved = False - - def get_config(self, id: Optional[str] = None): - """ - Get config content from the parser, if `id` provided, get the config item with `id`. - - Args: - id: nested id name to specify the expected position, joined by "#" mark, use index from 0 for list. - for example: "transforms#5", "transforms#5##keys", etc. - - """ - if isinstance(id, str) and isinstance(self.config, (dict, list)): - conf_, key = self._get_last_config_and_key(config=self.config, id=id) - return conf_[key] - return self.config - - def _do_parse(self, config, id: Optional[str] = None): - """ - Recursively parse the nested config content, add every config item as component to the resolver. - For example, `{"preprocessing": [{"": "LoadImage", "": {"keys": "image"}}]}` is parsed as items: - - `id="preprocessing", config=[{"": "LoadImage", "": {"keys": "image"}}]` - - `id="preprocessing#0", config={"": "LoadImage", "": {"keys": "image"}}` - - `id="preprocessing#0#", config="LoadImage"` - - `id="preprocessing#0#", config={"keys": "image"}` - - `id="preprocessing#0##keys", config="image"` - - Args: - config: config content to parse. - id: id name of current config item, nested ids are joined by "#" mark. defaults to None. - for example: "transforms#5", "transforms#5##keys", etc. - - """ - if isinstance(config, dict): - for k, v in config.items(): - sub_id = k if id is None else f"{id}#{k}" - self._do_parse(config=v, id=sub_id) - if isinstance(config, list): - for i, v in enumerate(config): - sub_id = i if id is None else f"{id}#{i}" - self._do_parse(config=v, id=sub_id) - if id is not None: - if is_instantiable(config): - self.config_resolver.add( - ConfigComponent(id=id, config=config, locator=self.locator, globals=self.global_imports) - ) - else: - self.config_resolver.add(ConfigItem(id=id, config=config, globals=self.global_imports)) - - def parse_config(self, resolve_all: bool = False): - """ - Parse the config content, add every config item as component to the resolver. - - Args: - resolve_all: if True, resolve all the components and build instances directly. - - """ - self.config_resolver = ConfigResolver() - self._do_parse(config=self.config) - - if resolve_all: - self.config_resolver.resolve_all() - self.resolved = True - - def get_resolved_config(self, id: str): - """ - Get the resolved instance component, if not resolved, try to resolve it first. - - Args: - id: id name of expected config component, nested ids are joined by "#" mark. - for example: "transforms#5", "transforms#5##keys", etc. - - """ - if self.config_resolver is None or not self.resolved: - self.parse_config() - return self.config_resolver.get_resolved_config(id=id) - - def get_resolved_component(self, id: str): - """ - Get the resolved config component, if not resolved, try to resolve it first. - It can be used to modify the config again and support lazy instantiation. - - Args: - id: id name of expected config component, nested ids are joined by "#" mark. - for example: "transforms#5", "transforms#5##keys", etc. - - """ - if self.config_resolver is None or not self.resolved: - self.parse_config() - return self.config_resolver.get_resolved_component(id=id) diff --git a/monai/apps/mmars/mmars.py b/monai/apps/mmars/mmars.py index 801b826bd1..9d88c754d6 100644 --- a/monai/apps/mmars/mmars.py +++ b/monai/apps/mmars/mmars.py @@ -17,6 +17,7 @@ """ import json +import os import warnings from pathlib import Path from typing import Mapping, Optional, Union @@ -41,10 +42,9 @@ def get_model_spec(idx: Union[int, str]): if isinstance(idx, str): key = idx.strip().lower() for cand in MODEL_DESC: - if str(cand[Keys.ID]).strip().lower() == key: + if str(cand.get(Keys.ID)).strip().lower() == key: return cand - logger.info(f"Available specs are: {MODEL_DESC}.") - raise ValueError(f"Unknown MODEL_DESC request: {idx}") + return idx def _get_all_ngc_models(pattern, page_index=0, page_size=50): @@ -100,7 +100,7 @@ def _get_ngc_doc_url(model_name: str, model_prefix=""): def download_mmar( - item, mmar_dir: Optional[PathLike] = None, progress: bool = True, api: bool = False, version: int = -1 + item, mmar_dir: Optional[PathLike] = None, progress: bool = True, api: bool = True, version: int = -1 ): """ Download and extract Medical Model Archive (MMAR) from Nvidia Clara Train. @@ -136,7 +136,7 @@ def download_mmar( raise ValueError("mmar_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?") mmar_dir = Path(mmar_dir) if api: - model_dict = _get_all_ngc_models(item) + model_dict = _get_all_ngc_models(item.get(Keys.NAME, f"{item}") if isinstance(item, Mapping) else f"{item}") if len(model_dict) == 0: raise ValueError(f"api query returns no item for pattern {item}. Please change or shorten it.") model_dir_list = [] @@ -155,11 +155,12 @@ def download_mmar( progress=progress, ) model_dir_list.append(model_dir) - return model_dir_list + if not model_dir_list: + raise ValueError(f"api query download no item for pattern {item}. Please change or shorten it.") + return model_dir_list[0] if not isinstance(item, Mapping): item = get_model_spec(item) - ver = item.get(Keys.VERSION, 1) if version > 0: ver = str(version) @@ -188,6 +189,8 @@ def load_from_mmar( pretrained=True, weights_only=False, model_key: str = "model", + api: bool = True, + model_file=None, ): """ Download and extract Medical Model Archive (MMAR) model weights from Nvidia Clara Train. @@ -203,6 +206,8 @@ def load_from_mmar( model_key: a key to search in the model file or config file for the model dictionary. Currently this function assumes that the model dictionary has `{"[name|path]": "test.module", "args": {'kw': 'test'}}`. + api: whether to query NGC API to get model infomation. + model_file: the relative path to the model file within an MMAR. Examples:: >>> from monai.apps import load_from_mmar @@ -212,11 +217,15 @@ def load_from_mmar( See Also: https://docs.nvidia.com/clara/ """ + if api: + item = {Keys.NAME: get_model_spec(item)[Keys.NAME] if isinstance(item, int) else f"{item}"} if not isinstance(item, Mapping): item = get_model_spec(item) - model_dir = download_mmar(item=item, mmar_dir=mmar_dir, progress=progress, version=version) - model_file = model_dir / item[Keys.MODEL_FILE] - logger.info(f'\n*** "{item[Keys.ID]}" available at {model_dir}.') + model_dir = download_mmar(item=item, mmar_dir=mmar_dir, progress=progress, version=version, api=api) + if model_file is None: + model_file = os.path.join("models", "model.pt") + model_file = model_dir / item.get(Keys.MODEL_FILE, model_file) + logger.info(f'\n*** "{item.get(Keys.NAME)}" available at {model_dir}.') # loading with `torch.jit.load` if model_file.name.endswith(".ts"): @@ -235,7 +244,7 @@ def load_from_mmar( model_config = _get_val(dict(model_dict).get("train_conf", {}), key=model_key, default={}) if not model_config: # 2. search json CONFIG_FILE for model config spec. - json_path = model_dir / item.get(Keys.CONFIG_FILE, "config_train.json") + json_path = model_dir / item.get(Keys.CONFIG_FILE, os.path.join("config", "config_train.json")) with open(json_path) as f: conf_dict = json.load(f) conf_dict = dict(conf_dict) diff --git a/monai/apps/mmars/utils.py b/monai/apps/mmars/utils.py deleted file mode 100644 index bc7b516886..0000000000 --- a/monai/apps/mmars/utils.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -from typing import Callable, Dict, List, Optional, Union - - -def match_refs_pattern(value: str) -> List[str]: - """ - Match regular expression for the input string to find the references. - The reference part starts with "@", like: "@XXX#YYY#ZZZ". - - Args: - value: input value to match regular expression. - - """ - refs: List[str] = [] - # regular expression pattern to match "@XXX" or "@XXX#YYY" - result = re.compile(r"@\w*[\#\w]*").findall(value) - for item in result: - if is_expression(value) or value == item: - # only check when string starts with "$" or the whole content is "@XXX" - ref_obj_id = item[1:] - if ref_obj_id not in refs: - refs.append(ref_obj_id) - return refs - - -def resolve_refs_pattern(value: str, refs: Dict, globals: Optional[Dict] = None) -> str: - """ - Match regular expression for the input string to update content with the references. - The reference part starts with "@", like: "@XXX#YYY#ZZZ". - References dictionary must contain the referring IDs as keys. - - Args: - value: input value to match regular expression. - refs: all the referring components with ids as keys, default to `None`. - globals: predefined global variables to execute code string with `eval()`. - - """ - # regular expression pattern to match "@XXX" or "@XXX#YYY" - result = re.compile(r"@\w*[\#\w]*").findall(value) - for item in result: - ref_id = item[1:] - if is_expression(value): - # replace with local code and execute later - value = value.replace(item, f"refs['{ref_id}']") - elif value == item: - if ref_id not in refs: - raise KeyError(f"can not find expected ID '{ref_id}' in the references.") - value = refs[ref_id] - if is_expression(value): - # execute the code string with python `eval()` - value = eval(value[1:], globals, {"refs": refs}) - return value - - -def find_refs_in_config( - config: Union[Dict, List, str], - id: Optional[str] = None, - refs: Optional[List[str]] = None, - match_fn: Callable = match_refs_pattern, -) -> List[str]: - """ - Recursively search all the content of input config item to get the ids of references. - References mean (1) referring to the ID of other item, can be extracted by `match_fn`, for example: - `{"lr": "$@epoch / 100"}` with "@" mark, the referring IDs: `["epoch"]`. (2) if sub-item in the config - is instantiable, treat it as reference because must instantiate it before resolving current config. - For `dict` and `list`, recursively check the sub-items. - - Args: - config: input config content to search. - id: ID name for the input config, default to `None`. - refs: list of the ID name of existing references, default to `None`. - match_fn: callable function to match config item for references, take `config` as parameter. - - """ - refs_: List[str] = [] if refs is None else refs - if isinstance(config, str): - refs_ += match_fn(value=config) - - if isinstance(config, list): - for i, v in enumerate(config): - sub_id = f"{id}#{i}" if id is not None else f"{i}" - if is_instantiable(v): - refs_.append(sub_id) - refs_ = find_refs_in_config(v, sub_id, refs_, match_fn) - if isinstance(config, dict): - for k, v in config.items(): - sub_id = f"{id}#{k}" if id is not None else f"{k}" - if is_instantiable(v): - refs_.append(sub_id) - refs_ = find_refs_in_config(v, sub_id, refs_, match_fn) - return refs_ - - -def resolve_config_with_refs( - config: Union[Dict, List, str], - id: Optional[str] = None, - refs: Optional[Dict] = None, - globals: Optional[Dict] = None, - match_fn: Callable = resolve_refs_pattern, -): - """ - With all the references in `refs`, resolve the config content with them and return new config. - - Args: - config: input config content to resolve. - id: ID name for the input config, default to `None`. - refs: all the referring components with ids, default to `None`. - globals: predefined global variables to execute code string with `eval()`. - match_fn: callable function to match config item for references, take `config`, - `refs` and `globals` as parameters. - - """ - refs_: Dict = {} if refs is None else refs - if isinstance(config, str): - config = match_fn(config, refs, globals) - if isinstance(config, list): - # all the items in the list should be replaced with the references - ret_list: List = [] - for i, v in enumerate(config): - sub = f"{id}#{i}" if id is not None else f"{i}" - ret_list.append( - refs_[sub] if is_instantiable(v) else resolve_config_with_refs(v, sub, refs_, globals, match_fn) - ) - return ret_list - if isinstance(config, dict): - # all the items in the dict should be replaced with the references - ret_dict: Dict = {} - for k, v in config.items(): - sub = f"{id}#{k}" if id is not None else f"{k}" - ret_dict.update( - {k: refs_[sub] if is_instantiable(v) else resolve_config_with_refs(v, sub, refs_, globals, match_fn)} - ) - return ret_dict - return config - - -def is_instantiable(config: Union[Dict, List, str]) -> bool: - """ - Check whether the content of the config represents a `class` or `function` to instantiate - with specified "" or "". - - Args: - config: input config content to check. - - """ - return isinstance(config, dict) and ("" in config or "" in config) - - -def is_expression(config: Union[Dict, List, str]) -> bool: - """ - Check whether the content of the config is executable expression string. - If True, the string should start with "$" mark. - - Args: - config: input config content to check. - - """ - return isinstance(config, str) and config.startswith("$") diff --git a/monai/utils/module.py b/monai/utils/module.py index 1dcbe6849f..8b7745c3ee 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -198,8 +198,8 @@ def load_submodules(basemod, load_all: bool = True, exclude_pattern: str = "(.*[ def instantiate(path: str, **kwargs): """ - Method for creating an instance for the specified class / function path. - kwargs will be class args or default args for `partial` function. + Create an object instance or partial function from a class or function represented by string. + `kwargs` will be part of the input arguments to the class constructor or function. The target component must be a class or a function, if not, return the component directly. Args: @@ -210,12 +210,14 @@ def instantiate(path: str, **kwargs): """ component = locate(path) + if component is None: + raise ModuleNotFoundError(f"Cannot locate '{path}'.") if inspect.isclass(component): return component(**kwargs) if inspect.isfunction(component): return partial(component, **kwargs) - warnings.warn(f"target component must be a valid class or function, but got {path}.") + warnings.warn(f"Component to instantiate must represent a valid class or function, but got {path}.") return component diff --git a/tests/test_component_locator.py b/tests/test_component_locator.py index adff25457a..eafb2152d1 100644 --- a/tests/test_component_locator.py +++ b/tests/test_component_locator.py @@ -12,7 +12,7 @@ import unittest from pydoc import locate -from monai.apps.mmars import ComponentLocator +from monai.apps.manifest import ComponentLocator from monai.utils import optional_import _, has_ignite = optional_import("ignite") diff --git a/tests/test_config_item.py b/tests/test_config_item.py index af6a02802b..108e5d7aa6 100644 --- a/tests/test_config_item.py +++ b/tests/test_config_item.py @@ -1,4 +1,4 @@ -# Copyright 2020 - 2021 MONAI Consortium +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,80 +11,59 @@ import unittest from functools import partial -from typing import Callable, Iterator +from typing import Callable import torch from parameterized import parameterized import monai -from monai.apps import ComponentLocator, ConfigComponent, ConfigItem, is_instantiable +from monai.apps import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem from monai.data import DataLoader, Dataset from monai.transforms import LoadImaged, RandTorchVisiond from monai.utils import optional_import _, has_tv = optional_import("torchvision") -TEST_CASE_1 = [{"": "LoadImaged", "": {"keys": ["image"]}}, LoadImaged] +TEST_CASE_1 = [{"lr": 0.001}, 0.0001] + +TEST_CASE_2 = [{"": "LoadImaged", "": {"keys": ["image"]}}, LoadImaged] # test python `` -TEST_CASE_2 = [{"": "monai.transforms.LoadImaged", "": {"keys": ["image"]}}, LoadImaged] +TEST_CASE_3 = [{"": "monai.transforms.LoadImaged", "": {"keys": ["image"]}}, LoadImaged] +# test `` +TEST_CASE_4 = [{"": "LoadImaged", "": True, "": {"keys": ["image"]}}, dict] # test `` -TEST_CASE_3 = [{"": "LoadImaged", "": True, "": {"keys": ["image"]}}, dict] -# test unresolved reference -TEST_CASE_4 = [{"": "LoadImaged", "": {"keys": ["@key_name"]}}] +TEST_CASE_5 = [{"": "LoadImaged", "": "true", "": {"keys": ["image"]}}, dict] # test non-monai modules and excludes -TEST_CASE_5 = [ +TEST_CASE_6 = [ {"": "torch.optim.Adam", "": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, torch.optim.Adam, ] -TEST_CASE_6 = [{"": "decollate_batch", "": {"detach": True, "pad": True}}, partial] +TEST_CASE_7 = [{"": "decollate_batch", "": {"detach": True, "pad": True}}, partial] # test args contains "name" field -TEST_CASE_7 = [ +TEST_CASE_8 = [ {"": "RandTorchVisiond", "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}}, RandTorchVisiond, ] -# test references of dict config -TEST_CASE_8 = [{"dataset": "@dataset", "batch_size": 2}, ["dataset"]] -# test references of list config -TEST_CASE_9 = [{"dataset": "@dataset", "transforms": ["@trans0", "@trans1"]}, ["dataset", "trans0", "trans1"]] -# test references of execute code -TEST_CASE_10 = [ - {"dataset": "$@dataset.test_func()", "transforms": ["$torch.zeros([2, 2]) + @trans"]}, - ["dataset", "trans"], -] -# test references of lambda function -TEST_CASE_11 = [ - {"lr_range": "$lambda x: x + @num_epochs", "lr": ["$lambda x: torch.zeros([2, 2]) + @init_lr"]}, - ["num_epochs", "init_lr"], -] -# test instance with no references -TEST_CASE_12 = ["transform#1", {"": "LoadImaged", "": {"keys": ["image"]}}, {}, LoadImaged] -# test dataloader refers to `@dataset`, here we don't test recursive references, test that in `ConfigResolver` -TEST_CASE_13 = [ - "dataloader", - {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}}, - {"dataset": Dataset(data=[1, 2])}, - DataLoader, -] -# test references in code execution -TEST_CASE_14 = [ - "optimizer", - {"": "torch.optim.Adam", "": {"params": "$@model.parameters()", "lr": "@learning_rate"}}, - {"model": torch.nn.PReLU(), "learning_rate": 1e-4}, - torch.optim.Adam, -] -# test replace references with code execution result -TEST_CASE_15 = ["optimizer##params", "$@model.parameters()", {"model": torch.nn.PReLU()}, Iterator] # test execute some function in args, test pre-imported global packages `monai` -TEST_CASE_16 = ["dataloader##collate_fn", "$monai.data.list_data_collate", {}, Callable] -# test lambda function, should not execute the lambda function, just change the string with referring objects -TEST_CASE_17 = ["dataloader##collate_fn", "$lambda x: monai.data.list_data_collate(x) + 100", {}, Callable] +TEST_CASE_9 = ["collate_fn", "$monai.data.list_data_collate"] +# test lambda function, should not execute the lambda function, just change the string +TEST_CASE_10 = ["collate_fn", "$lambda x: monai.data.list_data_collate(x) + torch.tensor(var)"] class TestConfigItem(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) + def test_item(self, test_input, expected): + item = ConfigItem(config=test_input) + conf = item.get_config() + conf["lr"] = 0.0001 + item.update_config(config=conf) + self.assertEqual(item.get_config()["lr"], expected) + @parameterized.expand( - [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_5, TEST_CASE_6] + ([TEST_CASE_7] if has_tv else []) + [TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7] + + ([TEST_CASE_8] if has_tv else []) ) - def test_instantiate(self, test_input, output_type): + def test_component(self, test_input, output_type): locator = ComponentLocator(excludes=["metrics"]) configer = ConfigComponent(id="test", config=test_input, locator=locator) ret = configer.instantiate() @@ -96,39 +75,21 @@ def test_instantiate(self, test_input, output_type): if isinstance(ret, LoadImaged): self.assertEqual(ret.keys[0], "image") - @parameterized.expand([TEST_CASE_4]) - def test_raise_error(self, test_input): - with self.assertRaises(KeyError): # has unresolved keys - configer = ConfigItem(id="test", config=test_input) - configer.resolve() - - @parameterized.expand([TEST_CASE_8, TEST_CASE_9, TEST_CASE_10, TEST_CASE_11]) - def test_referring_ids(self, test_input, ref_ids): - configer = ConfigItem(id="test", config=test_input) # also test default locator - ret = configer.get_id_of_refs() - self.assertListEqual(ret, ref_ids) - - @parameterized.expand([TEST_CASE_12, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15, TEST_CASE_16, TEST_CASE_17]) - def test_resolve_references(self, id, test_input, refs, output_type): - configer = ConfigComponent( - id=id, config=test_input, locator=None, excludes=["utils"], globals={"monai": monai, "torch": torch} - ) - configer.resolve(refs=refs) - ret = configer.get_resolved_config() - if is_instantiable(ret): - ret = configer.instantiate(**{}) # also test kwargs - self.assertTrue(isinstance(ret, output_type)) + @parameterized.expand([TEST_CASE_9, TEST_CASE_10]) + def test_expression(self, id, test_input): + configer = ConfigExpression(id=id, config=test_input, globals={"monai": monai, "torch": torch}) + var = 100 + ret = configer.execute(locals={"var": var}) + self.assertTrue(isinstance(ret, Callable)) def test_lazy_instantiation(self): - config = {"": "DataLoader", "": {"dataset": "@dataset", "batch_size": 2}} - refs = {"dataset": Dataset(data=[1, 2])} + config = {"": "DataLoader", "": {"dataset": Dataset(data=[1, 2]), "batch_size": 2}} configer = ConfigComponent(config=config, locator=None) init_config = configer.get_config() # modify config content at runtime init_config[""]["batch_size"] = 4 configer.update_config(config=init_config) - configer.resolve(refs=refs) ret = configer.instantiate() self.assertTrue(isinstance(ret, DataLoader)) self.assertEqual(ret.batch_size, 4) From b40f3e5337cefb334852c2d6e5c21bfaddb1b7d2 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 17 Feb 2022 14:50:30 +0800 Subject: [PATCH 22/30] [DLMED] update file names Signed-off-by: Nic Ma --- monai/apps/__init__.py | 2 +- monai/apps/manifest/__init__.py | 1 + monai/apps/manifest/reference_resolver.py | 4 ++-- tests/{test_config_resolver.py => test_reference_resolver.py} | 0 4 files changed, 4 insertions(+), 3 deletions(-) rename tests/{test_config_resolver.py => test_reference_resolver.py} (100%) diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index df085bddea..0f233bc3ef 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -10,6 +10,6 @@ # limitations under the License. from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset -from .manifest import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem +from .manifest import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, ReferenceResolver from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/manifest/__init__.py b/monai/apps/manifest/__init__.py index d8919a5249..79c4376d5c 100644 --- a/monai/apps/manifest/__init__.py +++ b/monai/apps/manifest/__init__.py @@ -10,3 +10,4 @@ # limitations under the License. from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem +from .reference_resolver import ReferenceResolver diff --git a/monai/apps/manifest/reference_resolver.py b/monai/apps/manifest/reference_resolver.py index 3c7a36b1c7..8a24fa24b9 100644 --- a/monai/apps/manifest/reference_resolver.py +++ b/monai/apps/manifest/reference_resolver.py @@ -79,9 +79,9 @@ def resolve_one_item(self, id: str, waiting_list: Optional[List[str]] = None): # all references are resolved new_config = self.resolve_config_with_refs(config=item_config, id=id, refs=self.resolved_content) item.update_config(config=new_config) - if ConfigComponent.is_instantiable(new_config): + if isinstance(item, ConfigComponent): self.resolved_content[id] = item.instantiate() - if ConfigExpression.is_expression(new_config): + if isinstance(item, ConfigExpression): self.resolved_content[id] = item.execute(locals={"refs": self.resolved_content}) else: self.resolved_content[id] = new_config diff --git a/tests/test_config_resolver.py b/tests/test_reference_resolver.py similarity index 100% rename from tests/test_config_resolver.py rename to tests/test_reference_resolver.py From 0198f9bb5babdf22b74ee869bc624cd99db9bcbc Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 17 Feb 2022 16:10:28 +0800 Subject: [PATCH 23/30] [DLMED] fix ReferenceResolver Signed-off-by: Nic Ma --- monai/apps/manifest/reference_resolver.py | 33 ++++++------- tests/test_reference_resolver.py | 58 ++++++++++++++--------- 2 files changed, 49 insertions(+), 42 deletions(-) diff --git a/monai/apps/manifest/reference_resolver.py b/monai/apps/manifest/reference_resolver.py index 8a24fa24b9..1988f35535 100644 --- a/monai/apps/manifest/reference_resolver.py +++ b/monai/apps/manifest/reference_resolver.py @@ -10,7 +10,7 @@ # limitations under the License. import re -from typing import Callable, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union import warnings from monai.apps.manifest.config_item import ConfigComponent, ConfigExpression, ConfigItem @@ -81,7 +81,7 @@ def resolve_one_item(self, id: str, waiting_list: Optional[List[str]] = None): item.update_config(config=new_config) if isinstance(item, ConfigComponent): self.resolved_content[id] = item.instantiate() - if isinstance(item, ConfigExpression): + elif isinstance(item, ConfigExpression): self.resolved_content[id] = item.execute(locals={"refs": self.resolved_content}) else: self.resolved_content[id] = new_config @@ -105,20 +105,20 @@ def get_resolved_content(self, id: str): """ if id not in self.resolved_content: self.resolve_one_item(id=id) - return self.resolved_content(id) + return self.resolved_content.get(id) - def get_resolved_config(self, id: str): + def get_item(self, id: str, resolve: bool = False): """ - Get the resolved config content with specified id name, then can be used for lazy instantiation. - If not resolved, try to resolve it first. + Get the config item with specified id name, then can be used for lazy instantiation. + If `resolve=True`, try to resolve it first. Args: id: id name of the expected config item. """ - if id not in self.resolved_content: + if resolve and id not in self.resolved_content: self.resolve_one_item(id=id) - return self.items(id) + return self.items.get(id) @staticmethod def match_refs_pattern(value: str) -> List[str]: @@ -171,7 +171,6 @@ def find_refs_in_config( config: Union[Dict, List, str], id: Optional[str] = None, refs: Optional[List[str]] = None, - match_fn: Callable = match_refs_pattern, ) -> List[str]: """ Recursively search all the content of input config item to get the ids of references. @@ -184,25 +183,24 @@ def find_refs_in_config( config: input config content to search. id: ID name for the input config, default to `None`. refs: list of the ID name of existing references, default to `None`. - match_fn: callable function to match config item for references, take `config` as parameter. """ refs_: List[str] = [] if refs is None else refs if isinstance(config, str): - refs_ += match_fn(value=config) + refs_ += ReferenceResolver.match_refs_pattern(value=config) if isinstance(config, list): for i, v in enumerate(config): sub_id = f"{id}#{i}" if id is not None else f"{i}" if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v): refs_.append(sub_id) - refs_ = ReferenceResolver.find_refs_in_config(v, sub_id, refs_, match_fn) + refs_ = ReferenceResolver.find_refs_in_config(v, sub_id, refs_) if isinstance(config, dict): for k, v in config.items(): sub_id = f"{id}#{k}" if id is not None else f"{k}" if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v): refs_.append(sub_id) - refs_ = ReferenceResolver.find_refs_in_config(v, sub_id, refs_, match_fn) + refs_ = ReferenceResolver.find_refs_in_config(v, sub_id, refs_) return refs_ @staticmethod @@ -210,7 +208,6 @@ def resolve_config_with_refs( config: Union[Dict, List, str], id: Optional[str] = None, refs: Optional[Dict] = None, - match_fn: Callable = resolve_refs_pattern, ): """ With all the references in `refs`, resolve the config content with them and return new config. @@ -219,13 +216,11 @@ def resolve_config_with_refs( config: input config content to resolve. id: ID name for the input config, default to `None`. refs: all the referring components with ids, default to `None`. - match_fn: callable function to match config item for references, take `config`, - `refs` and `globals` as parameters. """ refs_: Dict = {} if refs is None else refs if isinstance(config, str): - config = match_fn(config, refs, globals) + config = ReferenceResolver.resolve_refs_pattern(config, refs) if isinstance(config, list): # all the items in the list should be replaced with the references ret_list: List = [] @@ -234,7 +229,7 @@ def resolve_config_with_refs( if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v): ret_list.append(refs_[sub_id]) else: - ret_list.append(ReferenceResolver.resolve_config_with_refs(v, sub_id, refs_, match_fn)) + ret_list.append(ReferenceResolver.resolve_config_with_refs(v, sub_id, refs_)) return ret_list if isinstance(config, dict): # all the items in the dict should be replaced with the references @@ -244,6 +239,6 @@ def resolve_config_with_refs( if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v): ret_dict[k] = refs_[sub_id] else: - ret_dict[k] = ReferenceResolver.resolve_config_with_refs(v, sub_id, refs_, match_fn) + ret_dict[k] = ReferenceResolver.resolve_config_with_refs(v, sub_id, refs_) return ret_dict return config diff --git a/tests/test_reference_resolver.py b/tests/test_reference_resolver.py index 95200d0b81..f6a8dc416c 100644 --- a/tests/test_reference_resolver.py +++ b/tests/test_reference_resolver.py @@ -10,15 +10,16 @@ # limitations under the License. import unittest +from monai.apps.manifest.config_item import ComponentLocator, ConfigExpression, ConfigItem import torch from parameterized import parameterized import monai -from monai.apps import ConfigComponent, ConfigResolver +from monai.apps import ConfigComponent, ReferenceResolver from monai.data import DataLoader from monai.transforms import LoadImaged, RandTorchVisiond -from monai.utils import ClassScanner, optional_import +from monai.utils import optional_import _, has_tv = optional_import("torchvision") @@ -38,7 +39,7 @@ # test depends on other component and executable code TEST_CASE_2 = [ { - # all the recursively parsed config items + # some the recursively parsed config items "dataloader": { "": "DataLoader", "": {"dataset": "@dataset", "collate_fn": "$monai.data.list_data_collate"}, @@ -76,38 +77,49 @@ ] -class TestConfigComponent(unittest.TestCase): +class TestReferenceResolver(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2] + ([TEST_CASE_3] if has_tv else [])) def test_resolve(self, configs, expected_id, output_type): - scanner = ClassScanner(pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"]) - resolver = ConfigResolver() + locator = ComponentLocator() + resolver = ReferenceResolver() + # add items to resolver for k, v in configs.items(): - resolver.add( - ConfigComponent(id=k, config=v, class_scanner=scanner, globals={"monai": monai, "torch": torch}) - ) - ins = resolver.get_resolved_component(expected_id) - self.assertTrue(isinstance(ins, output_type)) + if ConfigComponent.is_instantiable(v): + resolver.add(ConfigComponent(config=v, id=k, locator=locator)) + elif ConfigExpression.is_expression(v): + resolver.add(ConfigExpression(config=v, id=k, globals={"monai": monai, "torch": torch})) + else: + resolver.add(ConfigItem(config=v, id=k)) + + result = resolver.get_resolved_content(expected_id) + self.assertTrue(isinstance(result, output_type)) + # test resolve all - resolver.resolved_configs = {} - resolver.resolved_components = {} + resolver.resolved_content = {} resolver.resolve_all() - ins = resolver.get_resolved_component(expected_id) - self.assertTrue(isinstance(ins, output_type)) + result = resolver.get_resolved_content(expected_id) + self.assertTrue(isinstance(result, output_type)) + # test lazy instantiation - config = resolver.get_resolved_config(expected_id) + item = resolver.get_item(expected_id, resolve=True) + config = item.get_config() config[""] = False - ins = ConfigComponent(id=expected_id, class_scanner=scanner, config=config).build() - self.assertTrue(isinstance(ins, output_type)) + item.update_config(config=config) + if isinstance(item, ConfigComponent): + result = item.instantiate() + else: + result = item.get_config() + self.assertTrue(isinstance(result, output_type)) - def test_circular_dependencies(self): - scanner = ClassScanner(pkgs=[], modules=[]) - resolver = ConfigResolver() + def test_circular_references(self): + locator = ComponentLocator() + resolver = ReferenceResolver() configs = {"A": "@B", "B": "@C", "C": "@A"} for k, v in configs.items(): - resolver.add(ConfigComponent(id=k, config=v, class_scanner=scanner)) + resolver.add(ConfigComponent(config=v, id=k, locator=locator)) for k in ["A", "B", "C"]: with self.assertRaises(ValueError): - resolver.get_resolved_component(k) + resolver.get_resolved_content(k) if __name__ == "__main__": From 1ffe52af2d87b2c6df2220bb7dbba23555135413 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 18 Feb 2022 11:05:49 +0800 Subject: [PATCH 24/30] [DLMED] update ConfigParser Signed-off-by: Nic Ma --- monai/apps/__init__.py | 2 +- monai/apps/manifest/__init__.py | 1 + monai/apps/manifest/config_parser.py | 170 +++++++++++++++++++++++++++ 3 files changed, 172 insertions(+), 1 deletion(-) create mode 100644 monai/apps/manifest/config_parser.py diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index 0f233bc3ef..51c4003458 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -10,6 +10,6 @@ # limitations under the License. from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset -from .manifest import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, ReferenceResolver +from .manifest import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, ConfigParser, ReferenceResolver from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/manifest/__init__.py b/monai/apps/manifest/__init__.py index 79c4376d5c..b8ddf57f93 100644 --- a/monai/apps/manifest/__init__.py +++ b/monai/apps/manifest/__init__.py @@ -10,4 +10,5 @@ # limitations under the License. from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem +from .config_parser import ConfigParser from .reference_resolver import ReferenceResolver diff --git a/monai/apps/manifest/config_parser.py b/monai/apps/manifest/config_parser.py new file mode 100644 index 0000000000..8b476140b2 --- /dev/null +++ b/monai/apps/manifest/config_parser.py @@ -0,0 +1,170 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +from typing import Any, Dict, List, Optional, Sequence, Union + +from monai.apps.manifest.config_item import ConfigComponent, ConfigExpression, ConfigItem, ComponentLocator +from monai.apps.manifest.reference_resolver import ReferenceResolver + + +class ConfigParser: + """ + Parse a nested config and build components. + A typical usage is a config dictionary contains all the necessary components to define training workflow in JSON. + For more details of the config format, please check :py:class:`monai.apps.ConfigComponent`. + + Args: + excludes: if any string of the `excludes` exists in the full module name, don't import this module. + global_imports: pre-import packages as global variables to execute the python `eval` commands. + for example, pre-import `monai`, then execute `eval("monai.data.list_data_collate")`. + default to `{"monai": "monai", "torch": "torch", "np": "numpy"}` as `numpy` and `torch` + are MONAI mininum requirements. + config: config content to parse. + + """ + + def __init__( + self, + excludes: Optional[Union[Sequence[str], str]] = None, + global_imports: Optional[Dict[str, Any]] = None, + config: Optional[Any] = None, + ): + self.config = None + if config is not None: + self.set_config(config=config) + self.locator = ComponentLocator(excludes=excludes) + self.global_imports: Dict[str, Any] = {"monai": "monai", "torch": "torch", "np": "numpy"} + if global_imports is not None: + for k, v in global_imports.items(): + self.global_imports[k] = importlib.import_module(v) + self.reference_resolver: Optional[ReferenceResolver] = None + self.parsed = False + + def _get_last_config_and_key(self, config: Union[Dict, List], id: str): + """ + Utility to get the last config item and the id from the whole config content with nested id name. + + Args: + config: the whole config content. + id: nested id name to get the last item, joined by "#" mark, use index from 0 for list. + for example: "transforms#5", "transforms#5##keys", etc. + + """ + keys = id.split("#") + for k in keys[:-1]: + config = config[k] if isinstance(config, dict) else config[int(k)] + key = keys[-1] if isinstance(config, dict) else int(keys[-1]) + return config, key + + def set_config(self, config: Any, id: Optional[str] = None): + """ + Set config content for the parser, if `id` provided, `config` will used to replace the config item with `id`. + + Args: + config: target config content to set. + id: nested id name to specify the target position, joined by "#" mark, use index from 0 for list. + for example: "transforms#5", "transforms#5##keys", etc. + + """ + if isinstance(id, str) and isinstance(self.config, (dict, list)): + conf_, key = self._get_last_config_and_key(config=self.config, id=id) + conf_[key] = config + else: + self.config = config + self.parsed = False + + def get_config(self, id: Optional[str] = None): + """ + Get config content from the parser, if `id` provided, get the config item with `id`. + + Args: + id: nested id name to specify the expected position, joined by "#" mark, use index from 0 for list. + for example: "transforms#5", "transforms#5##keys", etc. + + """ + if isinstance(id, str) and isinstance(self.config, (dict, list)): + conf_, key = self._get_last_config_and_key(config=self.config, id=id) + return conf_[key] + return self.config + + def _do_parse(self, config, id: Optional[str] = None): + """ + Recursively parse the nested config content, add every config item as component to the resolver. + For example, `{"preprocessing": [{"": "LoadImage", "": {"keys": "image"}}]}` is parsed as items: + - `id="preprocessing", config=[{"": "LoadImage", "": {"keys": "image"}}]` + - `id="preprocessing#0", config={"": "LoadImage", "": {"keys": "image"}}` + - `id="preprocessing#0#", config="LoadImage"` + - `id="preprocessing#0#", config={"keys": "image"}` + - `id="preprocessing#0##keys", config="image"` + + Args: + config: config content to parse. + id: id name of current config item, nested ids are joined by "#" mark. defaults to None. + for example: "transforms#5", "transforms#5##keys", etc. + + """ + if isinstance(config, dict): + for k, v in config.items(): + sub_id = k if id is None else f"{id}#{k}" + self._do_parse(config=v, id=sub_id) + if isinstance(config, list): + for i, v in enumerate(config): + sub_id = i if id is None else f"{id}#{i}" + self._do_parse(config=v, id=sub_id) + if id is not None: + if ConfigComponent.is_instantiable(config): + self.reference_resolver.add( + ConfigComponent(config=config, id=id, locator=self.locator) + ) + elif ConfigExpression.is_expression(config): + self.reference_resolver.add(ConfigExpression(config=config, id=id, globals=self.global_imports)) + else: + self.reference_resolver.add(ConfigItem(config=config, id=id)) + + def parse_config(self): + """ + Parse the config content, add every config item as component to the resolver. + + Args: + resolve_all: if True, resolve all the components and build instances directly. + + """ + self.reference_resolver = ReferenceResolver() + self._do_parse(config=self.config) + self.parsed = True + + def get_resolved_content(self, id: str): + """ + Get the resolved instance component, if not resolved, try to resolve it first. + + Args: + id: id name of expected config component, nested ids are joined by "#" mark. + for example: "transforms#5", "transforms#5##keys", etc. + + """ + if not self.parsed: + self.parse_config() + return self.reference_resolver.get_resolved_content(id=id) + + def get_resolved_config(self, id: str): + """ + Get the resolved config component, if not resolved, try to resolve it first. + It can be used to modify the config again and support lazy instantiation. + + Args: + id: id name of expected config component, nested ids are joined by "#" mark. + for example: "transforms#5", "transforms#5##keys", etc. + + """ + if not self.parsed: + self.parse_config() + return self.reference_resolver.get_item(id=id, resolve=True) From 627f7d5f0b7deb0a9f4b3a26f12257f61942397a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 18 Feb 2022 11:15:42 +0800 Subject: [PATCH 25/30] [DLMED] fix unit tests Signed-off-by: Nic Ma --- monai/apps/manifest/config_parser.py | 4 ++-- tests/test_config_parser.py | 20 ++++++++------------ 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/monai/apps/manifest/config_parser.py b/monai/apps/manifest/config_parser.py index 8b476140b2..e35933870a 100644 --- a/monai/apps/manifest/config_parser.py +++ b/monai/apps/manifest/config_parser.py @@ -155,7 +155,7 @@ def get_resolved_content(self, id: str): self.parse_config() return self.reference_resolver.get_resolved_content(id=id) - def get_resolved_config(self, id: str): + def get_config_item(self, id: str, resolve: bool = False): """ Get the resolved config component, if not resolved, try to resolve it first. It can be used to modify the config again and support lazy instantiation. @@ -167,4 +167,4 @@ def get_resolved_config(self, id: str): """ if not self.parsed: self.parse_config() - return self.reference_resolver.get_item(id=id, resolve=True) + return self.reference_resolver.get_item(id=id, resolve=resolve) diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index 0cb1ec92b7..203ba18d59 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -11,6 +11,7 @@ import unittest from unittest import skipUnless +from monai.apps.manifest.config_item import ConfigComponent from parameterized import parameterized @@ -49,7 +50,7 @@ class TestConfigComponent(unittest.TestCase): def test_config_content(self): - parser = ConfigParser(pkgs=["torch.optim", "monai"], modules=["data", "transforms", "adam"]) + parser = ConfigParser() test_config = {"preprocessing": [{"name": "LoadImage"}], "dataset": {"name": "Dataset"}} parser.set_config(config=test_config) self.assertEqual(str(parser.get_config()), str(test_config)) @@ -57,21 +58,16 @@ def test_config_content(self): self.assertDictEqual(parser.get_config(id="preprocessing#0#datasets"), {"name": "CacheDataset"}) @parameterized.expand([TEST_CASE_1]) - @skipUnless(has_tv, "Requires tifffile.") + @skipUnless(has_tv, "Requires torchvision.") def test_parse(self, config, expected_ids, output_types): - parser = ConfigParser( - pkgs=["torch.optim", "monai"], - modules=["data", "transforms", "adam"], - global_imports={"monai": "monai"}, - config=config, - ) + parser = ConfigParser(global_imports={"monai": "monai"}, config=config) for id, cls in zip(expected_ids, output_types): - config = parser.get_resolved_config(id) + item = parser.get_config_item(id, resolve=True) # test lazy instantiation - self.assertTrue(isinstance(config, dict)) - self.assertTrue(isinstance(parser.build(config), cls)) + if isinstance(item, ConfigComponent): + self.assertTrue(isinstance(item.instantiate(), cls)) # test get instance directly - self.assertTrue(isinstance(parser.get_resolved_component(id), cls)) + self.assertTrue(isinstance(parser.get_resolved_content(id), cls)) if __name__ == "__main__": From 1684c23112da457a6c568e9e66098acd9f83614a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 18 Feb 2022 19:24:54 +0800 Subject: [PATCH 26/30] [DLMED] update ConfigParse logic Signed-off-by: Nic Ma --- monai/apps/manifest/config_parser.py | 104 +++++++++++++-------------- tests/test_config_parser.py | 2 +- 2 files changed, 51 insertions(+), 55 deletions(-) diff --git a/monai/apps/manifest/config_parser.py b/monai/apps/manifest/config_parser.py index e35933870a..6a686dd8d1 100644 --- a/monai/apps/manifest/config_parser.py +++ b/monai/apps/manifest/config_parser.py @@ -10,7 +10,7 @@ # limitations under the License. import importlib -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, Optional, Sequence, Union from monai.apps.manifest.config_item import ConfigComponent, ConfigExpression, ConfigItem, ComponentLocator from monai.apps.manifest.reference_resolver import ReferenceResolver @@ -18,87 +18,83 @@ class ConfigParser: """ - Parse a nested config and build components. + Parse a config content, access or update the items of config content with unique ID. A typical usage is a config dictionary contains all the necessary components to define training workflow in JSON. For more details of the config format, please check :py:class:`monai.apps.ConfigComponent`. Args: - excludes: if any string of the `excludes` exists in the full module name, don't import this module. - global_imports: pre-import packages as global variables to execute the python `eval` commands. + config: config content to parse. + excludes: when importing modules to instantiate components, if any string of the `excludes` exists + in the full module name, don't import this module. + globals: pre-import packages as global variables to evaluate the python `eval` expressions. for example, pre-import `monai`, then execute `eval("monai.data.list_data_collate")`. default to `{"monai": "monai", "torch": "torch", "np": "numpy"}` as `numpy` and `torch` are MONAI mininum requirements. - config: config content to parse. + if the value in global is string, will import it immediately. """ def __init__( self, - excludes: Optional[Union[Sequence[str], str]] = None, - global_imports: Optional[Dict[str, Any]] = None, config: Optional[Any] = None, + excludes: Optional[Union[Sequence[str], str]] = None, + globals: Optional[Dict[str, Any]] = None, ): self.config = None if config is not None: self.set_config(config=config) + + self.globals: Dict[str, Any] = {} + globals = {"monai": "monai", "torch": "torch", "np": "numpy"} if globals is None else globals + if globals is not None: + for k, v in globals.items(): + self.globals[k] = importlib.import_module(v) if isinstance(v, str) else v + self.locator = ComponentLocator(excludes=excludes) - self.global_imports: Dict[str, Any] = {"monai": "monai", "torch": "torch", "np": "numpy"} - if global_imports is not None: - for k, v in global_imports.items(): - self.global_imports[k] = importlib.import_module(v) self.reference_resolver: Optional[ReferenceResolver] = None + # flag to identify the parsing status of current config content self.parsed = False - def _get_last_config_and_key(self, config: Union[Dict, List], id: str): - """ - Utility to get the last config item and the id from the whole config content with nested id name. - - Args: - config: the whole config content. - id: nested id name to get the last item, joined by "#" mark, use index from 0 for list. - for example: "transforms#5", "transforms#5##keys", etc. - - """ - keys = id.split("#") - for k in keys[:-1]: - config = config[k] if isinstance(config, dict) else config[int(k)] - key = keys[-1] if isinstance(config, dict) else int(keys[-1]) - return config, key - def set_config(self, config: Any, id: Optional[str] = None): """ - Set config content for the parser, if `id` provided, `config` will used to replace the config item with `id`. + Set config content for the parser, if `id` provided, `config` will replace the config item with `id`. Args: config: target config content to set. - id: nested id name to specify the target position, joined by "#" mark, use index from 0 for list. + id: id name to specify the target position, joined by "#" mark for nested content, use index from 0 for list. for example: "transforms#5", "transforms#5##keys", etc. """ if isinstance(id, str) and isinstance(self.config, (dict, list)): - conf_, key = self._get_last_config_and_key(config=self.config, id=id) - conf_[key] = config + keys = id.split("#") + # get the last second config item and replace it + last_id = "#".join(keys[:-1]) + conf_ = self.get_config(id=last_id) + conf_[keys[-1] if isinstance(conf_, dict) else int(keys[-1])] = config else: self.config = config + # must totally parse again as the content is modified self.parsed = False def get_config(self, id: Optional[str] = None): """ - Get config content from the parser, if `id` provided, get the config item with `id`. + Get config content of current config, if `id` provided, get the config item with `id`. Args: id: nested id name to specify the expected position, joined by "#" mark, use index from 0 for list. for example: "transforms#5", "transforms#5##keys", etc. """ - if isinstance(id, str) and isinstance(self.config, (dict, list)): - conf_, key = self._get_last_config_and_key(config=self.config, id=id) - return conf_[key] - return self.config + config = self.config + if isinstance(id, str) and len(id) > 0 and isinstance(config, (dict, list)): + keys = id.split("#") + for k in keys: + config = config[k] if isinstance(config, dict) else config[int(k)] + return config def _do_parse(self, config, id: Optional[str] = None): """ - Recursively parse the nested config content, add every config item as component to the resolver. + Recursively parse the nested config content, add every config item to the resolver. For example, `{"preprocessing": [{"": "LoadImage", "": {"keys": "image"}}]}` is parsed as items: - `id="preprocessing", config=[{"": "LoadImage", "": {"keys": "image"}}]` - `id="preprocessing#0", config={"": "LoadImage", "": {"keys": "image"}}` @@ -120,22 +116,19 @@ def _do_parse(self, config, id: Optional[str] = None): for i, v in enumerate(config): sub_id = i if id is None else f"{id}#{i}" self._do_parse(config=v, id=sub_id) - if id is not None: - if ConfigComponent.is_instantiable(config): - self.reference_resolver.add( - ConfigComponent(config=config, id=id, locator=self.locator) - ) - elif ConfigExpression.is_expression(config): - self.reference_resolver.add(ConfigExpression(config=config, id=id, globals=self.global_imports)) - else: - self.reference_resolver.add(ConfigItem(config=config, id=id)) + # add config items to the resolver + if ConfigComponent.is_instantiable(config): + self.reference_resolver.add( + ConfigComponent(config=config, id=id, locator=self.locator) + ) + elif ConfigExpression.is_expression(config): + self.reference_resolver.add(ConfigExpression(config=config, id=id, globals=self.global_imports)) + else: + self.reference_resolver.add(ConfigItem(config=config, id=id)) def parse_config(self): """ - Parse the config content, add every config item as component to the resolver. - - Args: - resolve_all: if True, resolve all the components and build instances directly. + Parse the config content, add every config item to the resolver and mark as `parsed`. """ self.reference_resolver = ReferenceResolver() @@ -144,10 +137,13 @@ def parse_config(self): def get_resolved_content(self, id: str): """ - Get the resolved instance component, if not resolved, try to resolve it first. + Get the resolved result of config items with specified id, if not resolved, try to resolve it first. + If the config item is instantiable, the resolved result is the instance. + If the config item is an expression, the resolved result is output of when evaluating the expression. + Otherwise, the resolved result is the updated config content of the config item. Args: - id: id name of expected config component, nested ids are joined by "#" mark. + id: id name of expected config item, nested ids are joined by "#" mark. for example: "transforms#5", "transforms#5##keys", etc. """ @@ -157,8 +153,8 @@ def get_resolved_content(self, id: str): def get_config_item(self, id: str, resolve: bool = False): """ - Get the resolved config component, if not resolved, try to resolve it first. - It can be used to modify the config again and support lazy instantiation. + Get the parsed config item, if `resolve=True` and not resolved, try to resolve it first. + It can be used to modify the config in other program and support lazy instantiation. Args: id: id name of expected config component, nested ids are joined by "#" mark. diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index 203ba18d59..69b025b7d6 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -60,7 +60,7 @@ def test_config_content(self): @parameterized.expand([TEST_CASE_1]) @skipUnless(has_tv, "Requires torchvision.") def test_parse(self, config, expected_ids, output_types): - parser = ConfigParser(global_imports={"monai": "monai"}, config=config) + parser = ConfigParser(config=config, globals={"monai": "monai"}) for id, cls in zip(expected_ids, output_types): item = parser.get_config_item(id, resolve=True) # test lazy instantiation From 7d4d60c93a8facdeb0545981be9cf68410d8c2e2 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 18 Feb 2022 21:13:31 +0800 Subject: [PATCH 27/30] [DLMED] add docs Signed-off-by: Nic Ma --- monai/apps/manifest/config_parser.py | 42 ++++++++++++++++++---------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/monai/apps/manifest/config_parser.py b/monai/apps/manifest/config_parser.py index 6a686dd8d1..bdf35e2e7a 100644 --- a/monai/apps/manifest/config_parser.py +++ b/monai/apps/manifest/config_parser.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy import importlib from typing import Any, Dict, Optional, Sequence, Union @@ -19,11 +20,27 @@ class ConfigParser: """ Parse a config content, access or update the items of config content with unique ID. - A typical usage is a config dictionary contains all the necessary components to define training workflow in JSON. + A typical usage is a config dictionary contains all the necessary information to define training workflow in JSON. For more details of the config format, please check :py:class:`monai.apps.ConfigComponent`. + It can recursively parse the config content, treat every item as a `ConfigItem` with unique ID, the ID is joined + by "#" mark for nested content. For example: + The config content `{"preprocessing": [{"": "LoadImage", "": {"keys": "image"}}]}` is parsed as items: + - `id="preprocessing", config=[{"": "LoadImage", "": {"keys": "image"}}]` + - `id="preprocessing#0", config={"": "LoadImage", "": {"keys": "image"}}` + - `id="preprocessing#0#", config="LoadImage"` + - `id="preprocessing#0#", config={"keys": "image"}` + - `id="preprocessing#0##keys", config="image"` + + There are 3 levels config information during the parsing: + - For the input config content, it supports to `get` and `update` the whole content or part of it specified with id, + it can be useful for lazy instantiation, etc. + - After parsing, all the config items are independent `ConfigItem`, can get it before / after resolving references. + - After resolving, the resolved output of every `ConfigItem` is python objects or instances, can be used in other + programs directly. + Args: - config: config content to parse. + config: input config content to parse. excludes: when importing modules to instantiate components, if any string of the `excludes` exists in the full module name, don't import this module. globals: pre-import packages as global variables to evaluate the python `eval` expressions. @@ -55,7 +72,7 @@ def __init__( # flag to identify the parsing status of current config content self.parsed = False - def set_config(self, config: Any, id: Optional[str] = None): + def update_config(self, config: Any, id: Optional[str] = None): """ Set config content for the parser, if `id` provided, `config` will replace the config item with `id`. @@ -95,12 +112,6 @@ def get_config(self, id: Optional[str] = None): def _do_parse(self, config, id: Optional[str] = None): """ Recursively parse the nested config content, add every config item to the resolver. - For example, `{"preprocessing": [{"": "LoadImage", "": {"keys": "image"}}]}` is parsed as items: - - `id="preprocessing", config=[{"": "LoadImage", "": {"keys": "image"}}]` - - `id="preprocessing#0", config={"": "LoadImage", "": {"keys": "image"}}` - - `id="preprocessing#0#", config="LoadImage"` - - `id="preprocessing#0#", config={"keys": "image"}` - - `id="preprocessing#0##keys", config="image"` Args: config: config content to parse. @@ -116,15 +127,16 @@ def _do_parse(self, config, id: Optional[str] = None): for i, v in enumerate(config): sub_id = i if id is None else f"{id}#{i}" self._do_parse(config=v, id=sub_id) - # add config items to the resolver - if ConfigComponent.is_instantiable(config): + # copy every config item to make them independent and add them to the resolver + item_conf = deepcopy(config) + if ConfigComponent.is_instantiable(item_conf): self.reference_resolver.add( - ConfigComponent(config=config, id=id, locator=self.locator) + ConfigComponent(config=item_conf, id=id, locator=self.locator) ) - elif ConfigExpression.is_expression(config): - self.reference_resolver.add(ConfigExpression(config=config, id=id, globals=self.global_imports)) + elif ConfigExpression.is_expression(item_conf): + self.reference_resolver.add(ConfigExpression(config=item_conf, id=id, globals=self.global_imports)) else: - self.reference_resolver.add(ConfigItem(config=config, id=id)) + self.reference_resolver.add(ConfigItem(config=item_conf, id=id)) def parse_config(self): """ From 9a06ce99e561bd29692a06338f67cbc0d98eb391 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 19 Feb 2022 00:26:39 +0800 Subject: [PATCH 28/30] [DLMED] add unit tests Signed-off-by: Nic Ma --- monai/apps/manifest/config_parser.py | 31 ++++++++++++++-------------- tests/test_config_parser.py | 13 +++++++----- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/monai/apps/manifest/config_parser.py b/monai/apps/manifest/config_parser.py index bdf35e2e7a..c8d5f62bd1 100644 --- a/monai/apps/manifest/config_parser.py +++ b/monai/apps/manifest/config_parser.py @@ -41,6 +41,7 @@ class ConfigParser: Args: config: input config content to parse. + id: specified ID name for the config content. excludes: when importing modules to instantiate components, if any string of the `excludes` exists in the full module name, don't import this module. globals: pre-import packages as global variables to evaluate the python `eval` expressions. @@ -53,13 +54,12 @@ class ConfigParser: def __init__( self, - config: Optional[Any] = None, + config: Any, excludes: Optional[Union[Sequence[str], str]] = None, globals: Optional[Dict[str, Any]] = None, ): self.config = None - if config is not None: - self.set_config(config=config) + self.update_config(config=config) self.globals: Dict[str, Any] = {} globals = {"monai": "monai", "torch": "torch", "np": "numpy"} if globals is None else globals @@ -72,7 +72,7 @@ def __init__( # flag to identify the parsing status of current config content self.parsed = False - def update_config(self, config: Any, id: Optional[str] = None): + def update_config(self, config: Any, id: str = ""): """ Set config content for the parser, if `id` provided, `config` will replace the config item with `id`. @@ -80,9 +80,10 @@ def update_config(self, config: Any, id: Optional[str] = None): config: target config content to set. id: id name to specify the target position, joined by "#" mark for nested content, use index from 0 for list. for example: "transforms#5", "transforms#5##keys", etc. + default to update all the config content. """ - if isinstance(id, str) and isinstance(self.config, (dict, list)): + if len(id) > 0 and isinstance(self.config, (dict, list)): keys = id.split("#") # get the last second config item and replace it last_id = "#".join(keys[:-1]) @@ -93,23 +94,24 @@ def update_config(self, config: Any, id: Optional[str] = None): # must totally parse again as the content is modified self.parsed = False - def get_config(self, id: Optional[str] = None): + def get_config(self, id: str = ""): """ Get config content of current config, if `id` provided, get the config item with `id`. Args: id: nested id name to specify the expected position, joined by "#" mark, use index from 0 for list. for example: "transforms#5", "transforms#5##keys", etc. + default to get all the config content. """ config = self.config - if isinstance(id, str) and len(id) > 0 and isinstance(config, (dict, list)): + if len(id) > 0 and isinstance(config, (dict, list)): keys = id.split("#") for k in keys: config = config[k] if isinstance(config, dict) else config[int(k)] return config - def _do_parse(self, config, id: Optional[str] = None): + def _do_parse(self, config, id: str = ""): """ Recursively parse the nested config content, add every config item to the resolver. @@ -117,16 +119,15 @@ def _do_parse(self, config, id: Optional[str] = None): config: config content to parse. id: id name of current config item, nested ids are joined by "#" mark. defaults to None. for example: "transforms#5", "transforms#5##keys", etc. + default to empty string. """ - if isinstance(config, dict): - for k, v in config.items(): - sub_id = k if id is None else f"{id}#{k}" - self._do_parse(config=v, id=sub_id) - if isinstance(config, list): - for i, v in enumerate(config): - sub_id = i if id is None else f"{id}#{i}" + if isinstance(config, (dict, list)): + subs = enumerate(config) if isinstance(config, list) else config.items() + for k, v in subs: + sub_id = f"{id}#{k}" if len(id) > 0 else k self._do_parse(config=v, id=sub_id) + # copy every config item to make them independent and add them to the resolver item_conf = deepcopy(config) if ConfigComponent.is_instantiable(item_conf): diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index 69b025b7d6..90d1733dc4 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from distutils.command.config import config import unittest from unittest import skipUnless from monai.apps.manifest.config_item import ConfigComponent @@ -50,12 +51,14 @@ class TestConfigComponent(unittest.TestCase): def test_config_content(self): - parser = ConfigParser() - test_config = {"preprocessing": [{"name": "LoadImage"}], "dataset": {"name": "Dataset"}} - parser.set_config(config=test_config) + parser = ConfigParser(config={}) + test_config = {"preprocessing": [{"": "LoadImage"}], "dataset": {"": "Dataset"}} + parser.update_config(config=test_config) self.assertEqual(str(parser.get_config()), str(test_config)) - parser.set_config(config={"name": "CacheDataset"}, id="preprocessing#0#datasets") - self.assertDictEqual(parser.get_config(id="preprocessing#0#datasets"), {"name": "CacheDataset"}) + parser.update_config(config={"": "CacheDataset"}, id="dataset") + self.assertDictEqual(parser.get_config(id="dataset"), {"": "CacheDataset"}) + parser.update_config(config="Dataset", id="dataset#") + self.assertEqual(parser.get_config(id="dataset#"), "Dataset") @parameterized.expand([TEST_CASE_1]) @skipUnless(has_tv, "Requires torchvision.") From b9c26e1687e3c45ad9cdb43bedab902fde26fa18 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 19 Feb 2022 12:47:52 +0800 Subject: [PATCH 29/30] [DLMED] update to latest Signed-off-by: Nic Ma --- monai/apps/__init__.py | 1 - monai/apps/manifest/__init__.py | 4 - monai/apps/manifest/config_item.py | 355 ---------------------- monai/apps/manifest/config_parser.py | 179 ----------- monai/apps/manifest/export.py | 59 ++++ monai/apps/manifest/inference.py | 68 +++++ monai/apps/manifest/reference_resolver.py | 244 --------------- monai/apps/manifest/schema/metadata.json | 71 +++++ monai/apps/manifest/verify_network.py | 60 ++++ tests/test_component_locator.py | 35 --- tests/test_config_item.py | 99 ------ tests/test_config_parser.py | 77 ----- tests/test_reference_resolver.py | 126 -------- 13 files changed, 258 insertions(+), 1120 deletions(-) delete mode 100644 monai/apps/manifest/config_item.py delete mode 100644 monai/apps/manifest/config_parser.py create mode 100644 monai/apps/manifest/export.py create mode 100644 monai/apps/manifest/inference.py delete mode 100644 monai/apps/manifest/reference_resolver.py create mode 100644 monai/apps/manifest/schema/metadata.json create mode 100644 monai/apps/manifest/verify_network.py delete mode 100644 tests/test_component_locator.py delete mode 100644 tests/test_config_item.py delete mode 100644 tests/test_config_parser.py delete mode 100644 tests/test_reference_resolver.py diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index 51c4003458..893f7877d2 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -10,6 +10,5 @@ # limitations under the License. from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset -from .manifest import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, ConfigParser, ReferenceResolver from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/manifest/__init__.py b/monai/apps/manifest/__init__.py index b8ddf57f93..1e97f89407 100644 --- a/monai/apps/manifest/__init__.py +++ b/monai/apps/manifest/__init__.py @@ -8,7 +8,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem -from .config_parser import ConfigParser -from .reference_resolver import ReferenceResolver diff --git a/monai/apps/manifest/config_item.py b/monai/apps/manifest/config_item.py deleted file mode 100644 index 1f0c06c8fc..0000000000 --- a/monai/apps/manifest/config_item.py +++ /dev/null @@ -1,355 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import sys -import warnings -from abc import ABC, abstractmethod -from importlib import import_module -from typing import Any, Dict, List, Mapping, Optional, Sequence, Union - -from monai.utils import ensure_tuple, instantiate - -__all__ = ["ComponentLocator", "ConfigItem", "ConfigExpression", "ConfigComponent"] - - -class Instantiable(ABC): - """ - Base class for instantiable object with module name and arguments. - - .. code-block:: python - - if not is_disabled(): - instantiate(module_name=resolve_module_name(), args=resolve_args()) - - """ - - @abstractmethod - def resolve_module_name(self, *args: Any, **kwargs: Any): - """ - Resolve the target module name, it should return an object class (or function) to be instantiated. - """ - raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") - - @abstractmethod - def resolve_args(self, *args: Any, **kwargs: Any): - """ - Resolve the arguments, it should return arguments to be passed to the object when instantiating. - """ - raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") - - @abstractmethod - def is_disabled(self, *args: Any, **kwargs: Any) -> bool: - """ - Return a boolean flag to indicate whether the object should be instantiated. - """ - raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") - - @abstractmethod - def instantiate(self, *args: Any, **kwargs: Any): - """ - Instantiate the target component. - """ - raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") - - -class ComponentLocator: - """ - Scan all the available classes and functions in the MONAI package and map them with the module paths in a table. - It's used to locate the module path for provided component name. - - Args: - excludes: if any string of the `excludes` exists in the full module name, don't import this module. - - """ - - MOD_START = "monai" - - def __init__(self, excludes: Optional[Union[Sequence[str], str]] = None): - self.excludes = [] if excludes is None else ensure_tuple(excludes) - self._components_table: Optional[Dict[str, List]] = None - - def _find_module_names(self) -> List[str]: - """ - Find all the modules start with MOD_START and don't contain any of `excludes`. - - """ - return [ - m for m in sys.modules.keys() if m.startswith(self.MOD_START) and all(s not in m for s in self.excludes) - ] - - def _find_classes_or_functions(self, modnames: Union[Sequence[str], str]) -> Dict[str, List]: - """ - Find all the classes and functions in the modules with specified `modnames`. - - Args: - modnames: names of the target modules to find all the classes and functions. - - """ - table: Dict[str, List] = {} - # all the MONAI modules are already loaded by `load_submodules` - for modname in ensure_tuple(modnames): - try: - # scan all the classes and functions in the module - module = import_module(modname) - for name, obj in inspect.getmembers(module): - if (inspect.isclass(obj) or inspect.isfunction(obj)) and obj.__module__ == modname: - if name not in table: - table[name] = [] - table[name].append(modname) - except ModuleNotFoundError: - pass - return table - - def get_component_module_name(self, name: str) -> Optional[Union[List[str], str]]: - """ - Get the full module name of the class or function with specified ``name``. - If target component name exists in multiple packages or modules, return a list of full module names. - - Args: - name: name of the expected class or function. - - """ - if not isinstance(name, str): - raise ValueError(f"`name` must be a valid string, but got: {name}.") - if self._components_table is None: - # init component and module mapping table - self._components_table = self._find_classes_or_functions(self._find_module_names()) - - mods: Optional[Union[List[str], str]] = self._components_table.get(name, None) - if isinstance(mods, list) and len(mods) == 1: - mods = mods[0] - return mods - - -class ConfigItem: - """ - Basic data structure to represent a configuration item. - - A `ConfigItem` instance can optionally have a string id, so that other items can refer to it. - It has a build-in `config` property to store the configuration object. - - Args: - config: content of a config item, can be objects of any types, - a configuration resolver may interpret the content to generate a configuration object. - id: optional name of the current config item, defaults to `None`. - - """ - - def __init__(self, config: Any, id: Optional[str] = None) -> None: - self.config = config - self.id = id - - def get_id(self) -> Optional[str]: - """ - Get the ID name of current config item, useful to identify config items during parsing. - - """ - return self.id - - def update_config(self, config: Any): - """ - Replace the content of `self.config` with new `config`. - A typical usage is to modify the initial config content at runtime. - - Args: - config: content of a `ConfigItem`. - - """ - self.config = config - - def get_config(self): - """ - Get the config content of current config item. - - """ - return self.config - - -class ConfigComponent(ConfigItem, Instantiable): - """ - Subclass of :py:class:`monai.apps.ConfigItem`, this class uses a dictionary with string keys to - represent a component of `class` or `function` and supports instantiation. - - Currently, four special keys (strings surrounded by ``<>``) are defined and interpreted beyond the regular literals: - - - class or function identifier of the python module, specified by one of the two keys. - - ``""``: indicates build-in python classes or functions such as "LoadImageDict". - - ``""``: full module name, such as "monai.transforms.LoadImageDict". - - ``""``: input arguments to the python module. - - ``""``: a flag to indicate whether to skip the instantiation. - - .. code-block:: python - - locator = ComponentLocator(excludes=["modules_to_exclude"]) - config = { - "": "LoadImaged", - "": { - "keys": ["image", "label"] - } - } - - configer = ConfigComponent(config, id="test", locator=locator) - image_loader = configer.instantiate() - print(image_loader) # - - Args: - config: content of a config item. - id: optional name of the current config item, defaults to `None`. - locator: a ``ComponentLocator`` to convert a module name string into the actual python module. - if `None`, a ``ComponentLocator(excludes=excludes)`` will be used. - excludes: if ``locator`` is None, create a new ``ComponentLocator`` with ``excludes``. - See also: :py:class:`monai.apps.manifest.ComponentLocator`. - - """ - - def __init__( - self, - config: Any, - id: Optional[str] = None, - locator: Optional[ComponentLocator] = None, - excludes: Optional[Union[Sequence[str], str]] = None, - ) -> None: - super().__init__(config=config, id=id) - self.locator = ComponentLocator(excludes=excludes) if locator is None else locator - - @staticmethod - def is_instantiable(config: Any) -> bool: - """ - Check whether this config represents a `class` or `function` that is to be instantiated. - - Args: - config: input config content to check. - - """ - return isinstance(config, Mapping) and ("" in config or "" in config) - - def resolve_module_name(self): - """ - Resolve the target module name from current config content. - The config content must have ``""`` or ``""``. - When both are specified, ``""`` will be used. - - """ - config = dict(self.get_config()) - path = config.get("") - if path is not None: - if not isinstance(path, str): - raise ValueError(f"'' must be a string, but got: {path}.") - if "" in config: - warnings.warn(f"both '' and '', default to use '': {path}.") - return path - - name = config.get("") - if not isinstance(name, str): - raise ValueError("must provide a string for `` or `` of target component to instantiate.") - - module = self.locator.get_component_module_name(name) - if module is None: - raise ModuleNotFoundError(f"can not find component '{name}' in {self.locator.MOD_START} modules.") - if isinstance(module, list): - warnings.warn( - f"there are more than 1 component have name `{name}`: {module}, use the first one `{module[0]}." - f" if want to use others, please set its module path in `` directly." - ) - module = module[0] - return f"{module}.{name}" - - def resolve_args(self): - """ - Utility function used in `instantiate()` to resolve the arguments from current config content. - - """ - return self.get_config().get("", {}) - - def is_disabled(self) -> bool: # type: ignore - """ - Utility function used in `instantiate()` to check whether to skip the instantiation. - - """ - _is_disabled = self.get_config().get("", False) - return _is_disabled.lower().strip() == "true" if isinstance(_is_disabled, str) else bool(_is_disabled) - - def instantiate(self, **kwargs) -> object: # type: ignore - """ - Instantiate component based on ``self.config`` content. - The target component must be a `class` or a `function`, otherwise, return `None`. - - Args: - kwargs: args to override / add the config args when instantiation. - - """ - if not self.is_instantiable(self.get_config()) or self.is_disabled(): - # if not a class or function or marked as `disabled`, skip parsing and return `None` - return None - - modname = self.resolve_module_name() - args = self.resolve_args() - args.update(kwargs) - return instantiate(modname, **args) - - -class ConfigExpression(ConfigItem): - """ - Subclass of :py:class:`monai.apps.ConfigItem`, the `ConfigItem` represents an executable expression - (execute based on ``eval()``). - - See also: - - - https://docs.python.org/3/library/functions.html#eval. - - For example: - - .. code-block:: python - - import monai - from monai.apps.manifest import ConfigExpression - - config = "$monai.__version__" - expression = ConfigExpression(config, id="test", globals={"monai": monai}) - print(expression.execute()) - - Args: - config: content of a config item. - id: optional name of current config item, defaults to `None`. - globals: additional global context to evaluate the string. - - """ - - def __init__(self, config: Any, id: Optional[str] = None, globals: Optional[Dict] = None) -> None: - super().__init__(config=config, id=id) - self.globals = globals - - def evaluate(self, locals: Optional[Dict] = None): - """ - Excute current config content and return the result if it is expression, based on python `eval()`. - For more details: https://docs.python.org/3/library/functions.html#eval. - - Args: - locals: besides ``globals``, may also have some local symbols used in the expression at runtime. - - """ - value = self.get_config() - if not ConfigExpression.is_expression(value): - return None - return eval(value[1:], self.globals, locals) - - @staticmethod - def is_expression(config: Union[Dict, List, str]) -> bool: - """ - Check whether the config is an executable expression string. - Currently A string starts with ``"$"`` character is interpreted as an expression. - - Args: - config: input config content to check. - - """ - return isinstance(config, str) and config.startswith("$") diff --git a/monai/apps/manifest/config_parser.py b/monai/apps/manifest/config_parser.py deleted file mode 100644 index c8d5f62bd1..0000000000 --- a/monai/apps/manifest/config_parser.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from copy import deepcopy -import importlib -from typing import Any, Dict, Optional, Sequence, Union - -from monai.apps.manifest.config_item import ConfigComponent, ConfigExpression, ConfigItem, ComponentLocator -from monai.apps.manifest.reference_resolver import ReferenceResolver - - -class ConfigParser: - """ - Parse a config content, access or update the items of config content with unique ID. - A typical usage is a config dictionary contains all the necessary information to define training workflow in JSON. - For more details of the config format, please check :py:class:`monai.apps.ConfigComponent`. - - It can recursively parse the config content, treat every item as a `ConfigItem` with unique ID, the ID is joined - by "#" mark for nested content. For example: - The config content `{"preprocessing": [{"": "LoadImage", "": {"keys": "image"}}]}` is parsed as items: - - `id="preprocessing", config=[{"": "LoadImage", "": {"keys": "image"}}]` - - `id="preprocessing#0", config={"": "LoadImage", "": {"keys": "image"}}` - - `id="preprocessing#0#", config="LoadImage"` - - `id="preprocessing#0#", config={"keys": "image"}` - - `id="preprocessing#0##keys", config="image"` - - There are 3 levels config information during the parsing: - - For the input config content, it supports to `get` and `update` the whole content or part of it specified with id, - it can be useful for lazy instantiation, etc. - - After parsing, all the config items are independent `ConfigItem`, can get it before / after resolving references. - - After resolving, the resolved output of every `ConfigItem` is python objects or instances, can be used in other - programs directly. - - Args: - config: input config content to parse. - id: specified ID name for the config content. - excludes: when importing modules to instantiate components, if any string of the `excludes` exists - in the full module name, don't import this module. - globals: pre-import packages as global variables to evaluate the python `eval` expressions. - for example, pre-import `monai`, then execute `eval("monai.data.list_data_collate")`. - default to `{"monai": "monai", "torch": "torch", "np": "numpy"}` as `numpy` and `torch` - are MONAI mininum requirements. - if the value in global is string, will import it immediately. - - """ - - def __init__( - self, - config: Any, - excludes: Optional[Union[Sequence[str], str]] = None, - globals: Optional[Dict[str, Any]] = None, - ): - self.config = None - self.update_config(config=config) - - self.globals: Dict[str, Any] = {} - globals = {"monai": "monai", "torch": "torch", "np": "numpy"} if globals is None else globals - if globals is not None: - for k, v in globals.items(): - self.globals[k] = importlib.import_module(v) if isinstance(v, str) else v - - self.locator = ComponentLocator(excludes=excludes) - self.reference_resolver: Optional[ReferenceResolver] = None - # flag to identify the parsing status of current config content - self.parsed = False - - def update_config(self, config: Any, id: str = ""): - """ - Set config content for the parser, if `id` provided, `config` will replace the config item with `id`. - - Args: - config: target config content to set. - id: id name to specify the target position, joined by "#" mark for nested content, use index from 0 for list. - for example: "transforms#5", "transforms#5##keys", etc. - default to update all the config content. - - """ - if len(id) > 0 and isinstance(self.config, (dict, list)): - keys = id.split("#") - # get the last second config item and replace it - last_id = "#".join(keys[:-1]) - conf_ = self.get_config(id=last_id) - conf_[keys[-1] if isinstance(conf_, dict) else int(keys[-1])] = config - else: - self.config = config - # must totally parse again as the content is modified - self.parsed = False - - def get_config(self, id: str = ""): - """ - Get config content of current config, if `id` provided, get the config item with `id`. - - Args: - id: nested id name to specify the expected position, joined by "#" mark, use index from 0 for list. - for example: "transforms#5", "transforms#5##keys", etc. - default to get all the config content. - - """ - config = self.config - if len(id) > 0 and isinstance(config, (dict, list)): - keys = id.split("#") - for k in keys: - config = config[k] if isinstance(config, dict) else config[int(k)] - return config - - def _do_parse(self, config, id: str = ""): - """ - Recursively parse the nested config content, add every config item to the resolver. - - Args: - config: config content to parse. - id: id name of current config item, nested ids are joined by "#" mark. defaults to None. - for example: "transforms#5", "transforms#5##keys", etc. - default to empty string. - - """ - if isinstance(config, (dict, list)): - subs = enumerate(config) if isinstance(config, list) else config.items() - for k, v in subs: - sub_id = f"{id}#{k}" if len(id) > 0 else k - self._do_parse(config=v, id=sub_id) - - # copy every config item to make them independent and add them to the resolver - item_conf = deepcopy(config) - if ConfigComponent.is_instantiable(item_conf): - self.reference_resolver.add( - ConfigComponent(config=item_conf, id=id, locator=self.locator) - ) - elif ConfigExpression.is_expression(item_conf): - self.reference_resolver.add(ConfigExpression(config=item_conf, id=id, globals=self.global_imports)) - else: - self.reference_resolver.add(ConfigItem(config=item_conf, id=id)) - - def parse_config(self): - """ - Parse the config content, add every config item to the resolver and mark as `parsed`. - - """ - self.reference_resolver = ReferenceResolver() - self._do_parse(config=self.config) - self.parsed = True - - def get_resolved_content(self, id: str): - """ - Get the resolved result of config items with specified id, if not resolved, try to resolve it first. - If the config item is instantiable, the resolved result is the instance. - If the config item is an expression, the resolved result is output of when evaluating the expression. - Otherwise, the resolved result is the updated config content of the config item. - - Args: - id: id name of expected config item, nested ids are joined by "#" mark. - for example: "transforms#5", "transforms#5##keys", etc. - - """ - if not self.parsed: - self.parse_config() - return self.reference_resolver.get_resolved_content(id=id) - - def get_config_item(self, id: str, resolve: bool = False): - """ - Get the parsed config item, if `resolve=True` and not resolved, try to resolve it first. - It can be used to modify the config in other program and support lazy instantiation. - - Args: - id: id name of expected config component, nested ids are joined by "#" mark. - for example: "transforms#5", "transforms#5##keys", etc. - - """ - if not self.parsed: - self.parse_config() - return self.reference_resolver.get_item(id=id, resolve=resolve) diff --git a/monai/apps/manifest/export.py b/monai/apps/manifest/export.py new file mode 100644 index 0000000000..acfeb98291 --- /dev/null +++ b/monai/apps/manifest/export.py @@ -0,0 +1,59 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json + +import torch +from monai.apps import ConfigParser +from ignite.handlers import Checkpoint +from monai.data import save_net_with_metadata +from monai.networks import convert_to_torchscript + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--weights', '-w', type=str, help='file path of the trained model weights', required=True) + parser.add_argument('--config', '-c', type=str, help='file path of config file that defines network', required=True) + parser.add_argument('--meta', '-e', type=str, help='file path of the meta data') + args = parser.parse_args() + + # load config file + with open(args.config) as f: + config_dict = json.load(f) + # load meta data + with open(args.meta) as f: + meta_dict = json.load(f) + + net: torch.nn.Module = None + # TODO: parse network definiftion from config file and construct network instance + config_parser = ConfigParser(config_dict) + net = config_parser.get_instance("network") + + checkpoint = torch.load(args.weights) + # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver + Checkpoint.load_objects(to_load={"model": net}, checkpoint=checkpoint) + + # convert to TorchScript model and save with meta data + net = convert_to_torchscript(model=net) + + save_net_with_metadata( + jit_obj=net, + filename_prefix_or_stream="model.ts", + include_config_vals=False, + append_timestamp=False, + meta_values=meta_dict, + more_extra_files={args.config: json.dumps(config_dict).encode()}, + ) + + +if __name__ == '__main__': + main() diff --git a/monai/apps/manifest/inference.py b/monai/apps/manifest/inference.py new file mode 100644 index 0000000000..41b9792869 --- /dev/null +++ b/monai/apps/manifest/inference.py @@ -0,0 +1,68 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json + +import torch +from monai.apps import ConfigParser +from monai.data import decollate_batch +from monai.inferers import Inferer +from monai.transforms import Transform +from monai.utils.enums import CommonKeys + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--config', '-c', type=str, help='config file that defines components', required=True) + parser.add_argument('--meta', '-e', type=str, help='file path of the meta data') + parser.add_argument('--override', '-o', type=str, help='config file that override components', required=False) + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + configs = {} + + # load meta data + with open(args.meta) as f: + configs.update(json.load(f)) + # load config file, can override meta data in config + with open(args.config) as f: + configs.update(json.load(f)) + + model: torch.nn.Module = None + dataloader: torch.utils.data.DataLoader = None + inferer: Inferer = None + postprocessing: Transform = None + # TODO: parse inference config file and construct instances + config_parser = ConfigParser(configs) + + # change JSON config content in python code, lazy instantiation + model_conf = config_parser.get_config("model") + model_conf["disabled"] = False + model = config_parser.build(model_conf).to(device) + + # instantialize the components immediately + dataloader = config_parser.get_instance("dataloader") + inferer = config_parser.get_instance("inferer") + postprocessing = config_parser.get_instance("postprocessing") + + model.eval() + with torch.no_grad(): + for d in dataloader: + images = d[CommonKeys.IMAGE].to(device) + # define sliding window size and batch size for windows inference + d[CommonKeys.PRED] = inferer(inputs=images, predictor=model) + # decollate the batch data into a list of dictionaries, then execute postprocessing transforms + [postprocessing(i) for i in decollate_batch(d)] + + +if __name__ == '__main__': + main() diff --git a/monai/apps/manifest/reference_resolver.py b/monai/apps/manifest/reference_resolver.py deleted file mode 100644 index 1988f35535..0000000000 --- a/monai/apps/manifest/reference_resolver.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -from typing import Dict, List, Optional, Union -import warnings - -from monai.apps.manifest.config_item import ConfigComponent, ConfigExpression, ConfigItem - - -class ReferenceResolver: - """ - Utility class to resolve the references between config items. - - Args: - components: config components to resolve, if None, can also `add()` component in runtime. - - """ - - def __init__(self, items: Optional[Dict[str, ConfigItem]] = None): - self.items = {} if items is None else items - self.resolved_content = {} - - def add(self, item: ConfigItem): - """ - Add a config item to the resolution graph. - - Args: - item: a config item to resolve. - - """ - id = item.get_id() - if id in self.items: - warnings.warn(f"id '{id}' is already added.") - return - self.items[id] = item - - def resolve_one_item(self, id: str, waiting_list: Optional[List[str]] = None): - """ - Resolve one item with specified id name. - If has unresolved references, recursively resolve the references first. - - Args: - id: id name of expected item to resolve. - waiting_list: list of items wait to resolve references. it's used to detect circular references. - when resolving references like: `{"name": "A", "dep": "@B"}` and `{"name": "B", "dep": "@A"}`. - - """ - if waiting_list is None: - waiting_list = [] - waiting_list.append(id) - item = self.items.get(id) - item_config = item.get_config() - ref_ids = self.find_refs_in_config(config=item_config, id=id) - - # if current item has reference already in the waiting list, that's circular references - for d in ref_ids: - if d in waiting_list: - raise ValueError(f"detected circular references for id='{d}' in the config content.") - - if len(ref_ids) > 0: - # # check whether the component has any unresolved deps - for ref_id in ref_ids: - if ref_id not in self.resolved_content: - # this reffring component is not resolved - if ref_id not in self.items: - raise RuntimeError(f"the referring item `{ref_id}` is not defined in config.") - # resolve the reference first - self.resolve_one_item(id=ref_id, waiting_list=waiting_list) - - # all references are resolved - new_config = self.resolve_config_with_refs(config=item_config, id=id, refs=self.resolved_content) - item.update_config(config=new_config) - if isinstance(item, ConfigComponent): - self.resolved_content[id] = item.instantiate() - elif isinstance(item, ConfigExpression): - self.resolved_content[id] = item.execute(locals={"refs": self.resolved_content}) - else: - self.resolved_content[id] = new_config - - def resolve_all(self): - """ - Resolve the references for all the config items. - - """ - for k in self.items.keys(): - self.resolve_one_item(id=k) - - def get_resolved_content(self, id: str): - """ - Get the resolved content with specified id name. - If not resolved, try to resolve it first. - - Args: - id: id name of the expected item. - - """ - if id not in self.resolved_content: - self.resolve_one_item(id=id) - return self.resolved_content.get(id) - - def get_item(self, id: str, resolve: bool = False): - """ - Get the config item with specified id name, then can be used for lazy instantiation. - If `resolve=True`, try to resolve it first. - - Args: - id: id name of the expected config item. - - """ - if resolve and id not in self.resolved_content: - self.resolve_one_item(id=id) - return self.items.get(id) - - @staticmethod - def match_refs_pattern(value: str) -> List[str]: - """ - Match regular expression for the input string to find the references. - The reference part starts with "@", like: "@XXX#YYY#ZZZ". - - Args: - value: input value to match regular expression. - - """ - refs: List[str] = [] - # regular expression pattern to match "@XXX" or "@XXX#YYY" - result = re.compile(r"@\w*[\#\w]*").findall(value) - for item in result: - if ConfigExpression.is_expression(value) or value == item: - # only check when string starts with "$" or the whole content is "@XXX" - ref_obj_id = item[1:] - if ref_obj_id not in refs: - refs.append(ref_obj_id) - return refs - - @staticmethod - def resolve_refs_pattern(value: str, refs: Dict) -> str: - """ - Match regular expression for the input string to update content with the references. - The reference part starts with "@", like: "@XXX#YYY#ZZZ". - References dictionary must contain the referring IDs as keys. - - Args: - value: input value to match regular expression. - refs: all the referring components with ids as keys, default to `None`. - - """ - # regular expression pattern to match "@XXX" or "@XXX#YYY" - result = re.compile(r"@\w*[\#\w]*").findall(value) - for item in result: - ref_id = item[1:] - if ConfigExpression.is_expression(value): - # replace with local code and execute later - value = value.replace(item, f"refs['{ref_id}']") - elif value == item: - if ref_id not in refs: - raise KeyError(f"can not find expected ID '{ref_id}' in the references.") - value = refs[ref_id] - return value - - @staticmethod - def find_refs_in_config( - config: Union[Dict, List, str], - id: Optional[str] = None, - refs: Optional[List[str]] = None, - ) -> List[str]: - """ - Recursively search all the content of input config item to get the ids of references. - References mean (1) referring to the ID of other item, can be extracted by `match_fn`, for example: - `{"lr": "$@epoch / 100"}` with "@" mark, the referring IDs: `["epoch"]`. (2) if sub-item in the config - is instantiable, treat it as reference because must instantiate it before resolving current config. - For `dict` and `list`, recursively check the sub-items. - - Args: - config: input config content to search. - id: ID name for the input config, default to `None`. - refs: list of the ID name of existing references, default to `None`. - - """ - refs_: List[str] = [] if refs is None else refs - if isinstance(config, str): - refs_ += ReferenceResolver.match_refs_pattern(value=config) - - if isinstance(config, list): - for i, v in enumerate(config): - sub_id = f"{id}#{i}" if id is not None else f"{i}" - if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v): - refs_.append(sub_id) - refs_ = ReferenceResolver.find_refs_in_config(v, sub_id, refs_) - if isinstance(config, dict): - for k, v in config.items(): - sub_id = f"{id}#{k}" if id is not None else f"{k}" - if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v): - refs_.append(sub_id) - refs_ = ReferenceResolver.find_refs_in_config(v, sub_id, refs_) - return refs_ - - @staticmethod - def resolve_config_with_refs( - config: Union[Dict, List, str], - id: Optional[str] = None, - refs: Optional[Dict] = None, - ): - """ - With all the references in `refs`, resolve the config content with them and return new config. - - Args: - config: input config content to resolve. - id: ID name for the input config, default to `None`. - refs: all the referring components with ids, default to `None`. - - """ - refs_: Dict = {} if refs is None else refs - if isinstance(config, str): - config = ReferenceResolver.resolve_refs_pattern(config, refs) - if isinstance(config, list): - # all the items in the list should be replaced with the references - ret_list: List = [] - for i, v in enumerate(config): - sub_id = f"{id}#{i}" if id is not None else f"{i}" - if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v): - ret_list.append(refs_[sub_id]) - else: - ret_list.append(ReferenceResolver.resolve_config_with_refs(v, sub_id, refs_)) - return ret_list - if isinstance(config, dict): - # all the items in the dict should be replaced with the references - ret_dict: Dict = {} - for k, v in config.items(): - sub_id = f"{id}#{k}" if id is not None else f"{k}" - if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v): - ret_dict[k] = refs_[sub_id] - else: - ret_dict[k] = ReferenceResolver.resolve_config_with_refs(v, sub_id, refs_) - return ret_dict - return config diff --git a/monai/apps/manifest/schema/metadata.json b/monai/apps/manifest/schema/metadata.json new file mode 100644 index 0000000000..babee8b30e --- /dev/null +++ b/monai/apps/manifest/schema/metadata.json @@ -0,0 +1,71 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://monai.io/mmar_metadata_schema.json", + "title": "metadata", + "description": "metadata that defines the context information for MMAR.", + "type": "object", + "properties": { + "version": { + "description": "version number of this MMAR.", + "type": "string" + }, + "monai_version": { + "description": "version number of MONAI used in this MMAR.", + "type": "string" + }, + "pytorch_version": { + "description": "version number of PyTorch used in this MMAR.", + "type": "string" + }, + "numpy_version": { + "description": "version number of MONAI used in this MMAR.", + "type": "string" + }, + "network_data_format": { + "description": "define the input and output data format for network.", + "type": "object", + "properties": { + "inputs": { + "type": "object", + "properties": { + "image": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "the data format for `image`." + }, + "format": { + "type": "string" + }, + "num_channels": { + "type": "integer", + "minimum": 1 + }, + "spatial_shape": { + "type": "array", + "items": { + "type": "integer", + "minumum": 1 + } + }, + "dtype": { + "type": "string" + }, + "value_range": { + "type": "array", + "items": { + "type": "number", + "unuqueItems": true + } + }, + "required": ["num_channels", "spatial_shape", "value_range"] + } + } + } + } + } + }, + "required": ["monai_version", "pytorch_version", "network_data_format"] + } +} diff --git a/monai/apps/manifest/verify_network.py b/monai/apps/manifest/verify_network.py new file mode 100644 index 0000000000..c0dbf178dc --- /dev/null +++ b/monai/apps/manifest/verify_network.py @@ -0,0 +1,60 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json + +import torch +from monai.apps import ConfigParser +from monai.utils.type_conversion import get_equivalent_dtype + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--config', '-c', type=str, help='config file that defines components', required=True) + parser.add_argument('--meta', '-e', type=str, help='file path of the meta data') + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + configs = {} + + # load meta data + with open(args.meta) as f: + configs.update(json.load(f)) + # load config file, can override meta data in config + with open(args.config) as f: + configs.update(json.load(f)) + + model: torch.nn.Module = None + # TODO: parse inference config file and construct instances + config_parser = ConfigParser(configs) + + model = config_parser.get_instance("model") + input_channels = config_parser.get_config("network_data_format#inputs#image#num_channels") + input_spatial_shape = tuple(config_parser.get_config("network_data_format#inputs#image#spatial_shape")) + dtype = config_parser.get_config("network_data_format#inputs#image#dtype") + dtype = get_equivalent_dtype(dtype, data_type=torch.Tensor) + + output_channels = config_parser.get_config("network_data_format#outputs#pred#num_channels") + output_spatial_shape = tuple(config_parser.get_config("network_data_format#outputs#pred#spatial_shape")) + + model.eval() + with torch.no_grad(): + test_data = torch.rand(*(input_channels, *input_spatial_shape), dtype=dtype, device=device) + output = model(test_data) + if output.shape[0] != output_channels: + raise ValueError(f"channel number of output data doesn't match expection: {output_channels}.") + if output.shape[1:] != output_spatial_shape: + raise ValueError(f"spatial shape of output data doesn't match expection: {output_spatial_shape}.") + + +if __name__ == '__main__': + main() diff --git a/tests/test_component_locator.py b/tests/test_component_locator.py deleted file mode 100644 index eafb2152d1..0000000000 --- a/tests/test_component_locator.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from pydoc import locate - -from monai.apps.manifest import ComponentLocator -from monai.utils import optional_import - -_, has_ignite = optional_import("ignite") - - -class TestComponentLocator(unittest.TestCase): - def test_locate(self): - locator = ComponentLocator(excludes=None if has_ignite else ["monai.handlers"]) - # test init mapping table and get the module path of component - self.assertEqual(locator.get_component_module_name("LoadImage"), "monai.transforms.io.array") - self.assertGreater(len(locator._components_table), 0) - for _, mods in locator._components_table.items(): - for i in mods: - self.assertGreater(len(mods), 0) - # ensure we can locate all the items by `name` - self.assertIsNotNone(locate(i), msg=f"can not locate target: {i}.") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_config_item.py b/tests/test_config_item.py deleted file mode 100644 index b2c2fec6c6..0000000000 --- a/tests/test_config_item.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from functools import partial -from typing import Callable - -import torch -from parameterized import parameterized - -import monai -from monai.apps import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem -from monai.data import DataLoader, Dataset -from monai.transforms import LoadImaged, RandTorchVisiond -from monai.utils import optional_import - -_, has_tv = optional_import("torchvision") - -TEST_CASE_1 = [{"lr": 0.001}, 0.0001] - -TEST_CASE_2 = [{"": "LoadImaged", "": {"keys": ["image"]}}, LoadImaged] -# test python `` -TEST_CASE_3 = [{"": "monai.transforms.LoadImaged", "": {"keys": ["image"]}}, LoadImaged] -# test `` -TEST_CASE_4 = [{"": "LoadImaged", "": True, "": {"keys": ["image"]}}, dict] -# test `` -TEST_CASE_5 = [{"": "LoadImaged", "": "true", "": {"keys": ["image"]}}, dict] -# test non-monai modules and excludes -TEST_CASE_6 = [ - {"": "torch.optim.Adam", "": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, - torch.optim.Adam, -] -TEST_CASE_7 = [{"": "decollate_batch", "": {"detach": True, "pad": True}}, partial] -# test args contains "name" field -TEST_CASE_8 = [ - {"": "RandTorchVisiond", "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}}, - RandTorchVisiond, -] -# test execute some function in args, test pre-imported global packages `monai` -TEST_CASE_9 = ["collate_fn", "$monai.data.list_data_collate"] -# test lambda function, should not execute the lambda function, just change the string -TEST_CASE_10 = ["collate_fn", "$lambda x: monai.data.list_data_collate(x) + torch.tensor(var)"] - - -class TestConfigItem(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) - def test_item(self, test_input, expected): - item = ConfigItem(config=test_input) - conf = item.get_config() - conf["lr"] = 0.0001 - item.update_config(config=conf) - self.assertEqual(item.get_config()["lr"], expected) - - @parameterized.expand( - [TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7] - + ([TEST_CASE_8] if has_tv else []) - ) - def test_component(self, test_input, output_type): - locator = ComponentLocator(excludes=["metrics"]) - configer = ConfigComponent(id="test", config=test_input, locator=locator) - ret = configer.instantiate() - if test_input.get("", False): - # test `` works fine - self.assertEqual(ret, None) - return - self.assertTrue(isinstance(ret, output_type)) - if isinstance(ret, LoadImaged): - self.assertEqual(ret.keys[0], "image") - - @parameterized.expand([TEST_CASE_9, TEST_CASE_10]) - def test_expression(self, id, test_input): - configer = ConfigExpression(id=id, config=test_input, globals={"monai": monai, "torch": torch}) - var = 100 - ret = configer.evaluate(locals={"var": var}) - self.assertTrue(isinstance(ret, Callable)) - - def test_lazy_instantiation(self): - config = {"": "DataLoader", "": {"dataset": Dataset(data=[1, 2]), "batch_size": 2}} - configer = ConfigComponent(config=config, locator=None) - init_config = configer.get_config() - # modify config content at runtime - init_config[""]["batch_size"] = 4 - configer.update_config(config=init_config) - - ret = configer.instantiate() - self.assertTrue(isinstance(ret, DataLoader)) - self.assertEqual(ret.batch_size, 4) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py deleted file mode 100644 index 90d1733dc4..0000000000 --- a/tests/test_config_parser.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from distutils.command.config import config -import unittest -from unittest import skipUnless -from monai.apps.manifest.config_item import ConfigComponent - -from parameterized import parameterized - -from monai.apps import ConfigParser -from monai.data import DataLoader, Dataset -from monai.transforms import Compose, LoadImaged, RandTorchVisiond -from monai.utils import optional_import - -_, has_tv = optional_import("torchvision") - -# test the resolved and parsed instances -TEST_CASE_1 = [ - { - "transform": { - "": "Compose", - "": { - "transforms": [ - {"": "LoadImaged", "": {"keys": "image"}}, - { - "": "RandTorchVisiond", - "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}, - }, - ] - }, - }, - "dataset": {"": "Dataset", "": {"data": [1, 2], "transform": "@transform"}}, - "dataloader": { - "": "DataLoader", - "": {"dataset": "@dataset", "batch_size": 2, "collate_fn": "monai.data.list_data_collate"}, - }, - }, - ["transform", "transform##transforms#0", "transform##transforms#1", "dataset", "dataloader"], - [Compose, LoadImaged, RandTorchVisiond, Dataset, DataLoader], -] - - -class TestConfigComponent(unittest.TestCase): - def test_config_content(self): - parser = ConfigParser(config={}) - test_config = {"preprocessing": [{"": "LoadImage"}], "dataset": {"": "Dataset"}} - parser.update_config(config=test_config) - self.assertEqual(str(parser.get_config()), str(test_config)) - parser.update_config(config={"": "CacheDataset"}, id="dataset") - self.assertDictEqual(parser.get_config(id="dataset"), {"": "CacheDataset"}) - parser.update_config(config="Dataset", id="dataset#") - self.assertEqual(parser.get_config(id="dataset#"), "Dataset") - - @parameterized.expand([TEST_CASE_1]) - @skipUnless(has_tv, "Requires torchvision.") - def test_parse(self, config, expected_ids, output_types): - parser = ConfigParser(config=config, globals={"monai": "monai"}) - for id, cls in zip(expected_ids, output_types): - item = parser.get_config_item(id, resolve=True) - # test lazy instantiation - if isinstance(item, ConfigComponent): - self.assertTrue(isinstance(item.instantiate(), cls)) - # test get instance directly - self.assertTrue(isinstance(parser.get_resolved_content(id), cls)) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_reference_resolver.py b/tests/test_reference_resolver.py deleted file mode 100644 index f6a8dc416c..0000000000 --- a/tests/test_reference_resolver.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from monai.apps.manifest.config_item import ComponentLocator, ConfigExpression, ConfigItem - -import torch -from parameterized import parameterized - -import monai -from monai.apps import ConfigComponent, ReferenceResolver -from monai.data import DataLoader -from monai.transforms import LoadImaged, RandTorchVisiond -from monai.utils import optional_import - -_, has_tv = optional_import("torchvision") - -# test instance with no dependencies -TEST_CASE_1 = [ - { - # all the recursively parsed config items - "transform#1": {"": "LoadImaged", "": {"keys": ["image"]}}, - "transform#1#": "LoadImaged", - "transform#1#": {"keys": ["image"]}, - "transform#1##keys": ["image"], - "transform#1##keys#0": "image", - }, - "transform#1", - LoadImaged, -] -# test depends on other component and executable code -TEST_CASE_2 = [ - { - # some the recursively parsed config items - "dataloader": { - "": "DataLoader", - "": {"dataset": "@dataset", "collate_fn": "$monai.data.list_data_collate"}, - }, - "dataset": {"": "Dataset", "": {"data": [1, 2]}}, - "dataloader#": "DataLoader", - "dataloader#": {"dataset": "@dataset", "collate_fn": "$monai.data.list_data_collate"}, - "dataloader##dataset": "@dataset", - "dataloader##collate_fn": "$monai.data.list_data_collate", - "dataset#": "Dataset", - "dataset#": {"data": [1, 2]}, - "dataset##data": [1, 2], - "dataset##data#0": 1, - "dataset##data#1": 2, - }, - "dataloader", - DataLoader, -] -# test config has key `name` -TEST_CASE_3 = [ - { - # all the recursively parsed config items - "transform#1": { - "": "RandTorchVisiond", - "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}, - }, - "transform#1#": "RandTorchVisiond", - "transform#1#": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}, - "transform#1##keys": "image", - "transform#1##name": "ColorJitter", - "transform#1##brightness": 0.25, - }, - "transform#1", - RandTorchVisiond, -] - - -class TestReferenceResolver(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2] + ([TEST_CASE_3] if has_tv else [])) - def test_resolve(self, configs, expected_id, output_type): - locator = ComponentLocator() - resolver = ReferenceResolver() - # add items to resolver - for k, v in configs.items(): - if ConfigComponent.is_instantiable(v): - resolver.add(ConfigComponent(config=v, id=k, locator=locator)) - elif ConfigExpression.is_expression(v): - resolver.add(ConfigExpression(config=v, id=k, globals={"monai": monai, "torch": torch})) - else: - resolver.add(ConfigItem(config=v, id=k)) - - result = resolver.get_resolved_content(expected_id) - self.assertTrue(isinstance(result, output_type)) - - # test resolve all - resolver.resolved_content = {} - resolver.resolve_all() - result = resolver.get_resolved_content(expected_id) - self.assertTrue(isinstance(result, output_type)) - - # test lazy instantiation - item = resolver.get_item(expected_id, resolve=True) - config = item.get_config() - config[""] = False - item.update_config(config=config) - if isinstance(item, ConfigComponent): - result = item.instantiate() - else: - result = item.get_config() - self.assertTrue(isinstance(result, output_type)) - - def test_circular_references(self): - locator = ComponentLocator() - resolver = ReferenceResolver() - configs = {"A": "@B", "B": "@C", "C": "@A"} - for k, v in configs.items(): - resolver.add(ConfigComponent(config=v, id=k, locator=locator)) - for k in ["A", "B", "C"]: - with self.assertRaises(ValueError): - resolver.get_resolved_content(k) - - -if __name__ == "__main__": - unittest.main() From 01fdf0a4b08c45927722b645fcbb946bd143a956 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 19 Feb 2022 12:51:09 +0800 Subject: [PATCH 30/30] [DLMED] resolve conflicts Signed-off-by: Nic Ma --- docs/source/apps.rst | 3 - monai/apps/__init__.py | 1 + monai/apps/manifest/__init__.py | 2 + monai/apps/manifest/config_item.py | 335 ++++++++++++++++++++++++++++ 4 files changed, 338 insertions(+), 3 deletions(-) create mode 100644 monai/apps/manifest/config_item.py diff --git a/docs/source/apps.rst b/docs/source/apps.rst index d0fe131d85..4b1cdc6f43 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -44,9 +44,6 @@ Model Manifest .. autoclass:: ConfigItem :members: -.. autoclass:: ConfigParser - :members: - .. autoclass:: ReferenceResolver :members: diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index 893f7877d2..df085bddea 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -10,5 +10,6 @@ # limitations under the License. from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset +from .manifest import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/manifest/__init__.py b/monai/apps/manifest/__init__.py index 1e97f89407..d8919a5249 100644 --- a/monai/apps/manifest/__init__.py +++ b/monai/apps/manifest/__init__.py @@ -8,3 +8,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem diff --git a/monai/apps/manifest/config_item.py b/monai/apps/manifest/config_item.py new file mode 100644 index 0000000000..075d00b961 --- /dev/null +++ b/monai/apps/manifest/config_item.py @@ -0,0 +1,335 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import sys +import warnings +from abc import ABC, abstractmethod +from importlib import import_module +from typing import Any, Dict, List, Mapping, Optional, Sequence, Union + +from monai.utils import ensure_tuple, instantiate + +__all__ = ["ComponentLocator", "ConfigItem", "ConfigExpression", "ConfigComponent"] + + +class Instantiable(ABC): + """ + Base class for an instantiable object. + """ + + @abstractmethod + def is_disabled(self, *args: Any, **kwargs: Any) -> bool: + """ + Return a boolean flag to indicate whether the object should be instantiated. + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def instantiate(self, *args: Any, **kwargs: Any) -> object: + """ + Instantiate the target component and return the instance. + """ + raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") + + +class ComponentLocator: + """ + Scan all the available classes and functions in the MONAI package and map them with the module paths in a table. + It's used to locate the module path for provided component name. + + Args: + excludes: if any string of the `excludes` exists in the full module name, don't import this module. + + """ + + MOD_START = "monai" + + def __init__(self, excludes: Optional[Union[Sequence[str], str]] = None): + self.excludes = [] if excludes is None else ensure_tuple(excludes) + self._components_table: Optional[Dict[str, List]] = None + + def _find_module_names(self) -> List[str]: + """ + Find all the modules start with MOD_START and don't contain any of `excludes`. + + """ + return [ + m for m in sys.modules.keys() if m.startswith(self.MOD_START) and all(s not in m for s in self.excludes) + ] + + def _find_classes_or_functions(self, modnames: Union[Sequence[str], str]) -> Dict[str, List]: + """ + Find all the classes and functions in the modules with specified `modnames`. + + Args: + modnames: names of the target modules to find all the classes and functions. + + """ + table: Dict[str, List] = {} + # all the MONAI modules are already loaded by `load_submodules` + for modname in ensure_tuple(modnames): + try: + # scan all the classes and functions in the module + module = import_module(modname) + for name, obj in inspect.getmembers(module): + if (inspect.isclass(obj) or inspect.isfunction(obj)) and obj.__module__ == modname: + if name not in table: + table[name] = [] + table[name].append(modname) + except ModuleNotFoundError: + pass + return table + + def get_component_module_name(self, name: str) -> Optional[Union[List[str], str]]: + """ + Get the full module name of the class or function with specified ``name``. + If target component name exists in multiple packages or modules, return a list of full module names. + + Args: + name: name of the expected class or function. + + """ + if not isinstance(name, str): + raise ValueError(f"`name` must be a valid string, but got: {name}.") + if self._components_table is None: + # init component and module mapping table + self._components_table = self._find_classes_or_functions(self._find_module_names()) + + mods: Optional[Union[List[str], str]] = self._components_table.get(name, None) + if isinstance(mods, list) and len(mods) == 1: + mods = mods[0] + return mods + + +class ConfigItem: + """ + Basic data structure to represent a configuration item. + + A `ConfigItem` instance can optionally have a string id, so that other items can refer to it. + It has a build-in `config` property to store the configuration object. + + Args: + config: content of a config item, can be objects of any types, + a configuration resolver may interpret the content to generate a configuration object. + id: name of the current config item, defaults to empty string. + + """ + + def __init__(self, config: Any, id: str = "") -> None: + self.config = config + self.id = id + + def get_id(self) -> Optional[str]: + """ + Get the ID name of current config item, useful to identify config items during parsing. + + """ + return self.id + + def update_config(self, config: Any): + """ + Replace the content of `self.config` with new `config`. + A typical usage is to modify the initial config content at runtime. + + Args: + config: content of a `ConfigItem`. + + """ + self.config = config + + def get_config(self): + """ + Get the config content of current config item. + + """ + return self.config + + +class ConfigComponent(ConfigItem, Instantiable): + """ + Subclass of :py:class:`monai.apps.ConfigItem`, this class uses a dictionary with string keys to + represent a component of `class` or `function` and supports instantiation. + + Currently, four special keys (strings surrounded by ``<>``) are defined and interpreted beyond the regular literals: + + - class or function identifier of the python module, specified by one of the two keys. + - ``""``: indicates build-in python classes or functions such as "LoadImageDict". + - ``""``: full module name, such as "monai.transforms.LoadImageDict". + - ``""``: input arguments to the python module. + - ``""``: a flag to indicate whether to skip the instantiation. + + .. code-block:: python + + locator = ComponentLocator(excludes=["modules_to_exclude"]) + config = { + "": "LoadImaged", + "": { + "keys": ["image", "label"] + } + } + + configer = ConfigComponent(config, id="test", locator=locator) + image_loader = configer.instantiate() + print(image_loader) # + + Args: + config: content of a config item. + id: name of the current config item, defaults to empty string. + locator: a ``ComponentLocator`` to convert a module name string into the actual python module. + if `None`, a ``ComponentLocator(excludes=excludes)`` will be used. + excludes: if ``locator`` is None, create a new ``ComponentLocator`` with ``excludes``. + See also: :py:class:`monai.apps.manifest.ComponentLocator`. + + """ + + def __init__( + self, + config: Any, + id: str = "", + locator: Optional[ComponentLocator] = None, + excludes: Optional[Union[Sequence[str], str]] = None, + ) -> None: + super().__init__(config=config, id=id) + self.locator = ComponentLocator(excludes=excludes) if locator is None else locator + + @staticmethod + def is_instantiable(config: Any) -> bool: + """ + Check whether this config represents a `class` or `function` that is to be instantiated. + + Args: + config: input config content to check. + + """ + return isinstance(config, Mapping) and ("" in config or "" in config) + + def resolve_module_name(self): + """ + Resolve the target module name from current config content. + The config content must have ``""`` or ``""``. + When both are specified, ``""`` will be used. + + """ + config = dict(self.get_config()) + path = config.get("") + if path is not None: + if not isinstance(path, str): + raise ValueError(f"'' must be a string, but got: {path}.") + if "" in config: + warnings.warn(f"both '' and '', default to use '': {path}.") + return path + + name = config.get("") + if not isinstance(name, str): + raise ValueError("must provide a string for `` or `` of target component to instantiate.") + + module = self.locator.get_component_module_name(name) + if module is None: + raise ModuleNotFoundError(f"can not find component '{name}' in {self.locator.MOD_START} modules.") + if isinstance(module, list): + warnings.warn( + f"there are more than 1 component have name `{name}`: {module}, use the first one `{module[0]}." + f" if want to use others, please set its module path in `` directly." + ) + module = module[0] + return f"{module}.{name}" + + def resolve_args(self): + """ + Utility function used in `instantiate()` to resolve the arguments from current config content. + + """ + return self.get_config().get("", {}) + + def is_disabled(self) -> bool: # type: ignore + """ + Utility function used in `instantiate()` to check whether to skip the instantiation. + + """ + _is_disabled = self.get_config().get("", False) + return _is_disabled.lower().strip() == "true" if isinstance(_is_disabled, str) else bool(_is_disabled) + + def instantiate(self, **kwargs) -> object: # type: ignore + """ + Instantiate component based on ``self.config`` content. + The target component must be a `class` or a `function`, otherwise, return `None`. + + Args: + kwargs: args to override / add the config args when instantiation. + + """ + if not self.is_instantiable(self.get_config()) or self.is_disabled(): + # if not a class or function or marked as `disabled`, skip parsing and return `None` + return None + + modname = self.resolve_module_name() + args = self.resolve_args() + args.update(kwargs) + return instantiate(modname, **args) + + +class ConfigExpression(ConfigItem): + """ + Subclass of :py:class:`monai.apps.ConfigItem`, the `ConfigItem` represents an executable expression + (execute based on ``eval()``). + + See also: + + - https://docs.python.org/3/library/functions.html#eval. + + For example: + + .. code-block:: python + + import monai + from monai.apps.manifest import ConfigExpression + + config = "$monai.__version__" + expression = ConfigExpression(config, id="test", globals={"monai": monai}) + print(expression.execute()) + + Args: + config: content of a config item. + id: name of current config item, defaults to empty string. + globals: additional global context to evaluate the string. + + """ + + def __init__(self, config: Any, id: str = "", globals: Optional[Dict] = None) -> None: + super().__init__(config=config, id=id) + self.globals = globals + + def evaluate(self, locals: Optional[Dict] = None): + """ + Execute the current config content and return the result if it is expression, based on Python `eval()`. + For more details: https://docs.python.org/3/library/functions.html#eval. + + Args: + locals: besides ``globals``, may also have some local symbols used in the expression at runtime. + + """ + value = self.get_config() + if not ConfigExpression.is_expression(value): + return None + return eval(value[1:], self.globals, locals) + + @staticmethod + def is_expression(config: Union[Dict, List, str]) -> bool: + """ + Check whether the config is an executable expression string. + Currently, a string starts with ``"$"`` character is interpreted as an expression. + + Args: + config: input config content to check. + + """ + return isinstance(config, str) and config.startswith("$")