Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from monai.config import DtypeLike, KeysCollection
from monai.data.utils import correct_nifti_header_if_necessary
from monai.transforms.utility.array import EnsureChannelFirst
from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import
from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import, require_pkg

from .utils import is_supported_format

Expand Down Expand Up @@ -132,6 +132,7 @@ def _stack_images(image_list: List, meta_dict: Dict):
return np.stack(image_list, axis=0)


@require_pkg(pkg_name="itk")
class ITKReader(ImageReader):
"""
Load medical images based on ITK library.
Expand Down Expand Up @@ -317,6 +318,7 @@ def _get_array_data(self, img):
return np.moveaxis(np_data, 0, -1) # channel last is compatible with `write_nifti`


@require_pkg(pkg_name="nibabel")
class NibabelReader(ImageReader):
"""
Load NIfTI format images based on Nibabel library.
Expand Down Expand Up @@ -564,6 +566,7 @@ def get_data(self, img):
return _stack_images(img_array, compatible_meta), compatible_meta


@require_pkg(pkg_name="PIL")
class PILReader(ImageReader):
"""
Load common 2D image format (supports PNG, JPG, BMP) file or files from provided path.
Expand Down
11 changes: 9 additions & 2 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
from monai.transforms.transform import Transform
from monai.utils import GridSampleMode, GridSamplePadMode
from monai.utils import ImageMetaKey as Key
from monai.utils import InterpolateMode, ensure_tuple, optional_import
from monai.utils.module import look_up_option
from monai.utils import InterpolateMode, OptionalImportError, ensure_tuple, look_up_option, optional_import

nib, _ = optional_import("nibabel")
Image, _ = optional_import("PIL.Image")
Expand Down Expand Up @@ -126,6 +125,10 @@ def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np.
for r in SUPPORTED_READERS: # set predefined readers as default
try:
self.register(SUPPORTED_READERS[r](*args, **kwargs))
except OptionalImportError:
logging.getLogger(self.__class__.__name__).debug(
f"required package for reader {r} is not installed, or the version doesn't match requirement."
)
except TypeError: # the reader doesn't have the corresponding args/kwargs
logging.getLogger(self.__class__.__name__).debug(
f"{r} is not supported with the given parameters {args} {kwargs}."
Expand All @@ -139,6 +142,10 @@ def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np.
the_reader = look_up_option(_r.lower(), SUPPORTED_READERS)
try:
self.register(the_reader(*args, **kwargs))
except OptionalImportError:
warnings.warn(
f"required package for reader {r} is not installed, or the version doesn't match requirement."
)
except TypeError: # the reader doesn't have the corresponding args/kwargs
warnings.warn(f"{r} is not supported with the given parameters {args} {kwargs}.")
self.register(the_reader())
Expand Down
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
look_up_option,
min_version,
optional_import,
require_pkg,
version_leq,
)
from .nvtx import Range
Expand Down
42 changes: 42 additions & 0 deletions monai/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
import enum
import sys
import warnings
from functools import wraps
from importlib import import_module
from pkgutil import walk_packages
from re import match
from types import FunctionType
from typing import Any, Callable, Collection, Hashable, Iterable, List, Mapping, Tuple, cast

import torch
Expand All @@ -29,6 +31,7 @@
"look_up_option",
"min_version",
"optional_import",
"require_pkg",
"load_submodules",
"get_full_type_name",
"get_package_version",
Expand Down Expand Up @@ -347,6 +350,45 @@ def __call__(self, *_args, **_kwargs):
return _LazyRaise(), False


def require_pkg(
pkg_name: str, version: str = "", version_checker: Callable[..., bool] = min_version, raise_error: bool = True
):
"""
Decorator function to check the required package installation.

Args:
pkg_name: required package name, like: "itk", "nibabel", etc.
version: required version string used by the version_checker.
version_checker: a callable to check the module version, defaults to `monai.utils.min_version`.
raise_error: if True, raise `OptionalImportError` error if the required package is not installed
or the version doesn't match requirement, if False, print the error in a warning.

"""

def _decorator(obj):
is_func = isinstance(obj, FunctionType)
call_obj = obj if is_func else obj.__init__
_, has = optional_import(module=pkg_name, version=version, version_checker=version_checker)

@wraps(call_obj)
def _wrapper(*args, **kwargs):
if not has:
err_msg = f"required package `{pkg_name}` is not installed or the version doesn't match requirement."
if raise_error:
raise OptionalImportError(err_msg)
else:
warnings.warn(err_msg)

return call_obj(*args, **kwargs)

if is_func:
return _wrapper
obj.__init__ = _wrapper
return obj

return _decorator


def get_package_version(dep_name, default="NOT INSTALLED or UNKNOWN VERSION."):
"""
Try to load package and get version. If not found, return `default`.
Expand Down
4 changes: 4 additions & 0 deletions tests/test_init_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from monai.data import ITKReader, NibabelReader, NumpyReader, PILReader
from monai.transforms import LoadImage, LoadImaged
from tests.utils import SkipIfNoModule


class TestInitLoadImage(unittest.TestCase):
Expand All @@ -26,6 +27,9 @@ def test_load_image(self):
inst = LoadImaged("image", reader=r)
self.assertIsInstance(inst, LoadImaged)

@SkipIfNoModule("itk")
@SkipIfNoModule("nibabel")
@SkipIfNoModule("PIL")
def test_readers(self):
inst = ITKReader()
self.assertIsInstance(inst, ITKReader)
Expand Down
77 changes: 77 additions & 0 deletions tests/test_require_pkg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from monai.utils import OptionalImportError, min_version, require_pkg


class TestRequirePkg(unittest.TestCase):
def test_class(self):
@require_pkg(pkg_name="torch", version="1.4", version_checker=min_version)
class TestClass:
pass

TestClass()

def test_function(self):
@require_pkg(pkg_name="torch", version="1.4", version_checker=min_version)
def test_func(x):
return x

test_func(x=None)

def test_warning(self):
@require_pkg(pkg_name="test123", raise_error=False)
def test_func(x):
return x

test_func(x=None)

def test_class_exception(self):
with self.assertRaises(OptionalImportError):

@require_pkg(pkg_name="test123")
class TestClass:
pass

TestClass()

def test_class_version_exception(self):
with self.assertRaises(OptionalImportError):

@require_pkg(pkg_name="torch", version="10000", version_checker=min_version)
class TestClass:
pass

TestClass()

def test_func_exception(self):
with self.assertRaises(OptionalImportError):

@require_pkg(pkg_name="test123")
def test_func(x):
return x

test_func(x=None)

def test_func_versions_exception(self):
with self.assertRaises(OptionalImportError):

@require_pkg(pkg_name="torch", version="10000", version_checker=min_version)
def test_func(x):
return x

test_func(x=None)


if __name__ == "__main__":
unittest.main()