Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,8 @@ Spatial

`RandGridDistortion`
""""""""""""""""""""
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGridDistortion.png
:alt: example of RandGridDistortion
.. autoclass:: RandGridDistortion
:members:
:special-members: __call__
Expand Down Expand Up @@ -1466,6 +1468,8 @@ Spatial (Dict)

`RandGridDistortiond`
"""""""""""""""""""""
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandGridDistortiond.png
:alt: example of RandGridDistortiond
.. autoclass:: RandGridDistortiond
:members:
:special-members: __call__
Expand Down
52 changes: 20 additions & 32 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2067,8 +2067,8 @@ class GridDistortion(Transform):

def __init__(
self,
num_cells: int,
distort_steps: List[Tuple],
num_cells: Union[Tuple[int], int],
distort_steps: Sequence[Sequence[float]],
mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR,
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
device: Optional[torch.device] = None,
Expand All @@ -2091,22 +2091,15 @@ def __init__(
device: device on which the tensor will be allocated.

"""
self.resampler = Resample(
mode=mode,
padding_mode=padding_mode,
device=device,
)
for dim_steps in distort_steps:
if len(dim_steps) != num_cells + 1:
raise ValueError("the length of each tuple in `distort_steps` must equal to `num_cells + 1`.")
self.resampler = Resample(mode=mode, padding_mode=padding_mode, device=device)
self.num_cells = num_cells
self.distort_steps = distort_steps
self.device = device

