From 280bf36cff872c6569847d8c8911fd091a4e2f6a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 18 Oct 2021 13:07:13 +0100 Subject: [PATCH 1/5] update rand grid distortion Signed-off-by: Wenqi Li --- docs/source/transforms.rst | 4 ++ monai/transforms/spatial/array.py | 40 ++++++++----------- monai/transforms/spatial/dictionary.py | 7 +--- .../transforms/utils_create_transform_ims.py | 13 +++++- tests/test_rand_grid_distortion.py | 2 - tests/test_rand_grid_distortiond.py | 25 ++++++------ 6 files changed, 47 insertions(+), 44 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 0a91805d80..804346b290 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -632,6 +632,8 @@ Spatial `RandGridDistortion` """""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGridDistortion.png + :alt: example of RandGridDistortion .. autoclass:: RandGridDistortion :members: :special-members: __call__ @@ -1466,6 +1468,8 @@ Spatial (Dict) `RandGridDistortiond` """"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGridDistortiond.png + :alt: example of RandGridDistortiond .. autoclass:: RandGridDistortiond :members: :special-members: __call__ diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 5a61d67f2b..4a06976ee7 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -2067,8 +2067,8 @@ class GridDistortion(Transform): def __init__( self, - num_cells: int, - distort_steps: List[Tuple], + num_cells: Union[Tuple[int], int], + distort_steps: Sequence[Sequence[float]], mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, device: Optional[torch.device] = None, @@ -2096,9 +2096,6 @@ def __init__( padding_mode=padding_mode, device=device, ) - for dim_steps in distort_steps: - if len(dim_steps) != num_cells + 1: - raise ValueError("the length of each tuple in `distort_steps` must equal to `num_cells + 1`.") self.num_cells = num_cells self.distort_steps = distort_steps self.device = device @@ -2106,7 +2103,7 @@ def __init__( def __call__( self, img: NdarrayOrTensor, - distort_steps: Optional[List[Tuple]] = None, + distort_steps: Optional[Sequence[Sequence]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, ) -> NdarrayOrTensor: @@ -2129,12 +2126,13 @@ def __call__( raise ValueError("the spatial size of `img` does not match with the length of `distort_steps`") all_ranges = [] + num_cells = ensure_tuple_rep(self.num_cells, len(img.shape) - 1) for dim_idx, dim_size in enumerate(img.shape[1:]): dim_distort_steps = distort_steps[dim_idx] ranges = torch.zeros(dim_size, dtype=torch.float32) - cell_size = dim_size // self.num_cells + cell_size = dim_size // num_cells[dim_idx] prev = 0 - for idx in range(self.num_cells + 1): + for idx in range(num_cells[dim_idx] + 1): start = int(idx * cell_size) end = start + cell_size if end > dim_size: @@ -2161,7 +2159,6 @@ def __init__( self, num_cells: int = 5, prob: float = 0.1, - spatial_dims: int = 2, distort_limit: Union[Tuple[float, float], float] = (-0.03, 0.03), mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, @@ -2174,7 +2171,6 @@ def __init__( Args: num_cells: number of grid cells on each dimension. prob: probability of returning a randomized grid distortion transform. Defaults to 0.1. - spatial_dims: spatial dimension of input data. The value should be 2 or 3. Defaults to 2. distort_limit: range to randomly distort. If single number, distort_limit is picked from (-distort_limit, distort_limit). Defaults to (-0.03, 0.03). @@ -2191,14 +2187,11 @@ def __init__( if num_cells <= 0: raise ValueError("num_cells should be no less than 1.") self.num_cells = num_cells - if spatial_dims not in [2, 3]: - raise ValueError("spatial_size should be 2 or 3.") - self.spatial_dims = spatial_dims if isinstance(distort_limit, (int, float)): self.distort_limit = (min(-distort_limit, distort_limit), max(-distort_limit, distort_limit)) else: self.distort_limit = (min(distort_limit), max(distort_limit)) - self.distort_steps = [tuple([1 + self.distort_limit[0]] * (self.num_cells + 1)) for _ in range(spatial_dims)] + self.distort_steps: Sequence[Sequence[float]] = ((1.0,),) self.grid_distortion = GridDistortion( num_cells=num_cells, distort_steps=self.distort_steps, @@ -2207,21 +2200,21 @@ def __init__( device=device, ) - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, spatial_shape: Sequence[int]) -> None: super().randomize(None) - self.distort_steps = [ - tuple( - 1 + self.R.uniform(low=self.distort_limit[0], high=self.distort_limit[1]) - for _ in range(self.num_cells + 1) - ) - for _dim in range(self.spatial_dims) - ] + if not self._do_transform: + return + self.distort_steps = tuple( + tuple(1.0 + self.R.uniform(low=self.distort_limit[0], high=self.distort_limit[1], size=n_cells + 1)) + for n_cells in ensure_tuple_rep(self.num_cells, len(spatial_shape)) + ) def __call__( self, img: NdarrayOrTensor, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, + randomize: bool = True, ) -> NdarrayOrTensor: """ Args: @@ -2234,7 +2227,8 @@ def __call__( See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample """ - self.randomize() + if randomize: + self.randomize(img.shape[1:]) if not self._do_transform: return img return self.grid_distortion(img, distort_steps=self.distort_steps, mode=mode, padding_mode=padding_mode) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index d8c01fc12e..a8fd23f4ba 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -1841,7 +1841,6 @@ def __init__( keys: KeysCollection, num_cells: int = 5, prob: float = 0.1, - spatial_dims: int = 2, distort_limit: Union[Tuple[float, float], float] = (-0.03, 0.03), mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, @@ -1853,7 +1852,6 @@ def __init__( keys: keys of the corresponding items to be transformed. num_cells: number of grid cells on each dimension. prob: probability of returning a randomized grid distortion transform. Defaults to 0.1. - spatial_dims: spatial dimension of input data. Defaults to 2. distort_limit: range to randomly distort. If single number, distort_limit is picked from (-distort_limit, distort_limit). Defaults to (-0.03, 0.03). @@ -1874,7 +1872,6 @@ def __init__( self.rand_grid_distortion = RandGridDistortion( num_cells=num_cells, prob=1.0, - spatial_dims=spatial_dims, distort_limit=distort_limit, device=device, ) @@ -1894,9 +1891,9 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N if not self._do_transform: return d - self.rand_grid_distortion.randomize(None) + self.rand_grid_distortion.randomize(d[self.keys[0]].shape[1:]) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): - d[key] = self.rand_grid_distortion(d[key], mode=mode, padding_mode=padding_mode) + d[key] = self.rand_grid_distortion(d[key], mode=mode, padding_mode=padding_mode, randomize=False) return d diff --git a/monai/transforms/utils_create_transform_ims.py b/monai/transforms/utils_create_transform_ims.py index 369eb02729..dca9615fbf 100644 --- a/monai/transforms/utils_create_transform_ims.py +++ b/monai/transforms/utils_create_transform_ims.py @@ -141,11 +141,20 @@ ) from monai.transforms.post.array import KeepLargestConnectedComponent, LabelFilter, LabelToContour from monai.transforms.post.dictionary import AsDiscreted, KeepLargestConnectedComponentd, LabelFilterd, LabelToContourd -from monai.transforms.spatial.array import Rand2DElastic, RandAffine, RandAxisFlip, RandRotate90, Resize, Spacing +from monai.transforms.spatial.array import ( + Rand2DElastic, + RandAffine, + RandAxisFlip, + RandGridDistortion, + RandRotate90, + Resize, + Spacing, +) from monai.transforms.spatial.dictionary import ( Rand2DElasticd, RandAffined, RandAxisFlipd, + RandGridDistortiond, RandRotate90d, Resized, Spacingd, @@ -672,3 +681,5 @@ def create_transform_im( create_transform_im( KeepLargestConnectedComponentd, dict(keys=CommonKeys.LABEL, applied_labels=1), data_binary, is_post=True, ndim=2 ) + create_transform_im(RandGridDistortion, dict(num_cells=3, prob=1.0, distort_limit=(-0.1, 0.1)), data) + create_transform_im(RandGridDistortiond, dict(keys=keys, num_cells=5, prob=1.0, distort_limit=(-0.1, 0.1)), data) diff --git a/tests/test_rand_grid_distortion.py b/tests/test_rand_grid_distortion.py index f7e4969328..73e37db5bf 100644 --- a/tests/test_rand_grid_distortion.py +++ b/tests/test_rand_grid_distortion.py @@ -26,7 +26,6 @@ dict( num_cells=num_cells, prob=1.0, - spatial_dims=2, distort_limit=0.5, mode="nearest", padding_mode="zeros", @@ -64,7 +63,6 @@ dict( num_cells=num_cells, prob=1.0, - spatial_dims=2, distort_limit=0.1, mode="bilinear", padding_mode="reflection", diff --git a/tests/test_rand_grid_distortiond.py b/tests/test_rand_grid_distortiond.py index 6c91f9ad02..835f38743c 100644 --- a/tests/test_rand_grid_distortiond.py +++ b/tests/test_rand_grid_distortiond.py @@ -18,9 +18,9 @@ from tests.utils import TEST_NDARRAYS, assert_allclose TESTS = [] +num_cells = 2 +seed = 0 for p in TEST_NDARRAYS: - num_cells = 2 - seed = 0 img = np.indices([6, 6]).astype(np.float32) TESTS.append( [ @@ -28,7 +28,6 @@ keys=["img", "mask"], num_cells=num_cells, prob=1.0, - spatial_dims=2, distort_limit=(-0.1, 0.1), mode=["bilinear", "nearest"], padding_mode="zeros", @@ -40,18 +39,18 @@ [ [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.6390989, 1.6390989, 1.6390989, 1.6390989, 1.6390989, 0.0], - [3.2781978, 3.2781978, 3.2781978, 3.2781978, 3.2781978, 0.0], - [3.2781978, 3.2781978, 3.2781978, 3.2781978, 3.2781978, 0.0], - [4.74323, 4.74323, 4.74323, 4.74323, 4.74323, 0.0], + [1.5645568, 1.5645568, 1.5645568, 1.5645568, 1.5645568, 0.0], + [3.1291137, 3.1291137, 3.1291137, 3.1291137, 3.1291137, 0.0], + [3.1291137, 3.1291137, 3.1291137, 3.1291137, 3.1291137, 0.0], + [4.6599426, 4.6599426, 4.6599426, 4.6599426, 4.6599426, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], [ - [0.0, 1.5086684, 3.0173368, 3.0173368, 4.5377502, 0.0], - [0.0, 1.5086684, 3.0173368, 3.0173368, 4.5377502, 0.0], - [0.0, 1.5086684, 3.0173368, 3.0173368, 4.5377502, 0.0], - [0.0, 1.5086684, 3.0173368, 3.0173368, 4.5377502, 0.0], - [0.0, 1.5086684, 3.0173368, 3.0173368, 4.5377502, 0.0], + [0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0], + [0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0], + [0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0], + [0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0], + [0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], ] @@ -66,7 +65,7 @@ [1.0, 1.0, 1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ] ] ) From 89a5ba127ef52527dc8536d6331b897908fb71e9 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 18 Oct 2021 13:29:54 +0100 Subject: [PATCH 2/5] update docs Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 4a06976ee7..dc612a807b 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -2091,11 +2091,7 @@ def __init__( device: device on which the tensor will be allocated. """ - self.resampler = Resample( - mode=mode, - padding_mode=padding_mode, - device=device, - ) + self.resampler = Resample(mode=mode, padding_mode=padding_mode, device=device) self.num_cells = num_cells self.distort_steps = distort_steps self.device = device @@ -2225,7 +2221,7 @@ def __call__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - + randomize: whether to shuffle the random factors using `randomize()`, default to True. """ if randomize: self.randomize(img.shape[1:]) From 6523bf6654fe81368acb629cefbcf1b3a74b09f7 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 18 Oct 2021 16:44:12 +0100 Subject: [PATCH 3/5] update based on comments Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 2 +- monai/transforms/spatial/dictionary.py | 4 ++-- tests/test_grid_distortion.py | 4 ++-- tests/test_grid_distortiond.py | 4 ++-- tests/test_rand_grid_distortion.py | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index dc612a807b..557ca0afce 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -2153,7 +2153,7 @@ class RandGridDistortion(RandomizableTransform): def __init__( self, - num_cells: int = 5, + num_cells: Union[Tuple[int], int] = 5, prob: float = 0.1, distort_limit: Union[Tuple[float, float], float] = (-0.03, 0.03), mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index a8fd23f4ba..b0a5b6b608 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -1787,7 +1787,7 @@ class GridDistortiond(MapTransform): def __init__( self, keys: KeysCollection, - num_cells: int, + num_cells: Union[Tuple[int], int], distort_steps: List[Tuple], mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, @@ -1839,7 +1839,7 @@ class RandGridDistortiond(RandomizableTransform, MapTransform): def __init__( self, keys: KeysCollection, - num_cells: int = 5, + num_cells: Union[Tuple[int], int] = 5, prob: float = 0.1, distort_limit: Union[Tuple[float, float], float] = (-0.03, 0.03), mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, diff --git a/tests/test_grid_distortion.py b/tests/test_grid_distortion.py index baed797c86..d5f4e14477 100644 --- a/tests/test_grid_distortion.py +++ b/tests/test_grid_distortion.py @@ -54,8 +54,8 @@ ), ] ) - num_cells = 2 - distort_steps = [(1.5,) * (1 + num_cells), (1.0,) * (1 + num_cells)] + num_cells = (2, 2) + distort_steps = [(1.5,) * (1 + num_cells[0]), (1.0,) * (1 + num_cells[1])] TESTS.append( [ dict( diff --git a/tests/test_grid_distortiond.py b/tests/test_grid_distortiond.py index e216f16cd4..55e2e6ad1d 100644 --- a/tests/test_grid_distortiond.py +++ b/tests/test_grid_distortiond.py @@ -18,9 +18,9 @@ from tests.utils import TEST_NDARRAYS, assert_allclose TESTS = [] +num_cells = (2, 2) +distort_steps = [(1.5,) * (1 + n_c) for n_c in num_cells] for p in TEST_NDARRAYS: - num_cells = 2 - distort_steps = [(1.5,) * (1 + num_cells)] * 2 img = np.indices([6, 6]).astype(np.float32) TESTS.append( [ diff --git a/tests/test_rand_grid_distortion.py b/tests/test_rand_grid_distortion.py index 73e37db5bf..48b1622d83 100644 --- a/tests/test_rand_grid_distortion.py +++ b/tests/test_rand_grid_distortion.py @@ -56,7 +56,7 @@ ), ] ) - num_cells = 2 + num_cells = (2, 2) seed = 1 TESTS.append( [ From c816034e1c375f5fafe7d673ab42da15ae454a79 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 18 Oct 2021 17:50:04 +0100 Subject: [PATCH 4/5] fixes formatting and errors Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 2 -- tests/test_grid_distortion.py | 12 ++++-------- tests/test_rand_grid_distortion.py | 6 ++---- 3 files changed, 6 insertions(+), 14 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 557ca0afce..4e3f513864 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -2180,8 +2180,6 @@ def __init__( """ RandomizableTransform.__init__(self, prob) - if num_cells <= 0: - raise ValueError("num_cells should be no less than 1.") self.num_cells = num_cells if isinstance(distort_limit, (int, float)): self.distort_limit = (min(-distort_limit, distort_limit), max(-distort_limit, distort_limit)) diff --git a/tests/test_grid_distortion.py b/tests/test_grid_distortion.py index d5f4e14477..75f8bb06bc 100644 --- a/tests/test_grid_distortion.py +++ b/tests/test_grid_distortion.py @@ -19,13 +19,11 @@ TESTS = [] for p in TEST_NDARRAYS: - num_cells = 3 - distort_steps = [(1.5,) * (1 + num_cells)] * 2 TESTS.append( [ dict( - num_cells=num_cells, - distort_steps=distort_steps, + num_cells=3, + distort_steps=[(1.5,) * 4] * 2, mode="nearest", padding_mode="zeros", ), @@ -89,13 +87,11 @@ ), ] ) - num_cells = 2 - distort_steps = [(1.25,) * (1 + num_cells)] * 3 TESTS.append( [ dict( - num_cells=num_cells, - distort_steps=distort_steps, + num_cells=2, + distort_steps=[(1.25,) * 3] * 3, mode="nearest", padding_mode="zeros", ), diff --git a/tests/test_rand_grid_distortion.py b/tests/test_rand_grid_distortion.py index 48b1622d83..c01a36e73d 100644 --- a/tests/test_rand_grid_distortion.py +++ b/tests/test_rand_grid_distortion.py @@ -19,12 +19,11 @@ TESTS = [] for p in TEST_NDARRAYS: - num_cells = 2 seed = 0 TESTS.append( [ dict( - num_cells=num_cells, + num_cells=2, prob=1.0, distort_limit=0.5, mode="nearest", @@ -56,12 +55,11 @@ ), ] ) - num_cells = (2, 2) seed = 1 TESTS.append( [ dict( - num_cells=num_cells, + num_cells=(2, 2), prob=1.0, distort_limit=0.1, mode="bilinear", From 11f56b406ac20ab0b4822adbc3c22a7c6e41c437 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 18 Oct 2021 18:32:23 +0100 Subject: [PATCH 5/5] skip weights downloading Signed-off-by: Wenqi Li --- tests/test_efficientnet.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_efficientnet.py b/tests/test_efficientnet.py index 20c7123d7f..667d3cd09b 100644 --- a/tests/test_efficientnet.py +++ b/tests/test_efficientnet.py @@ -382,8 +382,12 @@ class TestExtractFeatures(unittest.TestCase): def test_shape(self, input_param, input_shape, expected_shapes): device = "cuda" if torch.cuda.is_available() else "cpu" - # initialize model - net = EfficientNetBNFeatures(**input_param).to(device) + try: + # initialize model + net = EfficientNetBNFeatures(**input_param).to(device) + except (ContentTooShortError, HTTPError, RuntimeError) as e: + print(str(e)) + return # skipping the tests because of http errors # run inference with random tensor with eval_mode(net):