diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 779a42f9cb..2224b42a74 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -179,6 +179,11 @@ Layers .. autoclass:: GaussianFilter :members: +`HilbertTransform` +~~~~~~~~~~~~~~~~~~ +.. autoclass:: HilbertTransform + :members: + `Affine Transform` ~~~~~~~~~~~~~~~~~~ .. autoclass:: monai.networks.layers.AffineTransform diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 65fceebe8d..f7e075f376 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -216,6 +216,12 @@ Intensity :members: :special-members: __call__ +`DetectEnvelope` +""""""""""""""""""""" +.. autoclass:: DetectEnvelope + :members: + :special-members: __call__ + IO ^^ diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index a726975138..bd800a7c91 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -17,12 +17,15 @@ from torch import nn from torch.autograd import Function +from monai.config import get_torch_version_tuple from monai.networks.layers.convutils import gaussian_1d, same_padding -from monai.utils import SkipMode, ensure_tuple_rep, optional_import +from monai.utils import InvalidPyTorchVersionError, SkipMode, ensure_tuple_rep, optional_import _C, _ = optional_import("monai._C") +if tuple(int(s) for s in torch.__version__.split(".")[0:2]) >= (1, 7): + fft, _ = optional_import("torch.fft") -__all__ = ["SkipConnection", "Flatten", "GaussianFilter", "LLTM", "Reshape", "separable_filtering"] +__all__ = ["SkipConnection", "Flatten", "GaussianFilter", "LLTM", "Reshape", "separable_filtering", "HilbertTransform"] class SkipConnection(nn.Module): @@ -130,6 +133,72 @@ def _conv(input_: torch.Tensor, d: int) -> torch.Tensor: return _conv(x, spatial_dims - 1) +class HilbertTransform(nn.Module): + """ + Determine the analytical signal of a Tensor along a particular axis. + Requires PyTorch 1.7.0+ and the PyTorch FFT module (which is not included in NVIDIA PyTorch Release 20.10). + + Args: + axis: Axis along which to apply Hilbert transform. Default 2 (first spatial dimension). + N: Number of Fourier components (i.e. FFT size). Default: ``x.shape[axis]``. + """ + + def __init__(self, axis: int = 2, n: Union[int, None] = None) -> None: + + if get_torch_version_tuple() < (1, 7): + raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__) + + super().__init__() + self.axis = axis + self.n = n + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Tensor or array-like to transform. Must be real and in shape ``[Batch, chns, spatial1, spatial2, ...]``. + Returns: + torch.Tensor: Analytical signal of ``x``, transformed along axis specified in ``self.axis`` using + FFT of size ``self.N``. The absolute value of ``x_ht`` relates to the envelope of ``x`` along axis ``self.axis``. + """ + + # Make input a real tensor + x = torch.as_tensor(x, device=x.device if torch.is_tensor(x) else None) + if torch.is_complex(x): + raise ValueError("x must be real.") + else: + x = x.to(dtype=torch.float) + + if (self.axis < 0) or (self.axis > len(x.shape) - 1): + raise ValueError("Invalid axis for shape of x.") + + n = x.shape[self.axis] if self.n is None else self.n + if n <= 0: + raise ValueError("N must be positive.") + x = torch.as_tensor(x, dtype=torch.complex64) + # Create frequency axis + f = torch.cat( + [ + torch.true_divide(torch.arange(0, (n - 1) // 2 + 1, device=x.device), float(n)), + torch.true_divide(torch.arange(-(n // 2), 0, device=x.device), float(n)), + ] + ) + xf = fft.fft(x, n=n, dim=self.axis) + # Create step function + u = torch.heaviside(f, torch.tensor([0.5], device=f.device)) + u = torch.as_tensor(u, dtype=x.dtype, device=u.device) + new_dims_before = self.axis + new_dims_after = len(xf.shape) - self.axis - 1 + for _ in range(new_dims_before): + u.unsqueeze_(0) + for _ in range(new_dims_after): + u.unsqueeze_(-1) + + ht = fft.ifft(xf * 2 * u, dim=self.axis) + + # Apply transform + return torch.as_tensor(ht, device=ht.device, dtype=ht.dtype) + + class GaussianFilter(nn.Module): def __init__( self, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index a464109417..b1b11fe6f4 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -20,10 +20,11 @@ import numpy as np import torch -from monai.networks.layers import GaussianFilter +from monai.config import get_torch_version_tuple +from monai.networks.layers import GaussianFilter, HilbertTransform from monai.transforms.compose import Randomizable, Transform from monai.transforms.utils import rescale_array -from monai.utils import dtype_torch_to_numpy, ensure_tuple_size +from monai.utils import InvalidPyTorchVersionError, dtype_torch_to_numpy, ensure_tuple_size class RandGaussianNoise(Randomizable, Transform): @@ -509,6 +510,46 @@ def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> n return img * mask_data_ +class DetectEnvelope(Transform): + """ + Find the envelope of the input data along the requested axis using a Hilbert transform. + Requires PyTorch 1.7.0+ and the PyTorch FFT module (which is not included in NVIDIA PyTorch Release 20.10). + + Args: + axis: Axis along which to detect the envelope. Default 1, i.e. the first spatial dimension. + N: FFT size. Default img.shape[axis]. Input will be zero-padded or truncated to this size along dimension + ``axis``. + + """ + + def __init__(self, axis: int = 1, n: Union[int, None] = None) -> None: + + if get_torch_version_tuple() < (1, 7): + raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__) + + if axis < 0: + raise ValueError("axis must be zero or positive.") + + self.axis = axis + self.n = n + + def __call__(self, img: np.ndarray) -> np.ndarray: + """ + + Args: + img: numpy.ndarray containing input data. Must be real and in shape [channels, spatial1, spatial2, ...]. + + Returns: + np.ndarray containing envelope of data in img along the specified axis. + + """ + # add one to transform axis because a batch axis will be added at dimension 0 + hilbert_transform = HilbertTransform(self.axis + 1, self.n) + # convert to Tensor and add Batch axis expected by HilbertTransform + input_data = torch.as_tensor(np.ascontiguousarray(img)).unsqueeze(0) + return np.abs(hilbert_transform(input_data).squeeze(0).numpy()) + + class GaussianSmooth(Transform): """ Apply Gaussian smooth to the input data based on specified `sigma` parameter. diff --git a/monai/utils/module.py b/monai/utils/module.py index 0edf9047ac..039c4f9fb5 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -21,6 +21,7 @@ OPTIONAL_IMPORT_MSG_FMT = "{}" __all__ = [ + "InvalidPyTorchVersionError", "OptionalImportError", "exact_version", "export", @@ -105,6 +106,17 @@ def exact_version(the_module, version_str: str = "") -> bool: return bool(the_module.__version__ == version_str) +class InvalidPyTorchVersionError(Exception): + """ + Raised when called function or method requires a more recent + PyTorch version than that installed. + """ + + def __init__(self, required_version, name): + message = f"{name} requires PyTorch version {required_version} or later" + super().__init__(message) + + class OptionalImportError(ImportError): """ Could not import APIs from an optional dependency. diff --git a/tests/test_detect_envelope.py b/tests/test_detect_envelope.py new file mode 100644 index 0000000000..aec014731b --- /dev/null +++ b/tests/test_detect_envelope.py @@ -0,0 +1,165 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import DetectEnvelope +from monai.utils import InvalidPyTorchVersionError, OptionalImportError +from tests.utils import SkipIfAtLeastPyTorchVersion, SkipIfBeforePyTorchVersion, SkipIfModule, SkipIfNoModule + +n_samples = 500 +hann_windowed_sine = np.sin(2 * np.pi * 10 * np.linspace(0, 1, n_samples)) * np.hanning(n_samples) + +# SINGLE-CHANNEL VALUE TESTS +# using np.expand_dims() to add length 1 channel dimension at dimension 0 + +TEST_CASE_1D_SINE = [ + {}, # args (empty, so use default) + np.expand_dims(hann_windowed_sine, 0), # Input data: Hann windowed sine wave + np.expand_dims(np.hanning(n_samples), 0), # Expected output: the Hann window + 1e-4, # absolute tolerance +] + +TEST_CASE_2D_SINE = [ + {}, # args (empty, so use default (i.e. process along first spatial dimension, axis=1) + # Create 10 identical windowed sine waves as a 2D numpy array + np.expand_dims(np.stack([hann_windowed_sine] * 10, axis=1), 0), + # Expected output: Set of 10 identical Hann windows + np.expand_dims(np.stack([np.hanning(n_samples)] * 10, axis=1), 0), + 1e-4, # absolute tolerance +] + +TEST_CASE_3D_SINE = [ + {}, # args (empty, so use default (i.e. process along first spatial dimension, axis=1) + # Create 100 identical windowed sine waves as a (n_samples x 10 x 10) 3D numpy array + np.expand_dims(np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), 0), + # Expected output: Set of 100 identical Hann windows in (n_samples x 10 x 10) 3D numpy array + np.expand_dims(np.stack([np.stack([np.hanning(n_samples)] * 10, axis=1)] * 10, axis=2), 0), + 1e-4, # absolute tolerance +] + +TEST_CASE_2D_SINE_AXIS_1 = [ + {"axis": 2}, # set axis argument to 1 + # Create 10 identical windowed sine waves as a 2D numpy array + np.expand_dims(np.stack([hann_windowed_sine] * 10, axis=1), 0), + # Expected output: absolute value of each sample of the waveform, repeated (i.e. flat envelopes) + np.expand_dims(np.abs(np.repeat(hann_windowed_sine, 10).reshape((n_samples, 10))), 0), + 1e-4, # absolute tolerance +] + +TEST_CASE_1D_SINE_PADDING_N = [ + {"n": 512}, # args (empty, so use default) + np.expand_dims(hann_windowed_sine, 0), # Input data: Hann windowed sine wave + np.expand_dims(np.concatenate([np.hanning(500), np.zeros(12)]), 0), # Expected output: the Hann window + 1e-3, # absolute tolerance +] + +# MULTI-CHANNEL VALUE TEST + +TEST_CASE_2_CHAN_3D_SINE = [ + {}, # args (empty, so use default (i.e. process along first spatial dimension, axis=1) + # Create 100 identical windowed sine waves as a (n_samples x 10 x 10) 3D numpy array, twice (2 channels) + np.stack([np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2)] * 2, axis=0), + # Expected output: Set of 100 identical Hann windows in (n_samples x 10 x 10) 3D numpy array, twice (2 channels) + np.stack([np.stack([np.stack([np.hanning(n_samples)] * 10, axis=1)] * 10, axis=2)] * 2, axis=0), + 1e-4, # absolute tolerance +] + +# EXCEPTION TESTS + +TEST_CASE_INVALID_AXIS_1 = [ + {"axis": 3}, # set axis argument to 3 when only 3 dimensions (1 channel + 2 spatial) + np.expand_dims(np.stack([hann_windowed_sine] * 10, axis=1), 0), # Create 2D dataset + "__call__", # method expected to raise exception +] + +TEST_CASE_INVALID_AXIS_2 = [ + {"axis": -1}, # set axis argument negative + np.expand_dims(np.stack([hann_windowed_sine] * 10, axis=1), 0), # Create 2D dataset + "__init__", # method expected to raise exception +] + +TEST_CASE_INVALID_N = [ + {"n": 0}, # set FFT length to zero + np.expand_dims(np.stack([hann_windowed_sine] * 10, axis=1), 0), # Create 2D dataset + "__call__", # method expected to raise exception +] + +TEST_CASE_INVALID_DTYPE = [ + {}, + np.expand_dims(np.array(hann_windowed_sine, dtype=np.complex), 0), # complex numbers are invalid + "__call__", # method expected to raise exception +] + +TEST_CASE_INVALID_IMG_LEN = [ + {}, + np.expand_dims(np.array([]), 0), # empty array is invalid + "__call__", # method expected to raise exception +] + +TEST_CASE_INVALID_OBJ = [{}, "a string", "__call__"] # method expected to raise exception + + +@SkipIfBeforePyTorchVersion((1, 7)) +@SkipIfNoModule("torch.fft") +class TestDetectEnvelope(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_1D_SINE, + TEST_CASE_2D_SINE, + TEST_CASE_3D_SINE, + TEST_CASE_2D_SINE_AXIS_1, + TEST_CASE_1D_SINE_PADDING_N, + TEST_CASE_2_CHAN_3D_SINE, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = DetectEnvelope(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) + + @parameterized.expand( + [ + TEST_CASE_INVALID_AXIS_1, + TEST_CASE_INVALID_AXIS_2, + TEST_CASE_INVALID_N, + TEST_CASE_INVALID_DTYPE, + TEST_CASE_INVALID_IMG_LEN, + ] + ) + def test_value_error(self, arguments, image, method): + if method == "__init__": + self.assertRaises(ValueError, DetectEnvelope, **arguments) + elif method == "__call__": + self.assertRaises(ValueError, DetectEnvelope(**arguments), image) + else: + raise ValueError("Expected raising method invalid. Should be __init__ or __call__.") + + +@SkipIfBeforePyTorchVersion((1, 7)) +@SkipIfModule("torch.fft") +class TestHilbertTransformNoFFTMod(unittest.TestCase): + def test_no_fft_module_error(self): + self.assertRaises(OptionalImportError, DetectEnvelope(), np.random.rand(1, 10)) + + +@SkipIfAtLeastPyTorchVersion((1, 7)) +class TestDetectEnvelopeInvalidPyTorch(unittest.TestCase): + def test_invalid_pytorch_error(self): + with self.assertRaises(InvalidPyTorchVersionError) as cm: + DetectEnvelope() + self.assertEqual("DetectEnvelope requires PyTorch version 1.7.0 or later", str(cm.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_hilbert_transform.py b/tests/test_hilbert_transform.py new file mode 100644 index 0000000000..1e9e0e660f --- /dev/null +++ b/tests/test_hilbert_transform.py @@ -0,0 +1,226 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers import HilbertTransform +from monai.utils import InvalidPyTorchVersionError, OptionalImportError +from tests.utils import ( + SkipIfAtLeastPyTorchVersion, + SkipIfBeforePyTorchVersion, + SkipIfModule, + SkipIfNoModule, + skip_if_no_cuda, +) + + +def create_expected_numpy_output(input_datum, **kwargs): + + x = np.fft.fft( + input_datum.cpu().numpy() if input_datum.device.type == "cuda" else input_datum.numpy(), + **kwargs, + ) + f = np.fft.fftfreq(x.shape[kwargs["axis"]]) + u = np.heaviside(f, 0.5) + new_dims_before = kwargs["axis"] + new_dims_after = len(x.shape) - kwargs["axis"] - 1 + for _ in range(new_dims_before): + u = np.expand_dims(u, 0) + for _ in range(new_dims_after): + u = np.expand_dims(u, -1) + ht = np.fft.ifft(x * 2 * u, axis=kwargs["axis"]) + + return ht + + +cpu = torch.device("cpu") +n_samples = 500 +hann_windowed_sine = np.sin(2 * np.pi * 10 * np.linspace(0, 1, n_samples)) * np.hanning(n_samples) + +# CPU TEST DATA + +cpu_input_data = dict() +cpu_input_data["1D"] = torch.as_tensor(hann_windowed_sine, device=cpu).unsqueeze(0).unsqueeze(0) +cpu_input_data["2D"] = ( + torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu).unsqueeze(0).unsqueeze(0) +) +cpu_input_data["3D"] = ( + torch.as_tensor(np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=cpu) + .unsqueeze(0) + .unsqueeze(0) +) +cpu_input_data["1D 2CH"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu).unsqueeze(0) +cpu_input_data["2D 2CH"] = torch.as_tensor( + np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=cpu +).unsqueeze(0) + +# SINGLE-CHANNEL CPU VALUE TESTS + +TEST_CASE_1D_SINE_CPU = [ + {}, # args (empty, so use default) + cpu_input_data["1D"], # Input data: Random 1D signal + create_expected_numpy_output(cpu_input_data["1D"], axis=2), # Expected output: FFT of signal + 1e-5, # absolute tolerance +] + +TEST_CASE_2D_SINE_CPU = [ + {}, # args (empty, so use default) + cpu_input_data["2D"], # Input data: Random 1D signal + create_expected_numpy_output(cpu_input_data["2D"], axis=2), # Expected output: FFT of signal + 1e-5, # absolute tolerance +] + +TEST_CASE_3D_SINE_CPU = [ + {}, # args (empty, so use default) + cpu_input_data["3D"], # Input data: Random 1D signal + create_expected_numpy_output(cpu_input_data["3D"], axis=2), + 1e-5, # absolute tolerance +] + +# MULTICHANNEL CPU VALUE TESTS, PROCESS ALONG FIRST SPATIAL AXIS + +TEST_CASE_1D_2CH_SINE_CPU = [ + {}, # args (empty, so use default) + cpu_input_data["1D 2CH"], # Input data: Random 1D signal + create_expected_numpy_output(cpu_input_data["1D 2CH"], axis=2), + 1e-5, # absolute tolerance +] + +TEST_CASE_2D_2CH_SINE_CPU = [ + {}, # args (empty, so use default) + cpu_input_data["2D 2CH"], # Input data: Random 1D signal + create_expected_numpy_output(cpu_input_data["2D 2CH"], axis=2), + 1e-5, # absolute tolerance +] + +# GPU TEST DATA + +if torch.cuda.is_available(): + gpu = torch.device("cuda") + + gpu_input_data = dict() + gpu_input_data["1D"] = torch.as_tensor(hann_windowed_sine, device=gpu).unsqueeze(0).unsqueeze(0) + gpu_input_data["2D"] = ( + torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=gpu).unsqueeze(0).unsqueeze(0) + ) + gpu_input_data["3D"] = ( + torch.as_tensor(np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=gpu) + .unsqueeze(0) + .unsqueeze(0) + ) + gpu_input_data["1D 2CH"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=gpu).unsqueeze(0) + gpu_input_data["2D 2CH"] = torch.as_tensor( + np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=gpu + ).unsqueeze(0) + + # SINGLE CHANNEL GPU VALUE TESTS + + TEST_CASE_1D_SINE_GPU = [ + {}, # args (empty, so use default) + gpu_input_data["1D"], # Input data: Random 1D signal + create_expected_numpy_output(gpu_input_data["1D"], axis=2), # Expected output: FFT of signal + 1e-5, # absolute tolerance + ] + + TEST_CASE_2D_SINE_GPU = [ + {}, # args (empty, so use default) + gpu_input_data["2D"], # Input data: Random 1D signal + create_expected_numpy_output(gpu_input_data["2D"], axis=2), # Expected output: FFT of signal + 1e-5, # absolute tolerance + ] + + TEST_CASE_3D_SINE_GPU = [ + {}, # args (empty, so use default) + gpu_input_data["3D"], # Input data: Random 1D signal + create_expected_numpy_output(gpu_input_data["3D"], axis=2), # Expected output: FFT of signal + 1e-5, # absolute tolerance + ] + + # MULTICHANNEL GPU VALUE TESTS, PROCESS ALONG FIRST SPATIAL AXIS + + TEST_CASE_1D_2CH_SINE_GPU = [ + {}, # args (empty, so use default) + gpu_input_data["1D 2CH"], # Input data: Random 1D signal + create_expected_numpy_output(gpu_input_data["1D 2CH"], axis=2), + 1e-5, # absolute tolerance + ] + + TEST_CASE_2D_2CH_SINE_GPU = [ + {}, # args (empty, so use default) + gpu_input_data["2D 2CH"], # Input data: Random 1D signal + create_expected_numpy_output(gpu_input_data["2D 2CH"], axis=2), + 1e-5, # absolute tolerance + ] + +# TESTS CHECKING PADDING, AXIS SELECTION ETC ARE COVERED BY test_detect_envelope.py + + +@SkipIfBeforePyTorchVersion((1, 7)) +@SkipIfNoModule("torch.fft") +class TestHilbertTransformCPU(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_1D_SINE_CPU, + TEST_CASE_2D_SINE_CPU, + TEST_CASE_3D_SINE_CPU, + TEST_CASE_1D_2CH_SINE_CPU, + TEST_CASE_2D_2CH_SINE_CPU, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = HilbertTransform(**arguments)(image) + result = result.squeeze(0).squeeze(0).numpy() + np.testing.assert_allclose(result, expected_data.squeeze(), atol=atol) + + +@skip_if_no_cuda +@SkipIfBeforePyTorchVersion((1, 7)) +@SkipIfNoModule("torch.fft") +class TestHilbertTransformGPU(unittest.TestCase): + @parameterized.expand( + [] + if not torch.cuda.is_available() + else [ + TEST_CASE_1D_SINE_GPU, + TEST_CASE_2D_SINE_GPU, + TEST_CASE_3D_SINE_GPU, + TEST_CASE_1D_2CH_SINE_GPU, + TEST_CASE_2D_2CH_SINE_GPU, + ], + skip_on_empty=True, + ) + def test_value(self, arguments, image, expected_data, atol): + result = HilbertTransform(**arguments)(image) + result = result.squeeze(0).squeeze(0).cpu().numpy() + np.testing.assert_allclose(result, expected_data.squeeze(), atol=atol) + + +@SkipIfBeforePyTorchVersion((1, 7)) +@SkipIfModule("torch.fft") +class TestHilbertTransformNoFFTMod(unittest.TestCase): + def test_no_fft_module_error(self): + self.assertRaises(OptionalImportError, HilbertTransform(), torch.randn(1, 1, 10)) + + +@SkipIfAtLeastPyTorchVersion((1, 7)) +class TestHilbertTransformInvalidPyTorch(unittest.TestCase): + def test_invalid_pytorch_error(self): + with self.assertRaises(InvalidPyTorchVersionError) as cm: + HilbertTransform() + self.assertEqual("HilbertTransform requires PyTorch version 1.7.0 or later", str(cm.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils.py b/tests/utils.py index 6c717264ac..763ddaa9b2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -28,6 +28,7 @@ import torch import torch.distributed as dist +from monai.config import get_torch_version_tuple from monai.data import create_test_image_2d, create_test_image_3d from monai.utils import ensure_tuple, optional_import, set_determinism @@ -66,6 +67,18 @@ def __call__(self, obj): return unittest.skipIf(self.module_missing, f"optional module not present: {self.module_name}")(obj) +class SkipIfModule(object): + """Decorator to be used if test should be skipped + when optional module is present.""" + + def __init__(self, module_name): + self.module_name = module_name + self.module_avail = optional_import(self.module_name)[1] + + def __call__(self, obj): + return unittest.skipIf(self.module_avail, f"Skipping because optional module present: {self.module_name}")(obj) + + def skip_if_no_cuda(obj): """ Skip the unit tests if torch.cuda.is_available is False @@ -80,6 +93,34 @@ def skip_if_windows(obj): return unittest.skipIf(sys.platform == "win32", "Skipping tests on Windows")(obj) +class SkipIfBeforePyTorchVersion(object): + """Decorator to be used if test should be skipped + with PyTorch versions older than that given.""" + + def __init__(self, pytorch_version_tuple): + self.min_version = pytorch_version_tuple + self.version_too_old = get_torch_version_tuple() < self.min_version + + def __call__(self, obj): + return unittest.skipIf( + self.version_too_old, f"Skipping tests that fail on PyTorch versions before: {self.min_version}" + )(obj) + + +class SkipIfAtLeastPyTorchVersion(object): + """Decorator to be used if test should be skipped + with PyTorch versions older than that given.""" + + def __init__(self, pytorch_version_tuple): + self.max_version = pytorch_version_tuple + self.version_too_new = get_torch_version_tuple() >= self.max_version + + def __call__(self, obj): + return unittest.skipIf( + self.version_too_new, f"Skipping tests that fail on PyTorch versions at least: {self.max_version}" + )(obj) + + def make_nifti_image(array, affine=None): """ Create a temporary nifti image on the disk and return the image name.