def __call__(
self,
img: NdarrayOrTensor,
distort_steps: Optional[List[Tuple]] = None,
distort_steps: Optional[Sequence[Sequence]] = None,
mode: Optional[Union[GridSampleMode, str]] = None,
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
) -> NdarrayOrTensor:
Expand All @@ -2129,12 +2122,13 @@ def __call__(
raise ValueError("the spatial size of `img` does not match with the length of `distort_steps`")

all_ranges = []
num_cells = ensure_tuple_rep(self.num_cells, len(img.shape) - 1)
for dim_idx, dim_size in enumerate(img.shape[1:]):
dim_distort_steps = distort_steps[dim_idx]
ranges = torch.zeros(dim_size, dtype=torch.float32)
cell_size = dim_size // self.num_cells
cell_size = dim_size // num_cells[dim_idx]
prev = 0
for idx in range(self.num_cells + 1):
for idx in range(num_cells[dim_idx] + 1):
start = int(idx * cell_size)
end = start + cell_size
if end > dim_size:
Expand All @@ -2159,9 +2153,8 @@ class RandGridDistortion(RandomizableTransform):

def __init__(
self,
num_cells: int = 5,
num_cells: Union[Tuple[int], int] = 5,
prob: float = 0.1,
spatial_dims: int = 2,
distort_limit: Union[Tuple[float, float], float] = (-0.03, 0.03),
mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR,
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
Expand All @@ -2174,7 +2167,6 @@ def __init__(
Args:
num_cells: number of grid cells on each dimension.
prob: probability of returning a randomized grid distortion transform. Defaults to 0.1.
spatial_dims: spatial dimension of input data. The value should be 2 or 3. Defaults to 2.
distort_limit: range to randomly distort.
If single number, distort_limit is picked from (-distort_limit, distort_limit).
Defaults to (-0.03, 0.03).
Expand All @@ -2188,17 +2180,12 @@ def __init__(

"""
RandomizableTransform.__init__(self, prob)
if num_cells <= 0:
raise ValueError("num_cells should be no less than 1.")
self.num_cells = num_cells
if spatial_dims not in [2, 3]:
raise ValueError("spatial_size should be 2 or 3.")
self.spatial_dims = spatial_dims
if isinstance(distort_limit, (int, float)):
self.distort_limit = (min(-distort_limit, distort_limit), max(-distort_limit, distort_limit))
else:
self.distort_limit = (min(distort_limit), max(distort_limit))
self.distort_steps = [tuple([1 + self.distort_limit[0]] * (self.num_cells + 1)) for _ in range(spatial_dims)]
self.distort_steps: Sequence[Sequence[float]] = ((1.0,),)
self.grid_distortion = GridDistortion(
num_cells=num_cells,
distort_steps=self.distort_steps,
Expand All @@ -2207,21 +2194,21 @@ def __init__(
device=device,
)

def randomize(self, data: Optional[Any] = None) -> None:
def randomize(self, spatial_shape: Sequence[int]) -> None:
super().randomize(None)
self.distort_steps = [
tuple(
1 + self.R.uniform(low=self.distort_limit[0], high=self.distort_limit[1])
for _ in range(self.num_cells + 1)
)
for _dim in range(self.spatial_dims)
]
if not self._do_transform:
return
self.distort_steps = tuple(
tuple(1.0 + self.R.uniform(low=self.distort_limit[0], high=self.distort_limit[1], size=n_cells + 1))
for n_cells in ensure_tuple_rep(self.num_cells, len(spatial_shape))
)

def __call__(
self,
img: NdarrayOrTensor,
mode: Optional[Union[GridSampleMode, str]] = None,
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
randomize: bool = True,
) -> NdarrayOrTensor:
"""
Args:
Expand All @@ -2232,9 +2219,10 @@ def __call__(
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"border"``.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample

randomize: whether to shuffle the random factors using `randomize()`, default to True.
"""
self.randomize()
if randomize:
self.randomize(img.shape[1:])
if not self._do_transform:
return img
return self.grid_distortion(img, distort_steps=self.distort_steps, mode=mode, padding_mode=padding_mode)
11 changes: 4 additions & 7 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1787,7 +1787,7 @@ class GridDistortiond(MapTransform):
def __init__(
self,
keys: KeysCollection,
num_cells: int,
num_cells: Union[Tuple[int], int],
distort_steps: List[Tuple],
mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR,
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
Expand Down Expand Up @@ -1839,9 +1839,8 @@ class RandGridDistortiond(RandomizableTransform, MapTransform):
def __init__(
self,
keys: KeysCollection,
num_cells: int = 5,
num_cells: Union[Tuple[int], int] = 5,
prob: float = 0.1,
spatial_dims: int = 2,
distort_limit: Union[Tuple[float, float], float] = (-0.03, 0.03),
mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR,
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
Expand All @@ -1853,7 +1852,6 @@ def __init__(
keys: keys of the corresponding items to be transformed.
num_cells: number of grid cells on each dimension.
prob: probability of returning a randomized grid distortion transform. Defaults to 0.1.
spatial_dims: spatial dimension of input data. Defaults to 2.
distort_limit: range to randomly distort.
If single number, distort_limit is picked from (-distort_limit, distort_limit).
Defaults to (-0.03, 0.03).
Expand All @@ -1874,7 +1872,6 @@ def __init__(
self.rand_grid_distortion = RandGridDistortion(
num_cells=num_cells,
prob=1.0,
spatial_dims=spatial_dims,
distort_limit=distort_limit,
device=device,
)
Expand All @@ -1894,9 +1891,9 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
if not self._do_transform:
return d

self.rand_grid_distortion.randomize(None)
self.rand_grid_distortion.randomize(d[self.keys[0]].shape[1:])
for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):
d[key] = self.rand_grid_distortion(d[key], mode=mode, padding_mode=padding_mode)
d[key] = self.rand_grid_distortion(d[key], mode=mode, padding_mode=padding_mode, randomize=False)
return d


Expand Down
13 changes: 12 additions & 1 deletion monai/transforms/utils_create_transform_ims.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,20 @@
)
from monai.transforms.post.array import KeepLargestConnectedComponent, LabelFilter, LabelToContour
from monai.transforms.post.dictionary import AsDiscreted, KeepLargestConnectedComponentd, LabelFilterd, LabelToContourd
from monai.transforms.spatial.array import Rand2DElastic, RandAffine, RandAxisFlip, RandRotate90, Resize, Spacing
from monai.transforms.spatial.array import (
Rand2DElastic,
RandAffine,
RandAxisFlip,
RandGridDistortion,
RandRotate90,
Resize,
Spacing,
)
from monai.transforms.spatial.dictionary import (
Rand2DElasticd,
RandAffined,
RandAxisFlipd,
RandGridDistortiond,
RandRotate90d,
Resized,
Spacingd,
Expand Down Expand Up @@ -672,3 +681,5 @@ def create_transform_im(
create_transform_im(
KeepLargestConnectedComponentd, dict(keys=CommonKeys.LABEL, applied_labels=1), data_binary, is_post=True, ndim=2
)
create_transform_im(RandGridDistortion, dict(num_cells=3, prob=1.0, distort_limit=(-0.1, 0.1)), data)
create_transform_im(RandGridDistortiond, dict(keys=keys, num_cells=5, prob=1.0, distort_limit=(-0.1, 0.1)), data)
8 changes: 6 additions & 2 deletions tests/test_efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,12 @@ class TestExtractFeatures(unittest.TestCase):
def test_shape(self, input_param, input_shape, expected_shapes):
device = "cuda" if torch.cuda.is_available() else "cpu"

# initialize model
net = EfficientNetBNFeatures(**input_param).to(device)
try:
# initialize model
net = EfficientNetBNFeatures(**input_param).to(device)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
print(str(e))
return # skipping the tests because of http errors

# run inference with random tensor
with eval_mode(net):
Expand Down
16 changes: 6 additions & 10 deletions tests/test_grid_distortion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@

TESTS = []
for p in TEST_NDARRAYS:
num_cells = 3
distort_steps = [(1.5,) * (1 + num_cells)] * 2
TESTS.append(
[
dict(
num_cells=num_cells,
distort_steps=distort_steps,
num_cells=3,
distort_steps=[(1.5,) * 4] * 2,
mode="nearest",
padding_mode="zeros",
),
Expand Down Expand Up @@ -54,8 +52,8 @@
),
]
)
num_cells = 2
distort_steps = [(1.5,) * (1 + num_cells), (1.0,) * (1 + num_cells)]
num_cells = (2, 2)
distort_steps = [(1.5,) * (1 + num_cells[0]), (1.0,) * (1 + num_cells[1])]
TESTS.append(
[
dict(
Expand Down Expand Up @@ -89,13 +87,11 @@
),
]
)
num_cells = 2
distort_steps = [(1.25,) * (1 + num_cells)] * 3
TESTS.append(
[
dict(
num_cells=num_cells,
distort_steps=distort_steps,
num_cells=2,
distort_steps=[(1.25,) * 3] * 3,
mode="nearest",
padding_mode="zeros",
),
Expand Down
4 changes: 2 additions & 2 deletions tests/test_grid_distortiond.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from tests.utils import TEST_NDARRAYS, assert_allclose

TESTS = []
num_cells = (2, 2)
distort_steps = [(1.5,) * (1 + n_c) for n_c in num_cells]
for p in TEST_NDARRAYS:
num_cells = 2
distort_steps = [(1.5,) * (1 + num_cells)] * 2
img = np.indices([6, 6]).astype(np.float32)
TESTS.append(
[
Expand Down
8 changes: 2 additions & 6 deletions tests/test_rand_grid_distortion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@

TESTS = []
for p in TEST_NDARRAYS:
num_cells = 2
seed = 0
TESTS.append(
[
dict(
num_cells=num_cells,
num_cells=2,
prob=1.0,
spatial_dims=2,
distort_limit=0.5,
mode="nearest",
padding_mode="zeros",
Expand Down Expand Up @@ -57,14 +55,12 @@
),
]
)
num_cells = 2
seed = 1
TESTS.append(
[
dict(
num_cells=num_cells,
num_cells=(2, 2),
prob=1.0,
spatial_dims=2,
distort_limit=0.1,
mode="bilinear",
padding_mode="reflection",
Expand Down
25 changes: 12 additions & 13 deletions tests/test_rand_grid_distortiond.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,16 @@
from tests.utils import TEST_NDARRAYS, assert_allclose

TESTS = []
num_cells = 2
seed = 0
for p in TEST_NDARRAYS:
num_cells = 2
seed = 0
img = np.indices([6, 6]).astype(np.float32)
TESTS.append(
[
dict(
keys=["img", "mask"],
num_cells=num_cells,
prob=1.0,
spatial_dims=2,
distort_limit=(-0.1, 0.1),
mode=["bilinear", "nearest"],
padding_mode="zeros",
Expand All @@ -40,18 +39,18 @@
[
[
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[1.6390989, 1.6390989, 1.6390989, 1.6390989, 1.6390989, 0.0],
[3.2781978, 3.2781978, 3.2781978, 3.2781978, 3.2781978, 0.0],
[3.2781978, 3.2781978, 3.2781978, 3.2781978, 3.2781978, 0.0],
[4.74323, 4.74323, 4.74323, 4.74323, 4.74323, 0.0],
[1.5645568, 1.5645568, 1.5645568, 1.5645568, 1.5645568, 0.0],
[3.1291137, 3.1291137, 3.1291137, 3.1291137, 3.1291137, 0.0],
[3.1291137, 3.1291137, 3.1291137, 3.1291137, 3.1291137, 0.0],
[4.6599426, 4.6599426, 4.6599426, 4.6599426, 4.6599426, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
],
[
[0.0, 1.5086684, 3.0173368, 3.0173368, 4.5377502, 0.0],
[0.0, 1.5086684, 3.0173368, 3.0173368, 4.5377502, 0.0],
[0.0, 1.5086684, 3.0173368, 3.0173368, 4.5377502, 0.0],
[0.0, 1.5086684, 3.0173368, 3.0173368, 4.5377502, 0.0],
[0.0, 1.5086684, 3.0173368, 3.0173368, 4.5377502, 0.0],
[0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0],
[0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0],
[0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0],
[0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0],
[0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
],
]
Expand All @@ -66,7 +65,7 @@
[1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
]
]
)
Expand Down