diff --git a/tests/test_lltm.py b/tests/test_lltm.py index 5c666a8794..41a9ea55fd 100644 --- a/tests/test_lltm.py +++ b/tests/test_lltm.py @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.networks.layers import LLTM +from tests.utils import SkipIfNoModule TEST_CASE_1 = [ {"input_features": 32, "state_size": 2}, @@ -25,6 +26,7 @@ class TestLLTM(unittest.TestCase): @parameterized.expand([TEST_CASE_1]) + @SkipIfNoModule("monai._C") def test_value(self, input_param, expected_h, expected_c): torch.manual_seed(0) x = torch.randn(4, 32) @@ -37,6 +39,7 @@ def test_value(self, input_param, expected_h, expected_c): torch.testing.assert_allclose(new_c, expected_c, rtol=0.0001, atol=1e-04) @parameterized.expand([TEST_CASE_1]) + @SkipIfNoModule("monai._C") def test_value_cuda(self, input_param, expected_h, expected_c): device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0") torch.manual_seed(0) diff --git a/tests/utils.py b/tests/utils.py index dbbd9ba4c3..0e1fb9aee4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -32,6 +32,18 @@ def skip_if_quick(obj): return unittest.skipIf(is_quick, "Skipping slow tests")(obj) +class SkipIfNoModule(object): + """Decorator to be used if test should be skipped + when optional module is not present.""" + + def __init__(self, module_name): + self.module_name = module_name + self.module_missing = not optional_import(self.module_name)[1] + + def __call__(self, obj): + return unittest.skipIf(self.module_missing, f"optional module not present: {self.module_name}")(obj) + + def make_nifti_image(array, affine=None): """ Create a temporary nifti image on the disk and return the image name.