From 7d9375f1262613a7f49559a2ae709812f926fd16 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sun, 31 Oct 2021 21:19:41 +0800 Subject: [PATCH 1/3] [DLMED] add decorator Signed-off-by: Nic Ma --- monai/data/image_reader.py | 5 ++++- monai/transforms/io/array.py | 11 ++++++++-- monai/utils/__init__.py | 1 + monai/utils/module.py | 42 ++++++++++++++++++++++++++++++++++++ 4 files changed, 56 insertions(+), 3 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 9c212784ca..b510ee0559 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -20,7 +20,7 @@ from monai.config import DtypeLike, KeysCollection from monai.data.utils import correct_nifti_header_if_necessary from monai.transforms.utility.array import EnsureChannelFirst -from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import +from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import, require_pkg from .utils import is_supported_format @@ -132,6 +132,7 @@ def _stack_images(image_list: List, meta_dict: Dict): return np.stack(image_list, axis=0) +@require_pkg(pgk_name="itk") class ITKReader(ImageReader): """ Load medical images based on ITK library. @@ -317,6 +318,7 @@ def _get_array_data(self, img): return np.moveaxis(np_data, 0, -1) # channel last is compatible with `write_nifti` +@require_pkg(pgk_name="nibabel") class NibabelReader(ImageReader): """ Load NIfTI format images based on Nibabel library. @@ -564,6 +566,7 @@ def get_data(self, img): return _stack_images(img_array, compatible_meta), compatible_meta +@require_pkg(pgk_name="PIL") class PILReader(ImageReader): """ Load common 2D image format (supports PNG, JPG, BMP) file or files from provided path. diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 38a2861c8b..b0f4c02cf6 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -30,8 +30,7 @@ from monai.transforms.transform import Transform from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key -from monai.utils import InterpolateMode, ensure_tuple, optional_import -from monai.utils.module import look_up_option +from monai.utils import InterpolateMode, OptionalImportError, ensure_tuple, look_up_option, optional_import nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") @@ -126,6 +125,10 @@ def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np. for r in SUPPORTED_READERS: # set predefined readers as default try: self.register(SUPPORTED_READERS[r](*args, **kwargs)) + except OptionalImportError: + logging.getLogger(self.__class__.__name__).debug( + f"required package for reader {r} is not installed, or the version doesn't match requirement." + ) except TypeError: # the reader doesn't have the corresponding args/kwargs logging.getLogger(self.__class__.__name__).debug( f"{r} is not supported with the given parameters {args} {kwargs}." @@ -139,6 +142,10 @@ def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np. the_reader = look_up_option(_r.lower(), SUPPORTED_READERS) try: self.register(the_reader(*args, **kwargs)) + except OptionalImportError: + warnings.warn( + f"required package for reader {r} is not installed, or the version doesn't match requirement." + ) except TypeError: # the reader doesn't have the corresponding args/kwargs warnings.warn(f"{r} is not supported with the given parameters {args} {kwargs}.") self.register(the_reader()) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 57bbb0dd5b..42eba2e67f 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -70,6 +70,7 @@ look_up_option, min_version, optional_import, + require_pkg, version_leq, ) from .nvtx import Range diff --git a/monai/utils/module.py b/monai/utils/module.py index 130b89493e..437990c4c7 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -11,9 +11,11 @@ import enum import sys import warnings +from functools import wraps from importlib import import_module from pkgutil import walk_packages from re import match +from types import FunctionType from typing import Any, Callable, Collection, Hashable, Iterable, List, Mapping, Tuple, cast import torch @@ -29,6 +31,7 @@ "look_up_option", "min_version", "optional_import", + "require_pkg", "load_submodules", "get_full_type_name", "get_package_version", @@ -347,6 +350,45 @@ def __call__(self, *_args, **_kwargs): return _LazyRaise(), False +def require_pkg( + pgk_name: str, version: str = "", version_checker: Callable[..., bool] = min_version, raise_error: bool = True +): + """ + Decorator function to check the required package installation. + + Args: + pgk_name: required package name, like: "itk", "nibabel", etc. + version: required version string used by the version_checker. + version_checker: a callable to check the module version, defaults to `monai.utils.min_version`. + raise_error: if True, raise `OptionalImportError` error if the required package is not installed + or the version doesn't match requirement, if False, print the error in a warning. + + """ + + def _decorator(obj): + is_func = isinstance(obj, FunctionType) + call_obj = obj if is_func else obj.__init__ + + @wraps(call_obj) + def _wrapper(*args, **kwargs): + _, has = optional_import(module=pgk_name, version=version, version_checker=version_checker) + if not has: + err_msg = f"required package `{pgk_name}` is not installed or the version doesn't match requirement." + if raise_error: + raise OptionalImportError(err_msg) + else: + warnings.warn(err_msg) + + return call_obj(*args, **kwargs) + + if is_func: + return _wrapper + obj.__init__ = _wrapper + return obj + + return _decorator + + def get_package_version(dep_name, default="NOT INSTALLED or UNKNOWN VERSION."): """ Try to load package and get version. If not found, return `default`. From d4050766088ccb509ab7efb1fd8047f2f6e0c37f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sun, 31 Oct 2021 21:54:32 +0800 Subject: [PATCH 2/3] [DLMED] add unit tests Signed-off-by: Nic Ma --- monai/utils/module.py | 2 +- tests/test_init_reader.py | 4 ++ tests/test_require_pkg.py | 77 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 tests/test_require_pkg.py diff --git a/monai/utils/module.py b/monai/utils/module.py index 437990c4c7..8bd4fb842a 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -368,10 +368,10 @@ def require_pkg( def _decorator(obj): is_func = isinstance(obj, FunctionType) call_obj = obj if is_func else obj.__init__ + _, has = optional_import(module=pgk_name, version=version, version_checker=version_checker) @wraps(call_obj) def _wrapper(*args, **kwargs): - _, has = optional_import(module=pgk_name, version=version, version_checker=version_checker) if not has: err_msg = f"required package `{pgk_name}` is not installed or the version doesn't match requirement." if raise_error: diff --git a/tests/test_init_reader.py b/tests/test_init_reader.py index d6737c26ca..0017c6acee 100644 --- a/tests/test_init_reader.py +++ b/tests/test_init_reader.py @@ -13,6 +13,7 @@ from monai.data import ITKReader, NibabelReader, NumpyReader, PILReader from monai.transforms import LoadImage, LoadImaged +from tests.utils import SkipIfNoModule class TestInitLoadImage(unittest.TestCase): @@ -26,6 +27,9 @@ def test_load_image(self): inst = LoadImaged("image", reader=r) self.assertIsInstance(inst, LoadImaged) + @SkipIfNoModule("itk") + @SkipIfNoModule("nibabel") + @SkipIfNoModule("PIL") def test_readers(self): inst = ITKReader() self.assertIsInstance(inst, ITKReader) diff --git a/tests/test_require_pkg.py b/tests/test_require_pkg.py new file mode 100644 index 0000000000..6d66f0a2ae --- /dev/null +++ b/tests/test_require_pkg.py @@ -0,0 +1,77 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from monai.utils import OptionalImportError, min_version, require_pkg + + +class TestRequirePkg(unittest.TestCase): + def test_class(self): + @require_pkg(pgk_name="torch", version="1.4", version_checker=min_version) + class TestClass: + pass + + TestClass() + + def test_function(self): + @require_pkg(pgk_name="torch", version="1.4", version_checker=min_version) + def test_func(x): + return x + + test_func(x=None) + + def test_warning(self): + @require_pkg(pgk_name="test123", raise_error=False) + def test_func(x): + return x + + test_func(x=None) + + def test_class_exception(self): + with self.assertRaises(OptionalImportError): + + @require_pkg(pgk_name="test123") + class TestClass: + pass + + TestClass() + + def test_class_version_exception(self): + with self.assertRaises(OptionalImportError): + + @require_pkg(pgk_name="torch", version="10000", version_checker=min_version) + class TestClass: + pass + + TestClass() + + def test_func_exception(self): + with self.assertRaises(OptionalImportError): + + @require_pkg(pgk_name="test123") + def test_func(x): + return x + + test_func(x=None) + + def test_func_versions_exception(self): + with self.assertRaises(OptionalImportError): + + @require_pkg(pgk_name="torch", version="10000", version_checker=min_version) + def test_func(x): + return x + + test_func(x=None) + + +if __name__ == "__main__": + unittest.main() From 9592c4a3a1d65109a9fa8f2102d296168d194f04 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 1 Nov 2021 00:10:41 +0800 Subject: [PATCH 3/3] [DLMED] fix typo Signed-off-by: Nic Ma --- monai/data/image_reader.py | 6 +++--- monai/utils/module.py | 8 ++++---- tests/test_require_pkg.py | 14 +++++++------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index b510ee0559..622a4865d1 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -132,7 +132,7 @@ def _stack_images(image_list: List, meta_dict: Dict): return np.stack(image_list, axis=0) -@require_pkg(pgk_name="itk") +@require_pkg(pkg_name="itk") class ITKReader(ImageReader): """ Load medical images based on ITK library. @@ -318,7 +318,7 @@ def _get_array_data(self, img): return np.moveaxis(np_data, 0, -1) # channel last is compatible with `write_nifti` -@require_pkg(pgk_name="nibabel") +@require_pkg(pkg_name="nibabel") class NibabelReader(ImageReader): """ Load NIfTI format images based on Nibabel library. @@ -566,7 +566,7 @@ def get_data(self, img): return _stack_images(img_array, compatible_meta), compatible_meta -@require_pkg(pgk_name="PIL") +@require_pkg(pkg_name="PIL") class PILReader(ImageReader): """ Load common 2D image format (supports PNG, JPG, BMP) file or files from provided path. diff --git a/monai/utils/module.py b/monai/utils/module.py index 8bd4fb842a..12ffb27b82 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -351,13 +351,13 @@ def __call__(self, *_args, **_kwargs): def require_pkg( - pgk_name: str, version: str = "", version_checker: Callable[..., bool] = min_version, raise_error: bool = True + pkg_name: str, version: str = "", version_checker: Callable[..., bool] = min_version, raise_error: bool = True ): """ Decorator function to check the required package installation. Args: - pgk_name: required package name, like: "itk", "nibabel", etc. + pkg_name: required package name, like: "itk", "nibabel", etc. version: required version string used by the version_checker. version_checker: a callable to check the module version, defaults to `monai.utils.min_version`. raise_error: if True, raise `OptionalImportError` error if the required package is not installed @@ -368,12 +368,12 @@ def require_pkg( def _decorator(obj): is_func = isinstance(obj, FunctionType) call_obj = obj if is_func else obj.__init__ - _, has = optional_import(module=pgk_name, version=version, version_checker=version_checker) + _, has = optional_import(module=pkg_name, version=version, version_checker=version_checker) @wraps(call_obj) def _wrapper(*args, **kwargs): if not has: - err_msg = f"required package `{pgk_name}` is not installed or the version doesn't match requirement." + err_msg = f"required package `{pkg_name}` is not installed or the version doesn't match requirement." if raise_error: raise OptionalImportError(err_msg) else: diff --git a/tests/test_require_pkg.py b/tests/test_require_pkg.py index 6d66f0a2ae..ff32a322bb 100644 --- a/tests/test_require_pkg.py +++ b/tests/test_require_pkg.py @@ -16,21 +16,21 @@ class TestRequirePkg(unittest.TestCase): def test_class(self): - @require_pkg(pgk_name="torch", version="1.4", version_checker=min_version) + @require_pkg(pkg_name="torch", version="1.4", version_checker=min_version) class TestClass: pass TestClass() def test_function(self): - @require_pkg(pgk_name="torch", version="1.4", version_checker=min_version) + @require_pkg(pkg_name="torch", version="1.4", version_checker=min_version) def test_func(x): return x test_func(x=None) def test_warning(self): - @require_pkg(pgk_name="test123", raise_error=False) + @require_pkg(pkg_name="test123", raise_error=False) def test_func(x): return x @@ -39,7 +39,7 @@ def test_func(x): def test_class_exception(self): with self.assertRaises(OptionalImportError): - @require_pkg(pgk_name="test123") + @require_pkg(pkg_name="test123") class TestClass: pass @@ -48,7 +48,7 @@ class TestClass: def test_class_version_exception(self): with self.assertRaises(OptionalImportError): - @require_pkg(pgk_name="torch", version="10000", version_checker=min_version) + @require_pkg(pkg_name="torch", version="10000", version_checker=min_version) class TestClass: pass @@ -57,7 +57,7 @@ class TestClass: def test_func_exception(self): with self.assertRaises(OptionalImportError): - @require_pkg(pgk_name="test123") + @require_pkg(pkg_name="test123") def test_func(x): return x @@ -66,7 +66,7 @@ def test_func(x): def test_func_versions_exception(self): with self.assertRaises(OptionalImportError): - @require_pkg(pgk_name="torch", version="10000", version_checker=min_version) + @require_pkg(pkg_name="torch", version="10000", version_checker=min_version) def test_func(x): return x