diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 01b1cb00bb..1c3ee288a1 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -53,6 +53,10 @@ Generic Interfaces .. autoclass:: Decollated :members: +`Fourier` +^^^^^^^^^^^^^ +.. autoclass:: Fourier + :members: Vanilla Transforms ------------------ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 487a995e5e..20e29d5aa9 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -306,7 +306,15 @@ ZoomD, ZoomDict, ) -from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform +from .transform import ( + Fourier, + MapTransform, + Randomizable, + RandomizableTransform, + ThreadUnsafe, + Transform, + apply_transform, +) from .utility.array import ( AddChannel, AddExtremePointsChannel, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index dfbac7465c..4533f333ce 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -23,7 +23,7 @@ from monai.config import DtypeLike from monai.data.utils import get_random_patch, get_valid_patch_size from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter -from monai.transforms.transform import RandomizableTransform, Transform +from monai.transforms.transform import Fourier, RandomizableTransform, Transform from monai.transforms.utils import rescale_array from monai.utils import ( PT_BEFORE_1_7, @@ -1196,7 +1196,7 @@ def _randomize(self, _: Any) -> None: self.sampled_alpha = self.R.uniform(self.alpha[0], self.alpha[1]) -class GibbsNoise(Transform): +class GibbsNoise(Transform, Fourier): """ The transform applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts are one of the common type of type artifacts appearing in MRI scans. @@ -1204,15 +1204,17 @@ class GibbsNoise(Transform): The transform is applied to all the channels in the data. For general information on Gibbs artifacts, please refer to: - https://pubs.rsna.org/doi/full/10.1148/rg.313105115 - https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949 + `An Image-based Approach to Understanding the Physics of MR Artifacts + `_. + + `The AAPM/RSNA Physics Tutorial for Residents + `_ Args: - alpha (float): Parametrizes the intensity of the Gibbs noise filter applied. Takes + alpha: Parametrizes the intensity of the Gibbs noise filter applied. Takes values in the interval [0,1] with alpha = 0 acting as the identity mapping. - as_tensor_output: if true return torch.Tensor, else return np.array. default: True. - + as_tensor_output: if true return torch.Tensor, else return np.array. Default: True. """ def __init__(self, alpha: float = 0.5, as_tensor_output: bool = True) -> None: @@ -1221,47 +1223,22 @@ def __init__(self, alpha: float = 0.5, as_tensor_output: bool = True) -> None: raise AssertionError("alpha must take values in the interval [0,1].") self.alpha = alpha self.as_tensor_output = as_tensor_output - self._device = torch.device("cpu") def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: n_dims = len(img.shape[1:]) - # convert to ndarray to work with np.fft - _device = None - if isinstance(img, torch.Tensor): - _device = img.device - img = img.cpu().detach().numpy() - + if isinstance(img, np.ndarray): + img = torch.Tensor(img) # FT - k = self._shift_fourier(img, n_dims) + k = self.shift_fourier(img, n_dims) # build and apply mask k = self._apply_mask(k) # map back - img = self._inv_shift_fourier(k, n_dims) - return torch.Tensor(img).to(_device or self._device) if self.as_tensor_output else img - - def _shift_fourier(self, x: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray: - """ - Applies fourier transform and shifts its output. - Only the spatial dimensions get transformed. + img = self.inv_shift_fourier(k, n_dims) - Args: - x (np.ndarray): tensor to fourier transform. - """ - out: np.ndarray = np.fft.fftshift(np.fft.fftn(x, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) - return out + return img if self.as_tensor_output else img.cpu().detach().numpy() - def _inv_shift_fourier(self, k: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray: - """ - Applies inverse shift and fourier transform. Only the spatial - dimensions are transformed. - """ - out: np.ndarray = np.fft.ifftn( - np.fft.ifftshift(k, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0)) - ).real - return out - - def _apply_mask(self, k: np.ndarray) -> np.ndarray: + def _apply_mask(self, k: torch.Tensor) -> torch.Tensor: """Builds and applies a mask on the spatial dimensions. Args: @@ -1287,11 +1264,11 @@ def _apply_mask(self, k: np.ndarray) -> np.ndarray: mask = np.repeat(mask[None], k.shape[0], axis=0) # apply binary mask - k_masked: np.ndarray = k * mask + k_masked = k * torch.tensor(mask, device=k.device) return k_masked -class KSpaceSpikeNoise(Transform): +class KSpaceSpikeNoise(Transform, Fourier): """ Apply localized spikes in `k`-space at the given locations and intensities. Spike (Herringbone) artifact is a type of data acquisition artifact which @@ -1354,7 +1331,7 @@ def __init__( def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: """ Args: - img (np.array or torch.tensor): image with dimensions (C, H, W) or (C, H, W, D) + img: image with dimensions (C, H, W) or (C, H, W, D) """ # checking that tuples in loc are consistent with img size self._check_indices(img) @@ -1368,22 +1345,17 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, n_dims = len(img.shape[1:]) - # convert to ndarray to work with np.fft - if isinstance(img, torch.Tensor): - device = img.device - img = img.cpu().detach().numpy() - else: - device = torch.device("cpu") - + if isinstance(img, np.ndarray): + img = torch.Tensor(img) # FT - k = self._shift_fourier(img, n_dims) - log_abs = np.log(np.absolute(k) + 1e-10) - phase = np.angle(k) + k = self.shift_fourier(img, n_dims) + log_abs = torch.log(torch.absolute(k) + 1e-10) + phase = torch.angle(k) k_intensity = self.k_intensity # default log intensity if k_intensity is None: - k_intensity = tuple(np.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5) + k_intensity = tuple(torch.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5) # highlight if isinstance(self.loc[0], Sequence): @@ -1392,9 +1364,10 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, else: self._set_spike(log_abs, self.loc, k_intensity) # map back - k = np.exp(log_abs) * np.exp(1j * phase) - img = self._inv_shift_fourier(k, n_dims) - return torch.Tensor(img, device=device) if self.as_tensor_output else img + k = torch.exp(log_abs) * torch.exp(1j * phase) + img = self.inv_shift_fourier(k, n_dims) + + return img if self.as_tensor_output else img.cpu().detach().numpy() def _check_indices(self, img) -> None: """Helper method to check consistency of self.loc and input image. @@ -1414,14 +1387,14 @@ def _check_indices(self, img) -> None: f"The index value at position {i} of one of the tuples in loc = {self.loc} is out of bounds for current image." ) - def _set_spike(self, k: np.ndarray, idx: Tuple, val: Union[Sequence[float], float]): + def _set_spike(self, k: torch.Tensor, idx: Tuple, val: Union[Sequence[float], float]): """ Helper function to introduce a given intensity at given location. Args: - k (np.array): intensity array to alter. - idx (tuple): index of location where to apply change. - val (float): value of intensity to write in. + k: intensity array to alter. + idx: index of location where to apply change. + val: value of intensity to write in. """ if len(k.shape) == len(idx): if isinstance(val, Sequence): @@ -1429,33 +1402,12 @@ def _set_spike(self, k: np.ndarray, idx: Tuple, val: Union[Sequence[float], floa else: k[idx] = val elif len(k.shape) == 4 and len(idx) == 3: - k[:, idx[0], idx[1], idx[2]] = val + k[:, idx[0], idx[1], idx[2]] = val # type: ignore elif len(k.shape) == 3 and len(idx) == 2: - k[:, idx[0], idx[1]] = val - - def _shift_fourier(self, x: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray: - """ - Applies fourier transform and shifts its output. - Only the spatial dimensions get transformed. - - Args: - x (np.ndarray): tensor to fourier transform. - """ - out: np.ndarray = np.fft.fftshift(np.fft.fftn(x, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) - return out + k[:, idx[0], idx[1]] = val # type: ignore - def _inv_shift_fourier(self, k: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray: - """ - Applies inverse shift and fourier transform. Only the spatial - dimensions are transformed. - """ - out: np.ndarray = np.fft.ifftn( - np.fft.ifftshift(k, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0)) - ).real - return out - -class RandKSpaceSpikeNoise(RandomizableTransform): +class RandKSpaceSpikeNoise(RandomizableTransform, Fourier): """ Naturalistic data augmentation via spike artifacts. The transform applies localized spikes in `k`-space, and it is the random version of @@ -1476,7 +1428,7 @@ class RandKSpaceSpikeNoise(RandomizableTransform): channels at once, or channel-wise if ``channel_wise = True``. intensity_range: pass a tuple (a, b) to sample the log-intensity from the interval (a, b) - uniformly for all channels. Or pass sequence of intervals + uniformly for all channels. Or pass sequence of intevals ((a0, b0), (a1, b1), ...) to sample for each respective channel. In the second case, the number of 2-tuples must match the number of channels. @@ -1521,7 +1473,7 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, Apply transform to `img`. Assumes data is in channel-first form. Args: - img (np.array or torch.tensor): image with dimensions (C, H, W) or (C, H, W, D) + img: image with dimensions (C, H, W) or (C, H, W, D) """ if self.intensity_range is not None: if isinstance(self.intensity_range[0], Sequence) and len(self.intensity_range) != img.shape[0]: @@ -1532,19 +1484,20 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, self.sampled_k_intensity = [] self.sampled_locs = [] - # convert to ndarray to work with np.fft - x, device = self._to_numpy(img) - intensity_range = self._make_sequence(x) - self._randomize(x, intensity_range) + if not isinstance(img, torch.Tensor): + img = torch.Tensor(img) + + intensity_range = self._make_sequence(img) + self._randomize(img, intensity_range) - # build/apply transform only if there are spike locations + # build/appy transform only if there are spike locations if self.sampled_locs: transform = KSpaceSpikeNoise(self.sampled_locs, self.sampled_k_intensity, self.as_tensor_output) - return transform(x) + return transform(img) - return torch.Tensor(x, device=device) if self.as_tensor_output else x + return img if self.as_tensor_output else img.detach().numpy() - def _randomize(self, img: np.ndarray, intensity_range: Sequence[Sequence[float]]) -> None: + def _randomize(self, img: torch.Tensor, intensity_range: Sequence[Sequence[float]]) -> None: """ Helper method to sample both the location and intensity of the spikes. When not working channel wise (channel_wise=False) it use the random @@ -1568,11 +1521,11 @@ def _randomize(self, img: np.ndarray, intensity_range: Sequence[Sequence[float]] spatial = tuple(self.R.randint(0, k) for k in img.shape[1:]) self.sampled_locs = [(i,) + spatial for i in range(img.shape[0])] if isinstance(intensity_range[0], Sequence): - self.sampled_k_intensity = [self.R.uniform(*p) for p in intensity_range] # type: ignore + self.sampled_k_intensity = [self.R.uniform(p[0], p[1]) for p in intensity_range] else: - self.sampled_k_intensity = [self.R.uniform(*self.intensity_range)] * len(img) # type: ignore + self.sampled_k_intensity = [self.R.uniform(intensity_range[0], intensity_range[1])] * len(img) # type: ignore - def _make_sequence(self, x: np.ndarray) -> Sequence[Sequence[float]]: + def _make_sequence(self, x: torch.Tensor) -> Sequence[Sequence[float]]: """ Formats the sequence of intensities ranges to Sequence[Sequence[float]]. """ @@ -1586,27 +1539,21 @@ def _make_sequence(self, x: np.ndarray) -> Sequence[Sequence[float]]: # set default range if one not provided return self._set_default_range(x) - def _set_default_range(self, x: np.ndarray) -> Sequence[Sequence[float]]: + def _set_default_range(self, img: torch.Tensor) -> Sequence[Sequence[float]]: """ Sets default intensity ranges to be sampled. Args: - x (np.ndarray): tensor to fourier transform. + img: image to transform. """ - n_dims = len(x.shape[1:]) + n_dims = len(img.shape[1:]) - k = np.fft.fftshift(np.fft.fftn(x, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) - log_abs = np.log(np.absolute(k) + 1e-10) - shifted_means = np.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5 + k = self.shift_fourier(img, n_dims) + log_abs = torch.log(torch.absolute(k) + 1e-10) + shifted_means = torch.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5 intensity_sequence = tuple((i * 0.95, i * 1.1) for i in shifted_means) return intensity_sequence - def _to_numpy(self, img: Union[np.ndarray, torch.Tensor]) -> Tuple[np.ndarray, torch.device]: - if isinstance(img, torch.Tensor): - return img.cpu().detach().numpy(), img.device - else: - return img, torch.device("cpu") - class RandCoarseDropout(RandomizableTransform): """ diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 49f20ea419..c24f7b67ca 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -1234,16 +1234,14 @@ class RandKSpaceSpikeNoised(RandomizableTransform, MapTransform): prob: probability to add spike artifact to each item in the dictionary provided it is realized that the noise will be applied to the dictionary. - img_intensity_range: Intensity - range to sample for ``"image"`` key. Pass a tuple `(a, b)` to sample - the log-intensity from the interval `(a, b)` uniformly for all - channels. Or pass sequence of intervals `((a0, b0), (a1, b1), ...)` - to sample for each respective channel. In the second case, the - number of 2-tuples must match the number of channels. - Default ranges is `(0.95x, 1.10x)` where `x` is the mean - log-intensity for each channel. - label_intensity_range: Intensity range to sample for ``"label"`` key. Same - as behavior as ``img_intensity_range`` but ``"label"`` key. + intensity_ranges: Dictionary with intensity + ranges to sample for each key. Given a dictionary value of `(a, b)` the + transform will sample the log-intensity from the interval `(a, b)` uniformly for all + channels of the respective key. If a sequence of intevals `((a0, b0), (a1, b1), ...)` + is given, then the transform will sample from each interval for each + respective channel. In the second case, the number of 2-tuples must + match the number of channels. Default ranges is `(0.95x, 1.10x)` + where `x` is the mean log-intensity for each channel. channel_wise: treat each channel independently. True by default. common_sampling: If ``True`` same values for location and log-intensity @@ -1257,7 +1255,7 @@ class RandKSpaceSpikeNoised(RandomizableTransform, MapTransform): To apply `k`-space spikes randomly on the image only, with probability 0.5, and log-intensity sampled from the interval [13, 15] for each channel independently, one uses - ``RandKSpaceSpikeNoised("image", prob=0.5, img_intensity_range=(13,15), channel_wise=True)``. + ``RandKSpaceSpikeNoised("image", prob=0.5, intensity_ranges={"image":(13,15)}, channel_wise=True)``. """ def __init__( @@ -1265,8 +1263,7 @@ def __init__( keys: KeysCollection, global_prob: float = 1.0, prob: float = 0.1, - img_intensity_range: Optional[Sequence[Union[Sequence[float], float]]] = None, - label_intensity_range: Optional[Sequence[Union[Sequence[float], float]]] = None, + intensity_ranges: Optional[Mapping[Hashable, Sequence[Union[Sequence[float], float]]]] = None, channel_wise: bool = True, common_sampling: bool = False, common_seed: int = 42, @@ -1281,8 +1278,15 @@ def __init__( self.common_seed = common_seed self.as_tensor_output = as_tensor_output # the spikes artifact is amplitude dependent so we instantiate one per key - self.t_img = RandKSpaceSpikeNoise(prob, img_intensity_range, channel_wise, self.as_tensor_output) - self.t_label = RandKSpaceSpikeNoise(prob, label_intensity_range, channel_wise, self.as_tensor_output) + self.transforms = {} + if isinstance(intensity_ranges, Mapping): + for k in self.keys: + self.transforms[k] = RandKSpaceSpikeNoise( + prob, intensity_ranges[k], channel_wise, self.as_tensor_output + ) + else: + for k in self.keys: + self.transforms[k] = RandKSpaceSpikeNoise(prob, None, channel_wise, self.as_tensor_output) def __call__( self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] @@ -1297,13 +1301,12 @@ def __call__( # In case the same spikes are desired for both image and label. if self.common_sampling: - self.t_img.set_random_state(self.common_seed) - self.t_label.set_random_state(self.common_seed) + for k in self.keys: + self.transforms[k].set_random_state(self.common_seed) - for key in self.key_iterator(d): + for key, t in self.key_iterator(d, self.transforms): if self._do_transform: - transform = self.t_img if key == "image" else self.t_label - d[key] = transform(d[key]) + d[key] = self.transforms[t](d[key]) else: if isinstance(d[key], np.ndarray) and self.as_tensor_output: d[key] = torch.Tensor(d[key]) @@ -1321,8 +1324,8 @@ def set_rand_state(self, seed: Optional[int] = None, state: Optional[np.random.R state: set the random state with a `np.random.RandomState` object.""" self.set_random_state(seed, state) - self.t_img.set_random_state(seed, state) - self.t_label.set_random_state(seed, state) + for key in self.keys: + self.transforms[key].set_random_state(seed, state) def _to_numpy(self, d: Union[torch.Tensor, np.ndarray]) -> np.ndarray: if isinstance(d, torch.Tensor): diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 681c0ba9ec..97cc2f21fc 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -23,7 +23,15 @@ from monai.config import KeysCollection from monai.utils import MAX_SEED, ensure_tuple -__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"] +__all__ = [ + "ThreadUnsafe", + "apply_transform", + "Randomizable", + "RandomizableTransform", + "Transform", + "MapTransform", + "Fourier", +] ReturnType = TypeVar("ReturnType") @@ -365,3 +373,43 @@ def key_iterator( yield (key,) + tuple(_ex_iters) if extra_iterables else key elif not self.allow_missing_keys: raise KeyError(f"Key was missing ({key}) and allow_missing_keys==False") + + +class Fourier: + """ + Helper class storing Fourier mappings + """ + + @staticmethod + def shift_fourier(x: torch.Tensor, n_dims: int) -> torch.Tensor: + """ + Applies fourier transform and shifts the zero-frequency component to the + center of the spectrum. Only the spatial dimensions get transformed. + + Args: + x: Image to transform. + n_dims: Number of spatial dimensions. + Returns + k: K-space data. + """ + k: torch.Tensor = torch.fft.fftshift( + torch.fft.fftn(x, dim=tuple(range(-n_dims, 0))), dim=tuple(range(-n_dims, 0)) + ) + return k + + @staticmethod + def inv_shift_fourier(k: torch.Tensor, n_dims: int) -> torch.Tensor: + """ + Applies inverse shift and fourier transform. Only the spatial + dimensions are transformed. + + Args: + k: K-space data. + n_dims: Number of spatial dimensions. + Returns: + x: Tensor in image space. + """ + x: torch.Tensor = torch.fft.ifftn( + torch.fft.ifftshift(k, dim=tuple(range(-n_dims, 0))), dim=tuple(range(-n_dims, 0)) + ).real + return x diff --git a/tests/test_fourier.py b/tests/test_fourier.py new file mode 100644 index 0000000000..488bf0cbf9 --- /dev/null +++ b/tests/test_fourier.py @@ -0,0 +1,70 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data.synthetic import create_test_image_2d, create_test_image_3d +from monai.transforms import Fourier +from monai.utils.misc import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule + +TEST_CASES = [((128, 64),), ((64, 48, 80),)] +# for shape in ((128, 64), (64, 48, 80)): +# TEST_CASES.append(shape) + + +@SkipIfBeforePyTorchVersion((1, 8)) +@SkipIfNoModule("torch.fft") +class TestFourier(unittest.TestCase): + def setUp(self): + set_determinism(0) + super().setUp() + + def tearDown(self): + set_determinism(None) + + @staticmethod + def get_data(img_shape): + create_test_image = create_test_image_2d if len(img_shape) == 2 else create_test_image_3d + im = create_test_image(*img_shape, num_objs=4, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] + return torch.Tensor(im) + + @parameterized.expand(TEST_CASES) + def test_forward(self, img_shape): + n_dims = len(img_shape[1:]) + x = self.get_data(img_shape) + t = Fourier() + out = t.shift_fourier(x, n_dims) + + expect = torch.fft.fftshift(torch.fft.fftn(x, dim=tuple(range(-n_dims, 0))), dim=tuple(range(-n_dims, 0))) + + np.testing.assert_allclose(out, expect) + + @parameterized.expand(TEST_CASES) + def test_backward(self, img_shape): + n_dims = len(img_shape[1:]) + x = self.get_data(img_shape) + t = Fourier() + out = t.inv_shift_fourier(x, n_dims) + + expect = torch.fft.ifftn( + torch.fft.ifftshift(x, dim=tuple(range(-n_dims, 0))), dim=tuple(range(-n_dims, 0)) + ).real + + np.testing.assert_allclose(out, expect) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_gibbs_noise.py b/tests/test_gibbs_noise.py index 83cba56938..264e2e630a 100644 --- a/tests/test_gibbs_noise.py +++ b/tests/test_gibbs_noise.py @@ -19,6 +19,7 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import GibbsNoise from monai.utils.misc import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): @@ -27,6 +28,8 @@ TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) +@SkipIfBeforePyTorchVersion((1, 8)) +@SkipIfNoModule("torch.fft") class TestGibbsNoise(unittest.TestCase): def setUp(self): set_determinism(0) diff --git a/tests/test_gibbs_noised.py b/tests/test_gibbs_noised.py index 0e02feb341..8ad4839338 100644 --- a/tests/test_gibbs_noised.py +++ b/tests/test_gibbs_noised.py @@ -19,6 +19,7 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import GibbsNoised from monai.utils.misc import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): @@ -29,6 +30,8 @@ KEYS = ["im", "label"] +@SkipIfBeforePyTorchVersion((1, 8)) +@SkipIfNoModule("torch.fft") class TestGibbsNoised(unittest.TestCase): def setUp(self): set_determinism(0) diff --git a/tests/test_k_space_spike_noise.py b/tests/test_k_space_spike_noise.py index 53661d5fcb..bb6d05e676 100644 --- a/tests/test_k_space_spike_noise.py +++ b/tests/test_k_space_spike_noise.py @@ -20,6 +20,7 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import KSpaceSpikeNoise from monai.utils.misc import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): @@ -28,6 +29,8 @@ TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) +@SkipIfBeforePyTorchVersion((1, 8)) +@SkipIfNoModule("torch.fft") class TestKSpaceSpikeNoise(unittest.TestCase): def setUp(self): set_determinism(0) diff --git a/tests/test_k_space_spike_noised.py b/tests/test_k_space_spike_noised.py index e5d2dfb6f8..e891bd4568 100644 --- a/tests/test_k_space_spike_noised.py +++ b/tests/test_k_space_spike_noised.py @@ -20,6 +20,7 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import KSpaceSpikeNoised from monai.utils.misc import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): @@ -30,6 +31,8 @@ KEYS = ["image", "label"] +@SkipIfBeforePyTorchVersion((1, 8)) +@SkipIfNoModule("torch.fft") class TestKSpaceSpikeNoised(unittest.TestCase): def setUp(self): set_determinism(0) diff --git a/tests/test_rand_gibbs_noise.py b/tests/test_rand_gibbs_noise.py index 94948c5a0d..a0701d09c3 100644 --- a/tests/test_rand_gibbs_noise.py +++ b/tests/test_rand_gibbs_noise.py @@ -19,6 +19,7 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import RandGibbsNoise from monai.utils.misc import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): @@ -27,6 +28,8 @@ TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) +@SkipIfBeforePyTorchVersion((1, 8)) +@SkipIfNoModule("torch.fft") class TestRandGibbsNoise(unittest.TestCase): def setUp(self): set_determinism(0) diff --git a/tests/test_rand_gibbs_noised.py b/tests/test_rand_gibbs_noised.py index 986f4c02ae..72188a93b5 100644 --- a/tests/test_rand_gibbs_noised.py +++ b/tests/test_rand_gibbs_noised.py @@ -19,6 +19,7 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import RandGibbsNoised from monai.utils.misc import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): @@ -29,6 +30,8 @@ KEYS = ["im", "label"] +@SkipIfBeforePyTorchVersion((1, 8)) +@SkipIfNoModule("torch.fft") class TestRandGibbsNoised(unittest.TestCase): def setUp(self): set_determinism(0) diff --git a/tests/test_rand_k_space_spike_noise.py b/tests/test_rand_k_space_spike_noise.py index ba9156c5b2..71f7e36d9b 100644 --- a/tests/test_rand_k_space_spike_noise.py +++ b/tests/test_rand_k_space_spike_noise.py @@ -19,6 +19,7 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import KSpaceSpikeNoise, RandKSpaceSpikeNoise from monai.utils.misc import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): @@ -28,6 +29,8 @@ TEST_CASES.append((shape, as_tensor_output, as_tensor_input, channel_wise)) +@SkipIfBeforePyTorchVersion((1, 8)) +@SkipIfNoModule("torch.fft") class TestRandKSpaceSpikeNoise(unittest.TestCase): def setUp(self): set_determinism(0) diff --git a/tests/test_rand_k_space_spike_noised.py b/tests/test_rand_k_space_spike_noised.py index 3cb49f1c08..d61b83e2d5 100644 --- a/tests/test_rand_k_space_spike_noised.py +++ b/tests/test_rand_k_space_spike_noised.py @@ -19,6 +19,7 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import RandKSpaceSpikeNoised from monai.utils.misc import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): @@ -29,6 +30,8 @@ KEYS = ["image", "label"] +@SkipIfBeforePyTorchVersion((1, 8)) +@SkipIfNoModule("torch.fft") class TestKSpaceSpikeNoised(unittest.TestCase): def setUp(self): set_determinism(0) @@ -50,13 +53,12 @@ def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): data = self.get_data(im_shape, as_tensor_input) - intensity_range = (13, 15) + intensity_ranges = {"image": (13, 15), "label": (13, 15)} t = RandKSpaceSpikeNoised( KEYS, global_prob=1.0, prob=1.0, - img_intensity_range=intensity_range, - label_intensity_range=intensity_range, + intensity_ranges=intensity_ranges, channel_wise=True, as_tensor_output=as_tensor_output, ) @@ -73,13 +75,12 @@ def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): @parameterized.expand(TEST_CASES) def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): data = self.get_data(im_shape, as_tensor_input) - intensity_range = (13, 15) + intensity_ranges = {"image": (13, 15), "label": (13, 15)} t1 = RandKSpaceSpikeNoised( KEYS, global_prob=0.0, prob=1.0, - img_intensity_range=intensity_range, - label_intensity_range=intensity_range, + intensity_ranges=intensity_ranges, channel_wise=True, as_tensor_output=as_tensor_output, ) @@ -88,8 +89,7 @@ def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): KEYS, global_prob=0.0, prob=1.0, - img_intensity_range=intensity_range, - label_intensity_range=intensity_range, + intensity_ranges=intensity_ranges, channel_wise=True, as_tensor_output=as_tensor_output, ) @@ -104,23 +104,21 @@ def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): def test_intensity(self, im_shape, as_tensor_output, as_tensor_input): data = self.get_data(im_shape, as_tensor_input) - image_range = (15, 15.1) - label_range = (14, 14.1) + intensity_ranges = {"image": (13, 13.1), "label": (13, 13.1)} t = RandKSpaceSpikeNoised( KEYS, global_prob=1.0, prob=1.0, - img_intensity_range=image_range, - label_intensity_range=label_range, + intensity_ranges=intensity_ranges, channel_wise=True, as_tensor_output=True, ) _ = t(data) - self.assertGreaterEqual(t.t_img.sampled_k_intensity[0], 15) - self.assertLessEqual(t.t_img.sampled_k_intensity[0], 15.1) - self.assertGreaterEqual(t.t_label.sampled_k_intensity[0], 14) - self.assertLessEqual(t.t_label.sampled_k_intensity[0], 14.1) + self.assertGreaterEqual(t.transforms["image"].sampled_k_intensity[0], 13) + self.assertLessEqual(t.transforms["image"].sampled_k_intensity[0], 13.1) + self.assertGreaterEqual(t.transforms["label"].sampled_k_intensity[0], 13) + self.assertLessEqual(t.transforms["label"].sampled_k_intensity[0], 13.1) @parameterized.expand(TEST_CASES) def test_same_transformation(self, im_shape, _, as_tensor_input): @@ -128,14 +126,13 @@ def test_same_transformation(self, im_shape, _, as_tensor_input): # use same image for both dictionary entries to check same trans is applied to them data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} - image_range = label_range = (15, 15.1) + intensity_ranges = {"image": (13, 15), "label": (13, 15)} # use common_sampling = True to ask for the same transformation t = RandKSpaceSpikeNoised( KEYS, global_prob=1.0, prob=1.0, - img_intensity_range=image_range, - label_intensity_range=label_range, + intensity_ranges=intensity_ranges, channel_wise=True, common_sampling=True, as_tensor_output=True,