diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 9c212784ca..622a4865d1 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -20,7 +20,7 @@ from monai.config import DtypeLike, KeysCollection from monai.data.utils import correct_nifti_header_if_necessary from monai.transforms.utility.array import EnsureChannelFirst -from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import +from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import, require_pkg from .utils import is_supported_format @@ -132,6 +132,7 @@ def _stack_images(image_list: List, meta_dict: Dict): return np.stack(image_list, axis=0) +@require_pkg(pkg_name="itk") class ITKReader(ImageReader): """ Load medical images based on ITK library. @@ -317,6 +318,7 @@ def _get_array_data(self, img): return np.moveaxis(np_data, 0, -1) # channel last is compatible with `write_nifti` +@require_pkg(pkg_name="nibabel") class NibabelReader(ImageReader): """ Load NIfTI format images based on Nibabel library. @@ -564,6 +566,7 @@ def get_data(self, img): return _stack_images(img_array, compatible_meta), compatible_meta +@require_pkg(pkg_name="PIL") class PILReader(ImageReader): """ Load common 2D image format (supports PNG, JPG, BMP) file or files from provided path. diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 38a2861c8b..b0f4c02cf6 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -30,8 +30,7 @@ from monai.transforms.transform import Transform from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key -from monai.utils import InterpolateMode, ensure_tuple, optional_import -from monai.utils.module import look_up_option +from monai.utils import InterpolateMode, OptionalImportError, ensure_tuple, look_up_option, optional_import nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") @@ -126,6 +125,10 @@ def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np. for r in SUPPORTED_READERS: # set predefined readers as default try: self.register(SUPPORTED_READERS[r](*args, **kwargs)) + except OptionalImportError: + logging.getLogger(self.__class__.__name__).debug( + f"required package for reader {r} is not installed, or the version doesn't match requirement." + ) except TypeError: # the reader doesn't have the corresponding args/kwargs logging.getLogger(self.__class__.__name__).debug( f"{r} is not supported with the given parameters {args} {kwargs}." @@ -139,6 +142,10 @@ def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np. the_reader = look_up_option(_r.lower(), SUPPORTED_READERS) try: self.register(the_reader(*args, **kwargs)) + except OptionalImportError: + warnings.warn( + f"required package for reader {r} is not installed, or the version doesn't match requirement." + ) except TypeError: # the reader doesn't have the corresponding args/kwargs warnings.warn(f"{r} is not supported with the given parameters {args} {kwargs}.") self.register(the_reader()) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 57bbb0dd5b..42eba2e67f 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -70,6 +70,7 @@ look_up_option, min_version, optional_import, + require_pkg, version_leq, ) from .nvtx import Range diff --git a/monai/utils/module.py b/monai/utils/module.py index 130b89493e..12ffb27b82 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -11,9 +11,11 @@ import enum import sys import warnings +from functools import wraps from importlib import import_module from pkgutil import walk_packages from re import match +from types import FunctionType from typing import Any, Callable, Collection, Hashable, Iterable, List, Mapping, Tuple, cast import torch @@ -29,6 +31,7 @@ "look_up_option", "min_version", "optional_import", + "require_pkg", "load_submodules", "get_full_type_name", "get_package_version", @@ -347,6 +350,45 @@ def __call__(self, *_args, **_kwargs): return _LazyRaise(), False +def require_pkg( + pkg_name: str, version: str = "", version_checker: Callable[..., bool] = min_version, raise_error: bool = True +): + """ + Decorator function to check the required package installation. + + Args: + pkg_name: required package name, like: "itk", "nibabel", etc. + version: required version string used by the version_checker. + version_checker: a callable to check the module version, defaults to `monai.utils.min_version`. + raise_error: if True, raise `OptionalImportError` error if the required package is not installed + or the version doesn't match requirement, if False, print the error in a warning. + + """ + + def _decorator(obj): + is_func = isinstance(obj, FunctionType) + call_obj = obj if is_func else obj.__init__ + _, has = optional_import(module=pkg_name, version=version, version_checker=version_checker) + + @wraps(call_obj) + def _wrapper(*args, **kwargs): + if not has: + err_msg = f"required package `{pkg_name}` is not installed or the version doesn't match requirement." + if raise_error: + raise OptionalImportError(err_msg) + else: + warnings.warn(err_msg) + + return call_obj(*args, **kwargs) + + if is_func: + return _wrapper + obj.__init__ = _wrapper + return obj + + return _decorator + + def get_package_version(dep_name, default="NOT INSTALLED or UNKNOWN VERSION."): """ Try to load package and get version. If not found, return `default`. diff --git a/tests/test_init_reader.py b/tests/test_init_reader.py index d6737c26ca..0017c6acee 100644 --- a/tests/test_init_reader.py +++ b/tests/test_init_reader.py @@ -13,6 +13,7 @@ from monai.data import ITKReader, NibabelReader, NumpyReader, PILReader from monai.transforms import LoadImage, LoadImaged +from tests.utils import SkipIfNoModule class TestInitLoadImage(unittest.TestCase): @@ -26,6 +27,9 @@ def test_load_image(self): inst = LoadImaged("image", reader=r) self.assertIsInstance(inst, LoadImaged) + @SkipIfNoModule("itk") + @SkipIfNoModule("nibabel") + @SkipIfNoModule("PIL") def test_readers(self): inst = ITKReader() self.assertIsInstance(inst, ITKReader) diff --git a/tests/test_require_pkg.py b/tests/test_require_pkg.py new file mode 100644 index 0000000000..ff32a322bb --- /dev/null +++ b/tests/test_require_pkg.py @@ -0,0 +1,77 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from monai.utils import OptionalImportError, min_version, require_pkg + + +class TestRequirePkg(unittest.TestCase): + def test_class(self): + @require_pkg(pkg_name="torch", version="1.4", version_checker=min_version) + class TestClass: + pass + + TestClass() + + def test_function(self): + @require_pkg(pkg_name="torch", version="1.4", version_checker=min_version) + def test_func(x): + return x + + test_func(x=None) + + def test_warning(self): + @require_pkg(pkg_name="test123", raise_error=False) + def test_func(x): + return x + + test_func(x=None) + + def test_class_exception(self): + with self.assertRaises(OptionalImportError): + + @require_pkg(pkg_name="test123") + class TestClass: + pass + + TestClass() + + def test_class_version_exception(self): + with self.assertRaises(OptionalImportError): + + @require_pkg(pkg_name="torch", version="10000", version_checker=min_version) + class TestClass: + pass + + TestClass() + + def test_func_exception(self): + with self.assertRaises(OptionalImportError): + + @require_pkg(pkg_name="test123") + def test_func(x): + return x + + test_func(x=None) + + def test_func_versions_exception(self): + with self.assertRaises(OptionalImportError): + + @require_pkg(pkg_name="torch", version="10000", version_checker=min_version) + def test_func(x): + return x + + test_func(x=None) + + +if __name__ == "__main__": + unittest.main()