diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 846783953d..0a91805d80 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -624,6 +624,18 @@ Spatial :members: :special-members: __call__ +`GridDistortion` +"""""""""""""""" +.. autoclass:: GridDistortion + :members: + :special-members: __call__ + +`RandGridDistortion` +"""""""""""""""""""" +.. autoclass:: RandGridDistortion + :members: + :special-members: __call__ + `Rand2DElastic` """"""""""""""" .. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Rand2DElastic.png @@ -1446,6 +1458,18 @@ Spatial (Dict) :members: :special-members: __call__ +`GridDistortiond` +""""""""""""""""" +.. autoclass:: GridDistortiond + :members: + :special-members: __call__ + +`RandGridDistortiond` +""""""""""""""""""""" +.. autoclass:: RandGridDistortiond + :members: + :special-members: __call__ + Utility (Dict) ^^^^^^^^^^^^^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 20e24b3958..1223254db5 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -279,6 +279,7 @@ Affine, AffineGrid, Flip, + GridDistortion, Orientation, Rand2DElastic, Rand3DElastic, @@ -287,6 +288,7 @@ RandAxisFlip, RandDeformGrid, RandFlip, + RandGridDistortion, RandRotate, RandRotate90, RandZoom, @@ -307,6 +309,9 @@ Flipd, FlipD, FlipDict, + GridDistortiond, + GridDistortionD, + GridDistortionDict, Orientationd, OrientationD, OrientationDict, @@ -325,6 +330,9 @@ RandFlipd, RandFlipD, RandFlipDict, + RandGridDistortiond, + RandGridDistortionD, + RandGridDistortionDict, RandRotate90d, RandRotate90D, RandRotate90Dict, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index ecba432f71..5a61d67f2b 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -58,6 +58,7 @@ "Spacing", "Orientation", "Flip", + "GridDistortion", "Resize", "Rotate", "Zoom", @@ -65,6 +66,7 @@ "RandRotate90", "RandRotate", "RandFlip", + "RandGridDistortion", "RandAxisFlip", "RandZoom", "AffineGrid", @@ -2057,3 +2059,182 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: # but user input is 1-based (because channel dim is 0) coord_channels = coord_channels[[s - 1 for s in self.spatial_channels]] return concatenate((img, coord_channels), axis=0) + + +class GridDistortion(Transform): + + backend = [TransformBackends.TORCH] + + def __init__( + self, + num_cells: int, + distort_steps: List[Tuple], + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + device: Optional[torch.device] = None, + ) -> None: + """ + Grid distortion transform. Refer to: + https://github.com/albumentations-team/albumentations/blob/master/albumentations/augmentations/transforms.py + + Args: + num_cells: number of grid cells on each dimension. + distort_steps: This argument is a list of tuples, where each tuple contains the distort steps of the + corresponding dimensions (in the order of H, W[, D]). The length of each tuple equals to `num_cells + 1`. + Each value in the tuple represents the distort step of the related cell. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + device: device on which the tensor will be allocated. + + """ + self.resampler = Resample( + mode=mode, + padding_mode=padding_mode, + device=device, + ) + for dim_steps in distort_steps: + if len(dim_steps) != num_cells + 1: + raise ValueError("the length of each tuple in `distort_steps` must equal to `num_cells + 1`.") + self.num_cells = num_cells + self.distort_steps = distort_steps + self.device = device + + def __call__( + self, + img: NdarrayOrTensor, + distort_steps: Optional[List[Tuple]] = None, + mode: Optional[Union[GridSampleMode, str]] = None, + padding_mode: Optional[Union[GridSamplePadMode, str]] = None, + ) -> NdarrayOrTensor: + """ + Args: + img: shape must be (num_channels, H, W[, D]). + distort_steps: This argument is a list of tuples, where each tuple contains the distort steps of the + corresponding dimensions (in the order of H, W[, D]). The length of each tuple equals to `num_cells + 1`. + Each value in the tuple represents the distort step of the related cell. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + + """ + distort_steps = self.distort_steps if distort_steps is None else distort_steps + if len(img.shape) != len(distort_steps) + 1: + raise ValueError("the spatial size of `img` does not match with the length of `distort_steps`") + + all_ranges = [] + for dim_idx, dim_size in enumerate(img.shape[1:]): + dim_distort_steps = distort_steps[dim_idx] + ranges = torch.zeros(dim_size, dtype=torch.float32) + cell_size = dim_size // self.num_cells + prev = 0 + for idx in range(self.num_cells + 1): + start = int(idx * cell_size) + end = start + cell_size + if end > dim_size: + end = dim_size + cur = dim_size + else: + cur = prev + cell_size * dim_distort_steps[idx] + ranges[start:end] = torch.linspace(prev, cur, end - start) + prev = cur + ranges = ranges - (dim_size - 1.0) / 2.0 + all_ranges.append(ranges) + + coords = torch.meshgrid(*all_ranges) + grid = torch.stack([*coords, torch.ones_like(coords[0])]) + + return self.resampler(img, grid=grid, mode=mode, padding_mode=padding_mode) # type: ignore + + +class RandGridDistortion(RandomizableTransform): + + backend = [TransformBackends.TORCH] + + def __init__( + self, + num_cells: int = 5, + prob: float = 0.1, + spatial_dims: int = 2, + distort_limit: Union[Tuple[float, float], float] = (-0.03, 0.03), + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + device: Optional[torch.device] = None, + ) -> None: + """ + Random grid distortion transform. Refer to: + https://github.com/albumentations-team/albumentations/blob/master/albumentations/augmentations/transforms.py + + Args: + num_cells: number of grid cells on each dimension. + prob: probability of returning a randomized grid distortion transform. Defaults to 0.1. + spatial_dims: spatial dimension of input data. The value should be 2 or 3. Defaults to 2. + distort_limit: range to randomly distort. + If single number, distort_limit is picked from (-distort_limit, distort_limit). + Defaults to (-0.03, 0.03). + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + device: device on which the tensor will be allocated. + + """ + RandomizableTransform.__init__(self, prob) + if num_cells <= 0: + raise ValueError("num_cells should be no less than 1.") + self.num_cells = num_cells + if spatial_dims not in [2, 3]: + raise ValueError("spatial_size should be 2 or 3.") + self.spatial_dims = spatial_dims + if isinstance(distort_limit, (int, float)): + self.distort_limit = (min(-distort_limit, distort_limit), max(-distort_limit, distort_limit)) + else: + self.distort_limit = (min(distort_limit), max(distort_limit)) + self.distort_steps = [tuple([1 + self.distort_limit[0]] * (self.num_cells + 1)) for _ in range(spatial_dims)] + self.grid_distortion = GridDistortion( + num_cells=num_cells, + distort_steps=self.distort_steps, + mode=mode, + padding_mode=padding_mode, + device=device, + ) + + def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) + self.distort_steps = [ + tuple( + 1 + self.R.uniform(low=self.distort_limit[0], high=self.distort_limit[1]) + for _ in range(self.num_cells + 1) + ) + for _dim in range(self.spatial_dims) + ] + + def __call__( + self, + img: NdarrayOrTensor, + mode: Optional[Union[GridSampleMode, str]] = None, + padding_mode: Optional[Union[GridSamplePadMode, str]] = None, + ) -> NdarrayOrTensor: + """ + Args: + img: shape must be (num_channels, H, W[, D]). + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + + """ + self.randomize() + if not self._do_transform: + return img + return self.grid_distortion(img, distort_steps=self.distort_steps, mode=mode, padding_mode=padding_mode) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 7bc5498960..d8c01fc12e 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -17,7 +17,7 @@ from copy import deepcopy from enum import Enum -from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -33,12 +33,14 @@ Affine, AffineGrid, Flip, + GridDistortion, Orientation, Rand2DElastic, Rand3DElastic, RandAffine, RandAxisFlip, RandFlip, + RandGridDistortion, RandRotate, RandZoom, Resize, @@ -78,6 +80,8 @@ "Rand3DElasticd", "Flipd", "RandFlipd", + "GridDistortiond", + "RandGridDistortiond", "RandAxisFlipd", "Rotated", "RandRotated", @@ -105,6 +109,10 @@ "FlipDict", "RandFlipD", "RandFlipDict", + "GridDistortionD", + "GridDistortionDict", + "RandGridDistortionD", + "RandGridDistortionDict", "RandAxisFlipD", "RandAxisFlipDict", "RotateD", @@ -1769,6 +1777,129 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d +class GridDistortiond(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.GridDistortion`. + """ + + backend = GridDistortion.backend + + def __init__( + self, + keys: KeysCollection, + num_cells: int, + distort_steps: List[Tuple], + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + device: Optional[torch.device] = None, + allow_missing_keys: bool = False, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + num_cells: number of grid cells on each dimension. + distort_steps: This argument is a list of tuples, where each tuple contains the distort steps of the + corresponding dimensions (in the order of H, W[, D]). The length of each tuple equals to `num_cells + 1`. + Each value in the tuple represents the distort step of the related cell. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of string, each element corresponds to a key in ``keys``. + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"reflection"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of string, each element corresponds to a key in ``keys``. + device: device on which the tensor will be allocated. + allow_missing_keys: don't raise exception if key is missing. + + """ + super().__init__(keys, allow_missing_keys) + self.grid_distortion = GridDistortion( + num_cells=num_cells, + distort_steps=distort_steps, + device=device, + ) + self.mode = ensure_tuple_rep(mode, len(self.keys)) + self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + d[key] = self.grid_distortion(d[key], mode=mode, padding_mode=padding_mode) + return d + + +class RandGridDistortiond(RandomizableTransform, MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.RandGridDistortion`. + """ + + backend = RandGridDistortion.backend + + def __init__( + self, + keys: KeysCollection, + num_cells: int = 5, + prob: float = 0.1, + spatial_dims: int = 2, + distort_limit: Union[Tuple[float, float], float] = (-0.03, 0.03), + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + device: Optional[torch.device] = None, + allow_missing_keys: bool = False, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + num_cells: number of grid cells on each dimension. + prob: probability of returning a randomized grid distortion transform. Defaults to 0.1. + spatial_dims: spatial dimension of input data. Defaults to 2. + distort_limit: range to randomly distort. + If single number, distort_limit is picked from (-distort_limit, distort_limit). + Defaults to (-0.03, 0.03). + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of string, each element corresponds to a key in ``keys``. + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"reflection"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of string, each element corresponds to a key in ``keys``. + device: device on which the tensor will be allocated. + allow_missing_keys: don't raise exception if key is missing. + + """ + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob) + self.rand_grid_distortion = RandGridDistortion( + num_cells=num_cells, + prob=1.0, + spatial_dims=spatial_dims, + distort_limit=distort_limit, + device=device, + ) + self.mode = ensure_tuple_rep(mode, len(self.keys)) + self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandGridDistortiond": + super().set_random_state(seed, state) + self.rand_grid_distortion.set_random_state(seed, state) + return self + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + self.randomize(None) + if not self._do_transform: + return d + + self.rand_grid_distortion.randomize(None) + for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + d[key] = self.rand_grid_distortion(d[key], mode=mode, padding_mode=padding_mode) + return d + + SpacingD = SpacingDict = Spacingd OrientationD = OrientationDict = Orientationd Rotate90D = Rotate90Dict = Rotate90d @@ -1780,6 +1911,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N Rand3DElasticD = Rand3DElasticDict = Rand3DElasticd FlipD = FlipDict = Flipd RandFlipD = RandFlipDict = RandFlipd +GridDistortionD = GridDistortionDict = GridDistortiond +RandGridDistortionD = RandGridDistortionDict = RandGridDistortiond RandAxisFlipD = RandAxisFlipDict = RandAxisFlipd RotateD = RotateDict = Rotated RandRotateD = RandRotateDict = RandRotated diff --git a/tests/test_grid_distortion.py b/tests/test_grid_distortion.py new file mode 100644 index 0000000000..baed797c86 --- /dev/null +++ b/tests/test_grid_distortion.py @@ -0,0 +1,127 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import GridDistortion +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + num_cells = 3 + distort_steps = [(1.5,) * (1 + num_cells)] * 2 + TESTS.append( + [ + dict( + num_cells=num_cells, + distort_steps=distort_steps, + mode="nearest", + padding_mode="zeros", + ), + p(np.indices([6, 6]).astype(np.float32)), + p( + np.array( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [3.0, 3.0, 3.0, 0.0, 0.0, 0.0], + [3.0, 3.0, 3.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 3.0, 3.0, 0.0, 0.0, 0.0], + [0.0, 3.0, 3.0, 0.0, 0.0, 0.0], + [0.0, 3.0, 3.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + ] + ).astype(np.float32) + ), + ] + ) + num_cells = 2 + distort_steps = [(1.5,) * (1 + num_cells), (1.0,) * (1 + num_cells)] + TESTS.append( + [ + dict( + num_cells=num_cells, + distort_steps=distort_steps, + mode="bilinear", + padding_mode="reflection", + ), + p(np.indices([6, 6]).astype(np.float32)), + p( + np.array( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [2.25, 2.25, 2.25, 2.25, 2.25, 2.25], + [4.5, 4.5, 4.5, 4.5, 4.5, 4.5], + [4.5, 4.5, 4.5, 4.5, 4.5, 4.5], + [3.25, 3.25, 3.25, 3.25, 3.25, 3.25], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ], + [ + [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], + [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], + [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], + [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], + [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], + [0.0, 1.5, 3.0, 3.0, 4.5, 4.0], + ], + ] + ).astype(np.float32) + ), + ] + ) + num_cells = 2 + distort_steps = [(1.25,) * (1 + num_cells)] * 3 + TESTS.append( + [ + dict( + num_cells=num_cells, + distort_steps=distort_steps, + mode="nearest", + padding_mode="zeros", + ), + p(np.indices([3, 3, 3])[:1].astype(np.float32)), + p( + np.array( + [ + [ + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + ] + ] + ).astype(np.float32) + ), + ] + ) + + +class TestGridDistortion(unittest.TestCase): + @parameterized.expand(TESTS) + def test_grid_distortion(self, input_param, input_data, expected_val): + g = GridDistortion(**input_param) + result = g(input_data) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_grid_distortiond.py b/tests/test_grid_distortiond.py new file mode 100644 index 0000000000..e216f16cd4 --- /dev/null +++ b/tests/test_grid_distortiond.py @@ -0,0 +1,85 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import GridDistortiond +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + num_cells = 2 + distort_steps = [(1.5,) * (1 + num_cells)] * 2 + img = np.indices([6, 6]).astype(np.float32) + TESTS.append( + [ + dict( + keys=["img", "mask"], + num_cells=num_cells, + distort_steps=distort_steps, + mode=["bilinear", "nearest"], + padding_mode=["reflection", "zeros"], + ), + {"img": p(img), "mask": p(np.ones_like(img[:1]))}, + p( + np.array( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [2.25, 2.25, 2.25, 2.25, 2.25, 2.25], + [4.5, 4.5, 4.5, 4.5, 4.5, 4.5], + [4.5, 4.5, 4.5, 4.5, 4.5, 4.5], + [3.25, 3.25, 3.25, 3.25, 3.25, 3.25], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ], + [ + [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], + [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], + [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], + [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], + [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], + [0.0, 2.25, 4.5, 4.5, 3.25, 1.0], + ], + ] + ).astype(np.float32) + ), + p( + np.array( + [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] + ] + ).astype(np.float32) + ), + ] + ) + + +class TestGridDistortiond(unittest.TestCase): + @parameterized.expand(TESTS) + def test_grid_distortiond(self, input_param, input_data, expected_val_img, expected_val_mask): + g = GridDistortiond(**input_param) + result = g(input_data) + assert_allclose(result["img"], expected_val_img, rtol=1e-4, atol=1e-4) + assert_allclose(result["mask"], expected_val_mask, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_grid_distortion.py b/tests/test_rand_grid_distortion.py new file mode 100644 index 0000000000..f7e4969328 --- /dev/null +++ b/tests/test_rand_grid_distortion.py @@ -0,0 +1,110 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandGridDistortion +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + num_cells = 2 + seed = 0 + TESTS.append( + [ + dict( + num_cells=num_cells, + prob=1.0, + spatial_dims=2, + distort_limit=0.5, + mode="nearest", + padding_mode="zeros", + ), + seed, + p(np.indices([6, 6]).astype(np.float32)), + p( + np.array( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [2.0, 2.0, 2.0, 2.0, 2.0, 0.0], + [4.0, 4.0, 4.0, 4.0, 4.0, 0.0], + [4.0, 4.0, 4.0, 4.0, 4.0, 0.0], + [5.0, 5.0, 5.0, 5.0, 5.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 1.0, 3.0, 3.0, 4.0, 0.0], + [0.0, 1.0, 3.0, 3.0, 4.0, 0.0], + [0.0, 1.0, 3.0, 3.0, 4.0, 0.0], + [0.0, 1.0, 3.0, 3.0, 4.0, 0.0], + [0.0, 1.0, 3.0, 3.0, 4.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + ] + ).astype(np.float32) + ), + ] + ) + num_cells = 2 + seed = 1 + TESTS.append( + [ + dict( + num_cells=num_cells, + prob=1.0, + spatial_dims=2, + distort_limit=0.1, + mode="bilinear", + padding_mode="reflection", + ), + seed, + p(np.indices([6, 6]).astype(np.float32)), + p( + np.array( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.5660975, 1.5660975, 1.5660975, 1.5660975, 1.5660974, 1.5660975], + [3.132195, 3.132195, 3.132195, 3.132195, 3.132195, 3.132195], + [3.132195, 3.132195, 3.132195, 3.132195, 3.132195, 3.132195], + [4.482229, 4.482229, 4.482229, 4.482229, 4.482229, 4.482229], + [4.167737, 4.167737, 4.167737, 4.167737, 4.167737, 4.167737], + ], + [ + [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 4.4565434], + [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 4.4565434], + [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 4.4565434], + [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 4.4565434], + [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 4.4565434], + [0.0, 1.3940266, 2.7880538, 2.7880538, 4.1657557, 4.456543], + ], + ] + ).astype(np.float32) + ), + ] + ) + + +class TestRandGridDistortion(unittest.TestCase): + @parameterized.expand(TESTS) + def test_rand_grid_distortion(self, input_param, seed, input_data, expected_val): + g = RandGridDistortion(**input_param) + g.set_random_state(seed=seed) + result = g(input_data) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_grid_distortiond.py b/tests/test_rand_grid_distortiond.py new file mode 100644 index 0000000000..6c91f9ad02 --- /dev/null +++ b/tests/test_rand_grid_distortiond.py @@ -0,0 +1,89 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandGridDistortiond +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + num_cells = 2 + seed = 0 + img = np.indices([6, 6]).astype(np.float32) + TESTS.append( + [ + dict( + keys=["img", "mask"], + num_cells=num_cells, + prob=1.0, + spatial_dims=2, + distort_limit=(-0.1, 0.1), + mode=["bilinear", "nearest"], + padding_mode="zeros", + ), + seed, + {"img": p(img), "mask": p(np.ones_like(img[:1]))}, + p( + np.array( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.6390989, 1.6390989, 1.6390989, 1.6390989, 1.6390989, 0.0], + [3.2781978, 3.2781978, 3.2781978, 3.2781978, 3.2781978, 0.0], + [3.2781978, 3.2781978, 3.2781978, 3.2781978, 3.2781978, 0.0], + [4.74323, 4.74323, 4.74323, 4.74323, 4.74323, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 1.5086684, 3.0173368, 3.0173368, 4.5377502, 0.0], + [0.0, 1.5086684, 3.0173368, 3.0173368, 4.5377502, 0.0], + [0.0, 1.5086684, 3.0173368, 3.0173368, 4.5377502, 0.0], + [0.0, 1.5086684, 3.0173368, 3.0173368, 4.5377502, 0.0], + [0.0, 1.5086684, 3.0173368, 3.0173368, 4.5377502, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ], + ] + ).astype(np.float32) + ), + p( + np.array( + [ + [ + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + ] + ] + ) + ), + ] + ) + + +class TestRandGridDistortiond(unittest.TestCase): + @parameterized.expand(TESTS) + def test_rand_grid_distortiond(self, input_param, seed, input_data, expected_val_img, expected_val_mask): + g = RandGridDistortiond(**input_param) + g.set_random_state(seed=seed) + result = g(input_data) + assert_allclose(result["img"], expected_val_img, rtol=1e-4, atol=1e-4) + assert_allclose(result["mask"], expected_val_mask, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main()