From 346188c93bbae07ef976dc48754c9cecfed7cc17 Mon Sep 17 00:00:00 2001 From: Yaniel Cabrera Date: Wed, 28 Jul 2021 21:58:41 +0300 Subject: [PATCH 1/8] Moved fourier functions to their own class. Modified RandKSpaceSpikeNoised. 1. Allow RandKSpaceSpikeNoised to work with arbitrary keys. 2. Introduced Fourier transform to keep the forward/backward fourier mappings. Signed-off-by: Yaniel Cabrera --- monai/transforms/__init__.py | 2 +- monai/transforms/intensity/array.py | 159 ++++++++--------------- monai/transforms/intensity/dictionary.py | 117 ++++++++++++++++- monai/transforms/transform.py | 36 ++++- tests/test_rand_k_space_spike_noised.py | 32 ++--- 5 files changed, 218 insertions(+), 128 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 45eecd266c..d0c6987424 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -306,7 +306,7 @@ ZoomD, ZoomDict, ) -from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform +from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform, Fourier from .utility.array import ( AddChannel, AddExtremePointsChannel, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index dfbac7465c..3a60ecc7bc 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -23,7 +23,7 @@ from monai.config import DtypeLike from monai.data.utils import get_random_patch, get_valid_patch_size from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter -from monai.transforms.transform import RandomizableTransform, Transform +from monai.transforms.transform import RandomizableTransform, Transform, Fourier from monai.transforms.utils import rescale_array from monai.utils import ( PT_BEFORE_1_7, @@ -1196,7 +1196,7 @@ def _randomize(self, _: Any) -> None: self.sampled_alpha = self.R.uniform(self.alpha[0], self.alpha[1]) -class GibbsNoise(Transform): +class GibbsNoise(Transform, Fourier): """ The transform applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts are one of the common type of type artifacts appearing in MRI scans. @@ -1204,15 +1204,17 @@ class GibbsNoise(Transform): The transform is applied to all the channels in the data. For general information on Gibbs artifacts, please refer to: - https://pubs.rsna.org/doi/full/10.1148/rg.313105115 - https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949 + `An Image-based Approach to Understanding the Physics of MR Artifacts + `_. + + `The AAPM/RSNA Physics Tutorial for Residents + `_ Args: - alpha (float): Parametrizes the intensity of the Gibbs noise filter applied. Takes + alpha: Parametrizes the intensity of the Gibbs noise filter applied. Takes values in the interval [0,1] with alpha = 0 acting as the identity mapping. - as_tensor_output: if true return torch.Tensor, else return np.array. default: True. - + as_tensor_output: if true return torch.Tensor, else return np.array. Default: True. """ def __init__(self, alpha: float = 0.5, as_tensor_output: bool = True) -> None: @@ -1221,47 +1223,22 @@ def __init__(self, alpha: float = 0.5, as_tensor_output: bool = True) -> None: raise AssertionError("alpha must take values in the interval [0,1].") self.alpha = alpha self.as_tensor_output = as_tensor_output - self._device = torch.device("cpu") def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: n_dims = len(img.shape[1:]) - # convert to ndarray to work with np.fft - _device = None - if isinstance(img, torch.Tensor): - _device = img.device - img = img.cpu().detach().numpy() - + if isinstance(img, np.ndarray): + img = torch.Tensor(img) # FT - k = self._shift_fourier(img, n_dims) + k = self.shift_fourier(img, n_dims) # build and apply mask k = self._apply_mask(k) # map back - img = self._inv_shift_fourier(k, n_dims) - return torch.Tensor(img).to(_device or self._device) if self.as_tensor_output else img - - def _shift_fourier(self, x: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray: - """ - Applies fourier transform and shifts its output. - Only the spatial dimensions get transformed. + img = self.inv_shift_fourier(k, n_dims) - Args: - x (np.ndarray): tensor to fourier transform. - """ - out: np.ndarray = np.fft.fftshift(np.fft.fftn(x, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) - return out + return img if self.as_tensor_output else img.cpu().detach().numpy() - def _inv_shift_fourier(self, k: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray: - """ - Applies inverse shift and fourier transform. Only the spatial - dimensions are transformed. - """ - out: np.ndarray = np.fft.ifftn( - np.fft.ifftshift(k, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0)) - ).real - return out - - def _apply_mask(self, k: np.ndarray) -> np.ndarray: + def _apply_mask(self, k: torch.Tensor) -> torch.Tensor: """Builds and applies a mask on the spatial dimensions. Args: @@ -1287,11 +1264,11 @@ def _apply_mask(self, k: np.ndarray) -> np.ndarray: mask = np.repeat(mask[None], k.shape[0], axis=0) # apply binary mask - k_masked: np.ndarray = k * mask + k_masked = k * torch.tensor(mask, device=k.device) return k_masked -class KSpaceSpikeNoise(Transform): +class KSpaceSpikeNoise(Transform, Fourier): """ Apply localized spikes in `k`-space at the given locations and intensities. Spike (Herringbone) artifact is a type of data acquisition artifact which @@ -1354,7 +1331,7 @@ def __init__( def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: """ Args: - img (np.array or torch.tensor): image with dimensions (C, H, W) or (C, H, W, D) + img: image with dimensions (C, H, W) or (C, H, W, D) """ # checking that tuples in loc are consistent with img size self._check_indices(img) @@ -1368,22 +1345,17 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, n_dims = len(img.shape[1:]) - # convert to ndarray to work with np.fft - if isinstance(img, torch.Tensor): - device = img.device - img = img.cpu().detach().numpy() - else: - device = torch.device("cpu") - + if isinstance(img, np.ndarray): + img = torch.Tensor(img) # FT - k = self._shift_fourier(img, n_dims) - log_abs = np.log(np.absolute(k) + 1e-10) - phase = np.angle(k) + k = self.shift_fourier(img, n_dims) + log_abs = torch.log(torch.absolute(k) + 1e-10) + phase = torch.angle(k) k_intensity = self.k_intensity # default log intensity if k_intensity is None: - k_intensity = tuple(np.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5) + k_intensity = tuple(torch.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5) # highlight if isinstance(self.loc[0], Sequence): @@ -1392,9 +1364,10 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, else: self._set_spike(log_abs, self.loc, k_intensity) # map back - k = np.exp(log_abs) * np.exp(1j * phase) - img = self._inv_shift_fourier(k, n_dims) - return torch.Tensor(img, device=device) if self.as_tensor_output else img + k = torch.exp(log_abs) * torch.exp(1j * phase) + img = self.inv_shift_fourier(k, n_dims) + + return img if self.as_tensor_output else img.cpu().detach().numpy() def _check_indices(self, img) -> None: """Helper method to check consistency of self.loc and input image. @@ -1414,14 +1387,14 @@ def _check_indices(self, img) -> None: f"The index value at position {i} of one of the tuples in loc = {self.loc} is out of bounds for current image." ) - def _set_spike(self, k: np.ndarray, idx: Tuple, val: Union[Sequence[float], float]): + def _set_spike(self, k: torch.Tensor, idx: Tuple, val: Union[Sequence[float], float]): """ Helper function to introduce a given intensity at given location. Args: - k (np.array): intensity array to alter. - idx (tuple): index of location where to apply change. - val (float): value of intensity to write in. + k: intensity array to alter. + idx: index of location where to apply change. + val: value of intensity to write in. """ if len(k.shape) == len(idx): if isinstance(val, Sequence): @@ -1433,29 +1406,8 @@ def _set_spike(self, k: np.ndarray, idx: Tuple, val: Union[Sequence[float], floa elif len(k.shape) == 3 and len(idx) == 2: k[:, idx[0], idx[1]] = val - def _shift_fourier(self, x: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray: - """ - Applies fourier transform and shifts its output. - Only the spatial dimensions get transformed. - Args: - x (np.ndarray): tensor to fourier transform. - """ - out: np.ndarray = np.fft.fftshift(np.fft.fftn(x, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) - return out - - def _inv_shift_fourier(self, k: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray: - """ - Applies inverse shift and fourier transform. Only the spatial - dimensions are transformed. - """ - out: np.ndarray = np.fft.ifftn( - np.fft.ifftshift(k, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0)) - ).real - return out - - -class RandKSpaceSpikeNoise(RandomizableTransform): +class RandKSpaceSpikeNoise(RandomizableTransform, Fourier): """ Naturalistic data augmentation via spike artifacts. The transform applies localized spikes in `k`-space, and it is the random version of @@ -1476,7 +1428,7 @@ class RandKSpaceSpikeNoise(RandomizableTransform): channels at once, or channel-wise if ``channel_wise = True``. intensity_range: pass a tuple (a, b) to sample the log-intensity from the interval (a, b) - uniformly for all channels. Or pass sequence of intervals + uniformly for all channels. Or pass sequence of intevals ((a0, b0), (a1, b1), ...) to sample for each respective channel. In the second case, the number of 2-tuples must match the number of channels. @@ -1521,7 +1473,7 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, Apply transform to `img`. Assumes data is in channel-first form. Args: - img (np.array or torch.tensor): image with dimensions (C, H, W) or (C, H, W, D) + img: image with dimensions (C, H, W) or (C, H, W, D) """ if self.intensity_range is not None: if isinstance(self.intensity_range[0], Sequence) and len(self.intensity_range) != img.shape[0]: @@ -1532,19 +1484,20 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, self.sampled_k_intensity = [] self.sampled_locs = [] - # convert to ndarray to work with np.fft - x, device = self._to_numpy(img) - intensity_range = self._make_sequence(x) - self._randomize(x, intensity_range) + if not isinstance(img, torch.Tensor): + img = torch.Tensor(img) + + intensity_range = self._make_sequence(img) + self._randomize(img, intensity_range) - # build/apply transform only if there are spike locations + # build/appy transform only if there are spike locations if self.sampled_locs: transform = KSpaceSpikeNoise(self.sampled_locs, self.sampled_k_intensity, self.as_tensor_output) - return transform(x) + return transform(img) - return torch.Tensor(x, device=device) if self.as_tensor_output else x + return img if self.as_tensor_output else img.detach().numpy() - def _randomize(self, img: np.ndarray, intensity_range: Sequence[Sequence[float]]) -> None: + def _randomize(self, img: torch.Tensor, intensity_range: Sequence[Sequence[float]]) -> None: """ Helper method to sample both the location and intensity of the spikes. When not working channel wise (channel_wise=False) it use the random @@ -1568,11 +1521,11 @@ def _randomize(self, img: np.ndarray, intensity_range: Sequence[Sequence[float]] spatial = tuple(self.R.randint(0, k) for k in img.shape[1:]) self.sampled_locs = [(i,) + spatial for i in range(img.shape[0])] if isinstance(intensity_range[0], Sequence): - self.sampled_k_intensity = [self.R.uniform(*p) for p in intensity_range] # type: ignore + self.sampled_k_intensity = [self.R.uniform(*p) for p in intensity_range] else: - self.sampled_k_intensity = [self.R.uniform(*self.intensity_range)] * len(img) # type: ignore + self.sampled_k_intensity = [self.R.uniform(*self.intensity_range)] * len(img) - def _make_sequence(self, x: np.ndarray) -> Sequence[Sequence[float]]: + def _make_sequence(self, x: torch.Tensor) -> Sequence[Sequence[float]]: """ Formats the sequence of intensities ranges to Sequence[Sequence[float]]. """ @@ -1586,27 +1539,21 @@ def _make_sequence(self, x: np.ndarray) -> Sequence[Sequence[float]]: # set default range if one not provided return self._set_default_range(x) - def _set_default_range(self, x: np.ndarray) -> Sequence[Sequence[float]]: + def _set_default_range(self, img: torch.Tensor) -> Sequence[Sequence[float]]: """ Sets default intensity ranges to be sampled. Args: - x (np.ndarray): tensor to fourier transform. + img: image to transform. """ - n_dims = len(x.shape[1:]) + n_dims = len(img.shape[1:]) - k = np.fft.fftshift(np.fft.fftn(x, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) - log_abs = np.log(np.absolute(k) + 1e-10) - shifted_means = np.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5 + k = self.shift_fourier(img, n_dims) + log_abs = torch.log(torch.absolute(k) + 1e-10) + shifted_means = torch.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5 intensity_sequence = tuple((i * 0.95, i * 1.1) for i in shifted_means) return intensity_sequence - def _to_numpy(self, img: Union[np.ndarray, torch.Tensor]) -> Tuple[np.ndarray, torch.device]: - if isinstance(img, torch.Tensor): - return img.cpu().detach().numpy(), img.device - else: - return img, torch.device("cpu") - class RandCoarseDropout(RandomizableTransform): """ diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 49f20ea419..2c03c4d2de 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -1212,7 +1212,7 @@ def __call__( return d -class RandKSpaceSpikeNoised(RandomizableTransform, MapTransform): +class RandKSpaceSpikeNoised_(RandomizableTransform, MapTransform): """ Dictionary-based version of :py:class:`monai.transforms.RandKSpaceSpikeNoise`. @@ -1329,6 +1329,121 @@ def _to_numpy(self, d: Union[torch.Tensor, np.ndarray]) -> np.ndarray: d_numpy: np.ndarray = d.cpu().detach().numpy() return d_numpy +class RandKSpaceSpikeNoised(RandomizableTransform, MapTransform): + """ + Dictionary-based version of :py:class:`monai.transforms.RandKSpaceSpikeNoise`. + + Naturalistic data augmentation via spike artifacts. The transform applies + localized spikes in `k`-space. + + For general information on spike artifacts, please refer to: + + `AAPM/RSNA physics tutorial for residents: fundamental physics of MR imaging + `_. + + `Body MRI artifacts in clinical practice: A physicist's and radiologist's + perspective `_. + + Args: + keys: "image", "label", or ["image", "label"] depending + on which data you need to transform. + global_prob: probability of applying transform to the dictionary. + prob: probability to add spike artifact to each item in the + dictionary provided it is realized that the noise will be applied + to the dictionary. + intensity_ranges: Dictionary with intensity + ranges to sample for each key. Given a dictionary value of `(a, b)` the + transform will sample the log-intensity from the interval `(a, b)` uniformly for all + channels of the respective key. If a sequence of intevals `((a0, b0), (a1, b1), ...)` + is given, then the transform will sample from each interval for each + respective channel. In the second case, the number of 2-tuples must + match the number of channels. Default ranges is `(0.95x, 1.10x)` + where `x` is the mean log-intensity for each channel. + channel_wise: treat each channel independently. True by + default. + common_sampling: If ``True`` same values for location and log-intensity + will be sampled for the image and label. + common_seed: Seed to be used in case ``common_sampling = True``. + as_tensor_output: if ``True`` return torch.Tensor, else return + np.array. Default: ``True``. + allow_missing_keys: do not raise exception if key is missing. + + Example: + To apply `k`-space spikes randomly on the image only, with probability + 0.5, and log-intensity sampled from the interval [13, 15] for each + channel independently, one uses + ``RandKSpaceSpikeNoised("image", prob=0.5, intensity_ranges={"image":(13,15)}, channel_wise=True)``. + """ + + def __init__( + self, + keys: KeysCollection, + global_prob: float = 1.0, + prob: float = 0.1, + intensity_ranges: Optional[Mapping[Hashable, Sequence[Union[Sequence[float], float]]]] = None, + channel_wise: bool = True, + common_sampling: bool = False, + common_seed: int = 42, + as_tensor_output: bool = True, + allow_missing_keys: bool = False, + ): + + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, global_prob) + + self.common_sampling = common_sampling + self.common_seed = common_seed + self.as_tensor_output = as_tensor_output + # the spikes artifact is amplitude dependent so we instantiate one per key + self.transforms = {} + for k in self.keys: + self.transforms[k] = RandKSpaceSpikeNoise(prob, intensity_ranges[k], channel_wise, self.as_tensor_output) + + def __call__( + self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] + ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: + """ + Args: + data: Expects image/label to have dimensions (C, H, W) or + (C, H, W, D), where C is the channel. + """ + d = dict(data) + super().randomize(None) + + # In case the same spikes are desired for both image and label. + if self.common_sampling: + for k in self.keys: + self.transforms[k].set_random_state(self.common_seed) + + for key, t in self.key_iterator(d, self.transforms): + if self._do_transform: + d[key] = self.transforms[t](d[key]) + else: + if isinstance(d[key], np.ndarray) and self.as_tensor_output: + d[key] = torch.Tensor(d[key]) + elif isinstance(d[key], torch.Tensor) and not self.as_tensor_output: + d[key] = self._to_numpy(d[key]) + return d + + def set_rand_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> None: + """ + Set the random state locally to control the randomness. + User should use this method instead of ``set_random_state``. + + Args: + seed: set the random state with an integer seed. + state: set the random state with a `np.random.RandomState` object.""" + + self.set_random_state(seed, state) + for key in self.keys: + self.transforms[key].set_random_state(seed, state) + + + def _to_numpy(self, d: Union[torch.Tensor, np.ndarray]) -> np.ndarray: + if isinstance(d, torch.Tensor): + d_numpy: np.ndarray = d.cpu().detach().numpy() + return d_numpy + class RandCoarseDropoutd(RandomizableTransform, MapTransform): """ diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 681c0ba9ec..5807d9dd02 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -23,7 +23,7 @@ from monai.config import KeysCollection from monai.utils import MAX_SEED, ensure_tuple -__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"] +__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform", "Fourier"] ReturnType = TypeVar("ReturnType") @@ -365,3 +365,37 @@ def key_iterator( yield (key,) + tuple(_ex_iters) if extra_iterables else key elif not self.allow_missing_keys: raise KeyError(f"Key was missing ({key}) and allow_missing_keys==False") + +class Fourier: + """ + Helper class storing Fourier mappings + """ + + @staticmethod + def shift_fourier(x: Union[torch.Tensor, np.ndarray], n_dims: int) -> Union[torch.Tensor, np.ndarray]: + """ + Applies fourier transform and shifts its output. + Only the spatial dimensions get transformed. + + Args: + x (np.ndarray): tensor to fourier transform. + """ + # argument is dim if torch, else axes + mod, arg = (torch, "dim") if type(x) is torch.Tensor else (np, "axes") + arg_dict = {arg: tuple(range(-n_dims, 0))} + out = mod.fft.fftshift(mod.fft.fftn(x, **arg_dict), **arg_dict) + return out + + @staticmethod + def inv_shift_fourier(k: Union[torch.Tensor, np.ndarray], n_dims: int) -> Union[torch.Tensor, np.ndarray]: + """ + Applies inverse shift and fourier transform. Only the spatial + dimensions are transformed. + """ + dims = tuple(range(-n_dims, 0)) + + if type(k) is torch.Tensor: + out = torch.fft.ifftn(torch.fft.ifftshift(k, dim=dims), dim=dims) + else: + out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims) + return out.real \ No newline at end of file diff --git a/tests/test_rand_k_space_spike_noised.py b/tests/test_rand_k_space_spike_noised.py index 3cb49f1c08..00f2bea4fe 100644 --- a/tests/test_rand_k_space_spike_noised.py +++ b/tests/test_rand_k_space_spike_noised.py @@ -50,13 +50,12 @@ def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): data = self.get_data(im_shape, as_tensor_input) - intensity_range = (13, 15) + intensity_ranges = {"image":(13, 15), "label":(13,15)} t = RandKSpaceSpikeNoised( KEYS, global_prob=1.0, prob=1.0, - img_intensity_range=intensity_range, - label_intensity_range=intensity_range, + intensity_ranges=intensity_ranges, channel_wise=True, as_tensor_output=as_tensor_output, ) @@ -73,13 +72,12 @@ def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): @parameterized.expand(TEST_CASES) def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): data = self.get_data(im_shape, as_tensor_input) - intensity_range = (13, 15) + intensity_ranges = {"image":(13, 15), "label":(13,15)} t1 = RandKSpaceSpikeNoised( KEYS, global_prob=0.0, prob=1.0, - img_intensity_range=intensity_range, - label_intensity_range=intensity_range, + intensity_ranges=intensity_ranges, channel_wise=True, as_tensor_output=as_tensor_output, ) @@ -88,8 +86,7 @@ def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): KEYS, global_prob=0.0, prob=1.0, - img_intensity_range=intensity_range, - label_intensity_range=intensity_range, + intensity_ranges=intensity_ranges, channel_wise=True, as_tensor_output=as_tensor_output, ) @@ -104,23 +101,21 @@ def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): def test_intensity(self, im_shape, as_tensor_output, as_tensor_input): data = self.get_data(im_shape, as_tensor_input) - image_range = (15, 15.1) - label_range = (14, 14.1) + intensity_ranges = {"image":(13, 13.1), "label":(13,13.1)} t = RandKSpaceSpikeNoised( KEYS, global_prob=1.0, prob=1.0, - img_intensity_range=image_range, - label_intensity_range=label_range, + intensity_ranges=intensity_ranges, channel_wise=True, as_tensor_output=True, ) _ = t(data) - self.assertGreaterEqual(t.t_img.sampled_k_intensity[0], 15) - self.assertLessEqual(t.t_img.sampled_k_intensity[0], 15.1) - self.assertGreaterEqual(t.t_label.sampled_k_intensity[0], 14) - self.assertLessEqual(t.t_label.sampled_k_intensity[0], 14.1) + self.assertGreaterEqual(t.transforms["image"].sampled_k_intensity[0], 13) + self.assertLessEqual(t.transforms["image"].sampled_k_intensity[0], 13.1) + self.assertGreaterEqual(t.transforms["label"].sampled_k_intensity[0], 13) + self.assertLessEqual(t.transforms["label"].sampled_k_intensity[0], 13.1) @parameterized.expand(TEST_CASES) def test_same_transformation(self, im_shape, _, as_tensor_input): @@ -128,14 +123,13 @@ def test_same_transformation(self, im_shape, _, as_tensor_input): # use same image for both dictionary entries to check same trans is applied to them data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} - image_range = label_range = (15, 15.1) + intensity_ranges = {"image":(13, 15), "label":(13,15)} # use common_sampling = True to ask for the same transformation t = RandKSpaceSpikeNoised( KEYS, global_prob=1.0, prob=1.0, - img_intensity_range=image_range, - label_intensity_range=label_range, + intensity_ranges=intensity_ranges, channel_wise=True, common_sampling=True, as_tensor_output=True, From 7c5ecebcd47750eac52650ec3577a28954c82607 Mon Sep 17 00:00:00 2001 From: Yaniel Cabrera Date: Wed, 28 Jul 2021 22:14:36 +0300 Subject: [PATCH 2/8] removed old code Signed-off-by: Yaniel Cabrera --- monai/transforms/intensity/dictionary.py | 117 ----------------------- 1 file changed, 117 deletions(-) diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 2c03c4d2de..945d6697f3 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -1212,123 +1212,6 @@ def __call__( return d -class RandKSpaceSpikeNoised_(RandomizableTransform, MapTransform): - """ - Dictionary-based version of :py:class:`monai.transforms.RandKSpaceSpikeNoise`. - - Naturalistic data augmentation via spike artifacts. The transform applies - localized spikes in `k`-space. - - For general information on spike artifacts, please refer to: - - `AAPM/RSNA physics tutorial for residents: fundamental physics of MR imaging - `_. - - `Body MRI artifacts in clinical practice: A physicist's and radiologist's - perspective `_. - - Args: - keys: "image", "label", or ["image", "label"] depending - on which data you need to transform. - global_prob: probability of applying transform to the dictionary. - prob: probability to add spike artifact to each item in the - dictionary provided it is realized that the noise will be applied - to the dictionary. - img_intensity_range: Intensity - range to sample for ``"image"`` key. Pass a tuple `(a, b)` to sample - the log-intensity from the interval `(a, b)` uniformly for all - channels. Or pass sequence of intervals `((a0, b0), (a1, b1), ...)` - to sample for each respective channel. In the second case, the - number of 2-tuples must match the number of channels. - Default ranges is `(0.95x, 1.10x)` where `x` is the mean - log-intensity for each channel. - label_intensity_range: Intensity range to sample for ``"label"`` key. Same - as behavior as ``img_intensity_range`` but ``"label"`` key. - channel_wise: treat each channel independently. True by - default. - common_sampling: If ``True`` same values for location and log-intensity - will be sampled for the image and label. - common_seed: Seed to be used in case ``common_sampling = True``. - as_tensor_output: if ``True`` return torch.Tensor, else return - np.array. Default: ``True``. - allow_missing_keys: do not raise exception if key is missing. - - Example: - To apply `k`-space spikes randomly on the image only, with probability - 0.5, and log-intensity sampled from the interval [13, 15] for each - channel independently, one uses - ``RandKSpaceSpikeNoised("image", prob=0.5, img_intensity_range=(13,15), channel_wise=True)``. - """ - - def __init__( - self, - keys: KeysCollection, - global_prob: float = 1.0, - prob: float = 0.1, - img_intensity_range: Optional[Sequence[Union[Sequence[float], float]]] = None, - label_intensity_range: Optional[Sequence[Union[Sequence[float], float]]] = None, - channel_wise: bool = True, - common_sampling: bool = False, - common_seed: int = 42, - as_tensor_output: bool = True, - allow_missing_keys: bool = False, - ): - - MapTransform.__init__(self, keys, allow_missing_keys) - RandomizableTransform.__init__(self, global_prob) - - self.common_sampling = common_sampling - self.common_seed = common_seed - self.as_tensor_output = as_tensor_output - # the spikes artifact is amplitude dependent so we instantiate one per key - self.t_img = RandKSpaceSpikeNoise(prob, img_intensity_range, channel_wise, self.as_tensor_output) - self.t_label = RandKSpaceSpikeNoise(prob, label_intensity_range, channel_wise, self.as_tensor_output) - - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: - """ - Args: - data: Expects image/label to have dimensions (C, H, W) or - (C, H, W, D), where C is the channel. - """ - d = dict(data) - super().randomize(None) - - # In case the same spikes are desired for both image and label. - if self.common_sampling: - self.t_img.set_random_state(self.common_seed) - self.t_label.set_random_state(self.common_seed) - - for key in self.key_iterator(d): - if self._do_transform: - transform = self.t_img if key == "image" else self.t_label - d[key] = transform(d[key]) - else: - if isinstance(d[key], np.ndarray) and self.as_tensor_output: - d[key] = torch.Tensor(d[key]) - elif isinstance(d[key], torch.Tensor) and not self.as_tensor_output: - d[key] = self._to_numpy(d[key]) - return d - - def set_rand_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> None: - """ - Set the random state locally to control the randomness. - User should use this method instead of ``set_random_state``. - - Args: - seed: set the random state with an integer seed. - state: set the random state with a `np.random.RandomState` object.""" - - self.set_random_state(seed, state) - self.t_img.set_random_state(seed, state) - self.t_label.set_random_state(seed, state) - - def _to_numpy(self, d: Union[torch.Tensor, np.ndarray]) -> np.ndarray: - if isinstance(d, torch.Tensor): - d_numpy: np.ndarray = d.cpu().detach().numpy() - return d_numpy - class RandKSpaceSpikeNoised(RandomizableTransform, MapTransform): """ Dictionary-based version of :py:class:`monai.transforms.RandKSpaceSpikeNoise`. From 6a57969f0e1a5e92e735013f23477c11d79abae1 Mon Sep 17 00:00:00 2001 From: Yaniel Cabrera Date: Wed, 28 Jul 2021 23:41:57 +0300 Subject: [PATCH 3/8] Ignore torch.fft tests if not present Ignore tests with versions of Pytorch which lack the module fft. Signed-off-by: Yaniel Cabrera --- monai/transforms/__init__.py | 10 ++++- monai/transforms/intensity/array.py | 12 +++--- monai/transforms/intensity/dictionary.py | 17 +++++--- monai/transforms/transform.py | 52 +++++++++++++++--------- tests/test_gibbs_noise.py | 4 +- tests/test_gibbs_noised.py | 4 +- tests/test_rand_gibbs_noise.py | 4 +- tests/test_rand_gibbs_noised.py | 4 +- tests/test_rand_k_space_spike_noise.py | 4 +- tests/test_rand_k_space_spike_noised.py | 11 +++-- 10 files changed, 81 insertions(+), 41 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index d0c6987424..2b45cfae6e 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -306,7 +306,15 @@ ZoomD, ZoomDict, ) -from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform, Fourier +from .transform import ( + Fourier, + MapTransform, + Randomizable, + RandomizableTransform, + ThreadUnsafe, + Transform, + apply_transform, +) from .utility.array import ( AddChannel, AddExtremePointsChannel, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 3a60ecc7bc..d40ecfb89a 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -23,7 +23,7 @@ from monai.config import DtypeLike from monai.data.utils import get_random_patch, get_valid_patch_size from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter -from monai.transforms.transform import RandomizableTransform, Transform, Fourier +from monai.transforms.transform import Fourier, RandomizableTransform, Transform from monai.transforms.utils import rescale_array from monai.utils import ( PT_BEFORE_1_7, @@ -1402,9 +1402,9 @@ def _set_spike(self, k: torch.Tensor, idx: Tuple, val: Union[Sequence[float], fl else: k[idx] = val elif len(k.shape) == 4 and len(idx) == 3: - k[:, idx[0], idx[1], idx[2]] = val + k[:, idx[0], idx[1], idx[2]] = val # type: ignore elif len(k.shape) == 3 and len(idx) == 2: - k[:, idx[0], idx[1]] = val + k[:, idx[0], idx[1]] = val # type: ignore class RandKSpaceSpikeNoise(RandomizableTransform, Fourier): @@ -1486,7 +1486,7 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, if not isinstance(img, torch.Tensor): img = torch.Tensor(img) - + intensity_range = self._make_sequence(img) self._randomize(img, intensity_range) @@ -1521,9 +1521,9 @@ def _randomize(self, img: torch.Tensor, intensity_range: Sequence[Sequence[float spatial = tuple(self.R.randint(0, k) for k in img.shape[1:]) self.sampled_locs = [(i,) + spatial for i in range(img.shape[0])] if isinstance(intensity_range[0], Sequence): - self.sampled_k_intensity = [self.R.uniform(*p) for p in intensity_range] + self.sampled_k_intensity = [self.R.uniform(*p) for p in intensity_range] else: - self.sampled_k_intensity = [self.R.uniform(*self.intensity_range)] * len(img) + self.sampled_k_intensity = [self.R.uniform(*self.intensity_range)] * len(img) def _make_sequence(self, x: torch.Tensor) -> Sequence[Sequence[float]]: """ diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 945d6697f3..c24f7b67ca 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -1235,12 +1235,12 @@ class RandKSpaceSpikeNoised(RandomizableTransform, MapTransform): dictionary provided it is realized that the noise will be applied to the dictionary. intensity_ranges: Dictionary with intensity - ranges to sample for each key. Given a dictionary value of `(a, b)` the + ranges to sample for each key. Given a dictionary value of `(a, b)` the transform will sample the log-intensity from the interval `(a, b)` uniformly for all channels of the respective key. If a sequence of intevals `((a0, b0), (a1, b1), ...)` is given, then the transform will sample from each interval for each - respective channel. In the second case, the number of 2-tuples must - match the number of channels. Default ranges is `(0.95x, 1.10x)` + respective channel. In the second case, the number of 2-tuples must + match the number of channels. Default ranges is `(0.95x, 1.10x)` where `x` is the mean log-intensity for each channel. channel_wise: treat each channel independently. True by default. @@ -1279,8 +1279,14 @@ def __init__( self.as_tensor_output = as_tensor_output # the spikes artifact is amplitude dependent so we instantiate one per key self.transforms = {} - for k in self.keys: - self.transforms[k] = RandKSpaceSpikeNoise(prob, intensity_ranges[k], channel_wise, self.as_tensor_output) + if isinstance(intensity_ranges, Mapping): + for k in self.keys: + self.transforms[k] = RandKSpaceSpikeNoise( + prob, intensity_ranges[k], channel_wise, self.as_tensor_output + ) + else: + for k in self.keys: + self.transforms[k] = RandKSpaceSpikeNoise(prob, None, channel_wise, self.as_tensor_output) def __call__( self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] @@ -1321,7 +1327,6 @@ def set_rand_state(self, seed: Optional[int] = None, state: Optional[np.random.R for key in self.keys: self.transforms[key].set_random_state(seed, state) - def _to_numpy(self, d: Union[torch.Tensor, np.ndarray]) -> np.ndarray: if isinstance(d, torch.Tensor): d_numpy: np.ndarray = d.cpu().detach().numpy() diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 5807d9dd02..d5efb35262 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -23,7 +23,15 @@ from monai.config import KeysCollection from monai.utils import MAX_SEED, ensure_tuple -__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform", "Fourier"] +__all__ = [ + "ThreadUnsafe", + "apply_transform", + "Randomizable", + "RandomizableTransform", + "Transform", + "MapTransform", + "Fourier", +] ReturnType = TypeVar("ReturnType") @@ -366,36 +374,42 @@ def key_iterator( elif not self.allow_missing_keys: raise KeyError(f"Key was missing ({key}) and allow_missing_keys==False") + class Fourier: """ Helper class storing Fourier mappings """ - + @staticmethod - def shift_fourier(x: Union[torch.Tensor, np.ndarray], n_dims: int) -> Union[torch.Tensor, np.ndarray]: + def shift_fourier(x: torch.Tensor, n_dims: int) -> torch.Tensor: """ - Applies fourier transform and shifts its output. - Only the spatial dimensions get transformed. + Applies fourier transform and shifts the zero-frequency component to the + center of the spectrum. Only the spatial dimensions get transformed. Args: - x (np.ndarray): tensor to fourier transform. + x: image to transform. + n_dims: number of spatial dimensions. + Returns + k: k-space data. """ - # argument is dim if torch, else axes - mod, arg = (torch, "dim") if type(x) is torch.Tensor else (np, "axes") - arg_dict = {arg: tuple(range(-n_dims, 0))} - out = mod.fft.fftshift(mod.fft.fftn(x, **arg_dict), **arg_dict) - return out + k: torch.Tensor = torch.fft.fftshift( + torch.fft.fftn(x, dim=tuple(range(-n_dims, 0))), dim=tuple(range(-n_dims, 0)) + ) + return k @staticmethod - def inv_shift_fourier(k: Union[torch.Tensor, np.ndarray], n_dims: int) -> Union[torch.Tensor, np.ndarray]: + def inv_shift_fourier(k: torch.Tensor, n_dims: int) -> torch.Tensor: """ Applies inverse shift and fourier transform. Only the spatial dimensions are transformed. - """ - dims = tuple(range(-n_dims, 0)) - if type(k) is torch.Tensor: - out = torch.fft.ifftn(torch.fft.ifftshift(k, dim=dims), dim=dims) - else: - out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims) - return out.real \ No newline at end of file + Args: + k: k-space data. + n_dims: number of spatial dimensions. + Returns: + x: tensor in image space. + """ + x: torch.Tensor = torch.fft.ifftn( + torch.fft.ifftshift(k, dim=tuple(range(-n_dims, 0))), dim=tuple(range(-n_dims, 0)) + ).real + return x diff --git a/tests/test_gibbs_noise.py b/tests/test_gibbs_noise.py index 83cba56938..a9d326857b 100644 --- a/tests/test_gibbs_noise.py +++ b/tests/test_gibbs_noise.py @@ -19,6 +19,7 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import GibbsNoise from monai.utils.misc import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): @@ -26,7 +27,8 @@ for as_tensor_input in (True, False): TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) - +@SkipIfBeforePyTorchVersion((1, 7)) +@SkipIfNoModule("torch.fft") class TestGibbsNoise(unittest.TestCase): def setUp(self): set_determinism(0) diff --git a/tests/test_gibbs_noised.py b/tests/test_gibbs_noised.py index 0e02feb341..c3c19869ba 100644 --- a/tests/test_gibbs_noised.py +++ b/tests/test_gibbs_noised.py @@ -19,6 +19,7 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import GibbsNoised from monai.utils.misc import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): @@ -28,7 +29,8 @@ KEYS = ["im", "label"] - +@SkipIfBeforePyTorchVersion((1, 7)) +@SkipIfNoModule("torch.fft") class TestGibbsNoised(unittest.TestCase): def setUp(self): set_determinism(0) diff --git a/tests/test_rand_gibbs_noise.py b/tests/test_rand_gibbs_noise.py index 94948c5a0d..b551653065 100644 --- a/tests/test_rand_gibbs_noise.py +++ b/tests/test_rand_gibbs_noise.py @@ -19,6 +19,7 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import RandGibbsNoise from monai.utils.misc import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): @@ -26,7 +27,8 @@ for as_tensor_input in (True, False): TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) - +@SkipIfBeforePyTorchVersion((1, 7)) +@SkipIfNoModule("torch.fft") class TestRandGibbsNoise(unittest.TestCase): def setUp(self): set_determinism(0) diff --git a/tests/test_rand_gibbs_noised.py b/tests/test_rand_gibbs_noised.py index 986f4c02ae..4d4d389d21 100644 --- a/tests/test_rand_gibbs_noised.py +++ b/tests/test_rand_gibbs_noised.py @@ -19,6 +19,7 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import RandGibbsNoised from monai.utils.misc import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): @@ -28,7 +29,8 @@ KEYS = ["im", "label"] - +@SkipIfBeforePyTorchVersion((1, 7)) +@SkipIfNoModule("torch.fft") class TestRandGibbsNoised(unittest.TestCase): def setUp(self): set_determinism(0) diff --git a/tests/test_rand_k_space_spike_noise.py b/tests/test_rand_k_space_spike_noise.py index ba9156c5b2..22c4780c13 100644 --- a/tests/test_rand_k_space_spike_noise.py +++ b/tests/test_rand_k_space_spike_noise.py @@ -19,6 +19,7 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import KSpaceSpikeNoise, RandKSpaceSpikeNoise from monai.utils.misc import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): @@ -27,7 +28,8 @@ for channel_wise in (True, False): TEST_CASES.append((shape, as_tensor_output, as_tensor_input, channel_wise)) - +@SkipIfBeforePyTorchVersion((1, 7)) +@SkipIfNoModule("torch.fft") class TestRandKSpaceSpikeNoise(unittest.TestCase): def setUp(self): set_determinism(0) diff --git a/tests/test_rand_k_space_spike_noised.py b/tests/test_rand_k_space_spike_noised.py index 00f2bea4fe..1f8199eec6 100644 --- a/tests/test_rand_k_space_spike_noised.py +++ b/tests/test_rand_k_space_spike_noised.py @@ -19,6 +19,7 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import RandKSpaceSpikeNoised from monai.utils.misc import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): @@ -29,6 +30,8 @@ KEYS = ["image", "label"] +@SkipIfBeforePyTorchVersion((1, 7)) +@SkipIfNoModule("torch.fft") class TestKSpaceSpikeNoised(unittest.TestCase): def setUp(self): set_determinism(0) @@ -50,7 +53,7 @@ def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): data = self.get_data(im_shape, as_tensor_input) - intensity_ranges = {"image":(13, 15), "label":(13,15)} + intensity_ranges = {"image": (13, 15), "label": (13, 15)} t = RandKSpaceSpikeNoised( KEYS, global_prob=1.0, @@ -72,7 +75,7 @@ def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): @parameterized.expand(TEST_CASES) def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): data = self.get_data(im_shape, as_tensor_input) - intensity_ranges = {"image":(13, 15), "label":(13,15)} + intensity_ranges = {"image": (13, 15), "label": (13, 15)} t1 = RandKSpaceSpikeNoised( KEYS, global_prob=0.0, @@ -101,7 +104,7 @@ def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): def test_intensity(self, im_shape, as_tensor_output, as_tensor_input): data = self.get_data(im_shape, as_tensor_input) - intensity_ranges = {"image":(13, 13.1), "label":(13,13.1)} + intensity_ranges = {"image": (13, 13.1), "label": (13, 13.1)} t = RandKSpaceSpikeNoised( KEYS, global_prob=1.0, @@ -123,7 +126,7 @@ def test_same_transformation(self, im_shape, _, as_tensor_input): # use same image for both dictionary entries to check same trans is applied to them data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} - intensity_ranges = {"image":(13, 15), "label":(13,15)} + intensity_ranges = {"image": (13, 15), "label": (13, 15)} # use common_sampling = True to ask for the same transformation t = RandKSpaceSpikeNoised( KEYS, From b555520a7a7a3651f462450a198fbf62a4eb2eb2 Mon Sep 17 00:00:00 2001 From: Yaniel Cabrera Date: Thu, 29 Jul 2021 01:07:17 +0300 Subject: [PATCH 4/8] update Signed-off-by: Yaniel Cabrera --- tests/test_gibbs_noise.py | 3 ++- tests/test_gibbs_noised.py | 3 ++- tests/test_k_space_spike_noise.py | 3 +++ tests/test_k_space_spike_noised.py | 3 +++ tests/test_rand_gibbs_noise.py | 3 ++- tests/test_rand_gibbs_noised.py | 3 ++- tests/test_rand_k_space_spike_noise.py | 3 ++- tests/test_rand_k_space_spike_noised.py | 2 +- 8 files changed, 17 insertions(+), 6 deletions(-) diff --git a/tests/test_gibbs_noise.py b/tests/test_gibbs_noise.py index a9d326857b..264e2e630a 100644 --- a/tests/test_gibbs_noise.py +++ b/tests/test_gibbs_noise.py @@ -27,7 +27,8 @@ for as_tensor_input in (True, False): TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) -@SkipIfBeforePyTorchVersion((1, 7)) + +@SkipIfBeforePyTorchVersion((1, 8)) @SkipIfNoModule("torch.fft") class TestGibbsNoise(unittest.TestCase): def setUp(self): diff --git a/tests/test_gibbs_noised.py b/tests/test_gibbs_noised.py index c3c19869ba..8ad4839338 100644 --- a/tests/test_gibbs_noised.py +++ b/tests/test_gibbs_noised.py @@ -29,7 +29,8 @@ KEYS = ["im", "label"] -@SkipIfBeforePyTorchVersion((1, 7)) + +@SkipIfBeforePyTorchVersion((1, 8)) @SkipIfNoModule("torch.fft") class TestGibbsNoised(unittest.TestCase): def setUp(self): diff --git a/tests/test_k_space_spike_noise.py b/tests/test_k_space_spike_noise.py index 53661d5fcb..bb6d05e676 100644 --- a/tests/test_k_space_spike_noise.py +++ b/tests/test_k_space_spike_noise.py @@ -20,6 +20,7 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import KSpaceSpikeNoise from monai.utils.misc import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): @@ -28,6 +29,8 @@ TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) +@SkipIfBeforePyTorchVersion((1, 8)) +@SkipIfNoModule("torch.fft") class TestKSpaceSpikeNoise(unittest.TestCase): def setUp(self): set_determinism(0) diff --git a/tests/test_k_space_spike_noised.py b/tests/test_k_space_spike_noised.py index e5d2dfb6f8..e891bd4568 100644 --- a/tests/test_k_space_spike_noised.py +++ b/tests/test_k_space_spike_noised.py @@ -20,6 +20,7 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import KSpaceSpikeNoised from monai.utils.misc import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): @@ -30,6 +31,8 @@ KEYS = ["image", "label"] +@SkipIfBeforePyTorchVersion((1, 8)) +@SkipIfNoModule("torch.fft") class TestKSpaceSpikeNoised(unittest.TestCase): def setUp(self): set_determinism(0) diff --git a/tests/test_rand_gibbs_noise.py b/tests/test_rand_gibbs_noise.py index b551653065..a0701d09c3 100644 --- a/tests/test_rand_gibbs_noise.py +++ b/tests/test_rand_gibbs_noise.py @@ -27,7 +27,8 @@ for as_tensor_input in (True, False): TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) -@SkipIfBeforePyTorchVersion((1, 7)) + +@SkipIfBeforePyTorchVersion((1, 8)) @SkipIfNoModule("torch.fft") class TestRandGibbsNoise(unittest.TestCase): def setUp(self): diff --git a/tests/test_rand_gibbs_noised.py b/tests/test_rand_gibbs_noised.py index 4d4d389d21..72188a93b5 100644 --- a/tests/test_rand_gibbs_noised.py +++ b/tests/test_rand_gibbs_noised.py @@ -29,7 +29,8 @@ KEYS = ["im", "label"] -@SkipIfBeforePyTorchVersion((1, 7)) + +@SkipIfBeforePyTorchVersion((1, 8)) @SkipIfNoModule("torch.fft") class TestRandGibbsNoised(unittest.TestCase): def setUp(self): diff --git a/tests/test_rand_k_space_spike_noise.py b/tests/test_rand_k_space_spike_noise.py index 22c4780c13..71f7e36d9b 100644 --- a/tests/test_rand_k_space_spike_noise.py +++ b/tests/test_rand_k_space_spike_noise.py @@ -28,7 +28,8 @@ for channel_wise in (True, False): TEST_CASES.append((shape, as_tensor_output, as_tensor_input, channel_wise)) -@SkipIfBeforePyTorchVersion((1, 7)) + +@SkipIfBeforePyTorchVersion((1, 8)) @SkipIfNoModule("torch.fft") class TestRandKSpaceSpikeNoise(unittest.TestCase): def setUp(self): diff --git a/tests/test_rand_k_space_spike_noised.py b/tests/test_rand_k_space_spike_noised.py index 1f8199eec6..d61b83e2d5 100644 --- a/tests/test_rand_k_space_spike_noised.py +++ b/tests/test_rand_k_space_spike_noised.py @@ -30,7 +30,7 @@ KEYS = ["image", "label"] -@SkipIfBeforePyTorchVersion((1, 7)) +@SkipIfBeforePyTorchVersion((1, 8)) @SkipIfNoModule("torch.fft") class TestKSpaceSpikeNoised(unittest.TestCase): def setUp(self): From 46fb79232e04104d807ad53a03e4eb6ca683314f Mon Sep 17 00:00:00 2001 From: Yaniel Cabrera Date: Thu, 29 Jul 2021 01:54:59 +0300 Subject: [PATCH 5/8] typing update Signed-off-by: Yaniel Cabrera --- monai/transforms/intensity/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index d40ecfb89a..2123e46c78 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1521,9 +1521,9 @@ def _randomize(self, img: torch.Tensor, intensity_range: Sequence[Sequence[float spatial = tuple(self.R.randint(0, k) for k in img.shape[1:]) self.sampled_locs = [(i,) + spatial for i in range(img.shape[0])] if isinstance(intensity_range[0], Sequence): - self.sampled_k_intensity = [self.R.uniform(*p) for p in intensity_range] + self.sampled_k_intensity = [self.R.uniform(p[0], p[1]) for p in intensity_range] else: - self.sampled_k_intensity = [self.R.uniform(*self.intensity_range)] * len(img) + self.sampled_k_intensity = [self.R.uniform(intensity_range[0], intensity_range[1])] * len(img) def _make_sequence(self, x: torch.Tensor) -> Sequence[Sequence[float]]: """ From f8fa16eac4e90dcbc0ab19932028fd2ff80e08c8 Mon Sep 17 00:00:00 2001 From: Yaniel Cabrera Date: Mon, 2 Aug 2021 12:33:35 +0300 Subject: [PATCH 6/8] added unit test for Fourier Signed-off-by: Yaniel Cabrera --- docs/source/transforms.rst | 4 ++++ monai/transforms/intensity/array.py | 2 +- monai/transforms/transform.py | 12 ++++++------ 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 01b1cb00bb..1c3ee288a1 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -53,6 +53,10 @@ Generic Interfaces .. autoclass:: Decollated :members: +`Fourier` +^^^^^^^^^^^^^ +.. autoclass:: Fourier + :members: Vanilla Transforms ------------------ diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 2123e46c78..6e7270859e 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1523,7 +1523,7 @@ def _randomize(self, img: torch.Tensor, intensity_range: Sequence[Sequence[float if isinstance(intensity_range[0], Sequence): self.sampled_k_intensity = [self.R.uniform(p[0], p[1]) for p in intensity_range] else: - self.sampled_k_intensity = [self.R.uniform(intensity_range[0], intensity_range[1])] * len(img) + self.sampled_k_intensity = [self.R.uniform(intensity_range[0], intensity_range[1])] * len(img) # type: ignore def _make_sequence(self, x: torch.Tensor) -> Sequence[Sequence[float]]: """ diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index d5efb35262..97cc2f21fc 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -387,10 +387,10 @@ def shift_fourier(x: torch.Tensor, n_dims: int) -> torch.Tensor: center of the spectrum. Only the spatial dimensions get transformed. Args: - x: image to transform. - n_dims: number of spatial dimensions. + x: Image to transform. + n_dims: Number of spatial dimensions. Returns - k: k-space data. + k: K-space data. """ k: torch.Tensor = torch.fft.fftshift( torch.fft.fftn(x, dim=tuple(range(-n_dims, 0))), dim=tuple(range(-n_dims, 0)) @@ -404,10 +404,10 @@ def inv_shift_fourier(k: torch.Tensor, n_dims: int) -> torch.Tensor: dimensions are transformed. Args: - k: k-space data. - n_dims: number of spatial dimensions. + k: K-space data. + n_dims: Number of spatial dimensions. Returns: - x: tensor in image space. + x: Tensor in image space. """ x: torch.Tensor = torch.fft.ifftn( torch.fft.ifftshift(k, dim=tuple(range(-n_dims, 0))), dim=tuple(range(-n_dims, 0)) From 84a186d6671720faa15d36081949fe09fe48a740 Mon Sep 17 00:00:00 2001 From: Yaniel Cabrera Date: Mon, 2 Aug 2021 12:36:14 +0300 Subject: [PATCH 7/8] added unit test for Fourier Signed-off-by: Yaniel Cabrera --- tests/test_fourier.py | 70 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 tests/test_fourier.py diff --git a/tests/test_fourier.py b/tests/test_fourier.py new file mode 100644 index 0000000000..488bf0cbf9 --- /dev/null +++ b/tests/test_fourier.py @@ -0,0 +1,70 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data.synthetic import create_test_image_2d, create_test_image_3d +from monai.transforms import Fourier +from monai.utils.misc import set_determinism +from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule + +TEST_CASES = [((128, 64),), ((64, 48, 80),)] +# for shape in ((128, 64), (64, 48, 80)): +# TEST_CASES.append(shape) + + +@SkipIfBeforePyTorchVersion((1, 8)) +@SkipIfNoModule("torch.fft") +class TestFourier(unittest.TestCase): + def setUp(self): + set_determinism(0) + super().setUp() + + def tearDown(self): + set_determinism(None) + + @staticmethod + def get_data(img_shape): + create_test_image = create_test_image_2d if len(img_shape) == 2 else create_test_image_3d + im = create_test_image(*img_shape, num_objs=4, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] + return torch.Tensor(im) + + @parameterized.expand(TEST_CASES) + def test_forward(self, img_shape): + n_dims = len(img_shape[1:]) + x = self.get_data(img_shape) + t = Fourier() + out = t.shift_fourier(x, n_dims) + + expect = torch.fft.fftshift(torch.fft.fftn(x, dim=tuple(range(-n_dims, 0))), dim=tuple(range(-n_dims, 0))) + + np.testing.assert_allclose(out, expect) + + @parameterized.expand(TEST_CASES) + def test_backward(self, img_shape): + n_dims = len(img_shape[1:]) + x = self.get_data(img_shape) + t = Fourier() + out = t.inv_shift_fourier(x, n_dims) + + expect = torch.fft.ifftn( + torch.fft.ifftshift(x, dim=tuple(range(-n_dims, 0))), dim=tuple(range(-n_dims, 0)) + ).real + + np.testing.assert_allclose(out, expect) + + +if __name__ == "__main__": + unittest.main() From 9e8519afbf3e48f57f0c06e7bb17c2e44b4c8c88 Mon Sep 17 00:00:00 2001 From: Yaniel Cabrera Date: Mon, 2 Aug 2021 13:21:41 +0300 Subject: [PATCH 8/8] fixing black Signed-off-by: Yaniel Cabrera --- monai/transforms/intensity/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 6e7270859e..4533f333ce 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1523,7 +1523,7 @@ def _randomize(self, img: torch.Tensor, intensity_range: Sequence[Sequence[float if isinstance(intensity_range[0], Sequence): self.sampled_k_intensity = [self.R.uniform(p[0], p[1]) for p in intensity_range] else: - self.sampled_k_intensity = [self.R.uniform(intensity_range[0], intensity_range[1])] * len(img) # type: ignore + self.sampled_k_intensity = [self.R.uniform(intensity_range[0], intensity_range[1])] * len(img) # type: ignore def _make_sequence(self, x: torch.Tensor) -> Sequence[Sequence[float]]: """