Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,18 @@ Spatial
:members:
:special-members: __call__

`GridDistortion`
""""""""""""""""
.. autoclass:: GridDistortion
:members:
:special-members: __call__

`RandGridDistortion`
""""""""""""""""""""
.. autoclass:: RandGridDistortion
:members:
:special-members: __call__

`Rand2DElastic`
"""""""""""""""
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Rand2DElastic.png
Expand Down Expand Up @@ -1446,6 +1458,18 @@ Spatial (Dict)
:members:
:special-members: __call__

`GridDistortiond`
"""""""""""""""""
.. autoclass:: GridDistortiond
:members:
:special-members: __call__

`RandGridDistortiond`
"""""""""""""""""""""
.. autoclass:: RandGridDistortiond
:members:
:special-members: __call__

Utility (Dict)
^^^^^^^^^^^^^^

Expand Down
8 changes: 8 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@
Affine,
AffineGrid,
Flip,
GridDistortion,
Orientation,
Rand2DElastic,
Rand3DElastic,
Expand All @@ -287,6 +288,7 @@
RandAxisFlip,
RandDeformGrid,
RandFlip,
RandGridDistortion,
RandRotate,
RandRotate90,
RandZoom,
Expand All @@ -307,6 +309,9 @@
Flipd,
FlipD,
FlipDict,
GridDistortiond,
GridDistortionD,
GridDistortionDict,
Orientationd,
OrientationD,
OrientationDict,
Expand All @@ -325,6 +330,9 @@
RandFlipd,
RandFlipD,
RandFlipDict,
RandGridDistortiond,
RandGridDistortionD,
RandGridDistortionDict,
RandRotate90d,
RandRotate90D,
RandRotate90Dict,
Expand Down
181 changes: 181 additions & 0 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,15 @@
"Spacing",
"Orientation",
"Flip",
"GridDistortion",
"Resize",
"Rotate",
"Zoom",
"Rotate90",
"RandRotate90",
"RandRotate",
"RandFlip",
"RandGridDistortion",
"RandAxisFlip",
"RandZoom",
"AffineGrid",
Expand Down Expand Up @@ -2057,3 +2059,182 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
# but user input is 1-based (because channel dim is 0)
coord_channels = coord_channels[[s - 1 for s in self.spatial_channels]]
return concatenate((img, coord_channels), axis=0)


class GridDistortion(Transform):

backend = [TransformBackends.TORCH]

def __init__(
self,
num_cells: int,
distort_steps: List[Tuple],
mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR,
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
device: Optional[torch.device] = None,
) -> None:
"""
Grid distortion transform. Refer to:
https://github.com/albumentations-team/albumentations/blob/master/albumentations/augmentations/transforms.py

Args:
num_cells: number of grid cells on each dimension.
distort_steps: This argument is a list of tuples, where each tuple contains the distort steps of the
corresponding dimensions (in the order of H, W[, D]). The length of each tuple equals to `num_cells + 1`.
Each value in the tuple represents the distort step of the related cell.
mode: {``"bilinear"``, ``"nearest"``}
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"border"``.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
device: device on which the tensor will be allocated.

"""
self.resampler = Resample(
mode=mode,
padding_mode=padding_mode,
device=device,
)
for dim_steps in distort_steps:
if len(dim_steps) != num_cells + 1:
raise ValueError("the length of each tuple in `distort_steps` must equal to `num_cells + 1`.")
self.num_cells = num_cells
self.distort_steps = distort_steps
self.device = device

def __call__(
self,
img: NdarrayOrTensor,
distort_steps: Optional[List[Tuple]] = None,
mode: Optional[Union[GridSampleMode, str]] = None,
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
) -> NdarrayOrTensor:
"""
Args:
img: shape must be (num_channels, H, W[, D]).
distort_steps: This argument is a list of tuples, where each tuple contains the distort steps of the
corresponding dimensions (in the order of H, W[, D]). The length of each tuple equals to `num_cells + 1`.
Each value in the tuple represents the distort step of the related cell.
mode: {``"bilinear"``, ``"nearest"``}
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"border"``.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample

"""
distort_steps = self.distort_steps if distort_steps is None else distort_steps
if len(img.shape) != len(distort_steps) + 1:
raise ValueError("the spatial size of `img` does not match with the length of `distort_steps`")

all_ranges = []
for dim_idx, dim_size in enumerate(img.shape[1:]):
dim_distort_steps = distort_steps[dim_idx]
ranges = torch.zeros(dim_size, dtype=torch.float32)
cell_size = dim_size // self.num_cells
prev = 0
for idx in range(self.num_cells + 1):
start = int(idx * cell_size)
end = start + cell_size
if end > dim_size:
end = dim_size
cur = dim_size
else:
cur = prev + cell_size * dim_distort_steps[idx]
ranges[start:end] = torch.linspace(prev, cur, end - start)
prev = cur
ranges = ranges - (dim_size - 1.0) / 2.0
all_ranges.append(ranges)

coords = torch.meshgrid(*all_ranges)
grid = torch.stack([*coords, torch.ones_like(coords[0])])

return self.resampler(img, grid=grid, mode=mode, padding_mode=padding_mode) # type: ignore


class RandGridDistortion(RandomizableTransform):

backend = [TransformBackends.TORCH]

def __init__(
self,
num_cells: int = 5,
prob: float = 0.1,
spatial_dims: int = 2,
distort_limit: Union[Tuple[float, float], float] = (-0.03, 0.03),
mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR,
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
device: Optional[torch.device] = None,
) -> None:
"""
Random grid distortion transform. Refer to:
https://github.com/albumentations-team/albumentations/blob/master/albumentations/augmentations/transforms.py

Args:
num_cells: number of grid cells on each dimension.
prob: probability of returning a randomized grid distortion transform. Defaults to 0.1.
spatial_dims: spatial dimension of input data. The value should be 2 or 3. Defaults to 2.
distort_limit: range to randomly distort.
If single number, distort_limit is picked from (-distort_limit, distort_limit).
Defaults to (-0.03, 0.03).
mode: {``"bilinear"``, ``"nearest"``}
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"border"``.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
device: device on which the tensor will be allocated.

"""
RandomizableTransform.__init__(self, prob)
if num_cells <= 0:
raise ValueError("num_cells should be no less than 1.")
self.num_cells = num_cells
if spatial_dims not in [2, 3]:
raise ValueError("spatial_size should be 2 or 3.")
self.spatial_dims = spatial_dims
if isinstance(distort_limit, (int, float)):
self.distort_limit = (min(-distort_limit, distort_limit), max(-distort_limit, distort_limit))
else:
self.distort_limit = (min(distort_limit), max(distort_limit))
self.distort_steps = [tuple([1 + self.distort_limit[0]] * (self.num_cells + 1)) for _ in range(spatial_dims)]
self.grid_distortion = GridDistortion(
num_cells=num_cells,
distort_steps=self.distort_steps,
mode=mode,
padding_mode=padding_mode,
device=device,
)

def randomize(self, data: Optional[Any] = None) -> None:
super().randomize(None)
self.distort_steps = [
tuple(
1 + self.R.uniform(low=self.distort_limit[0], high=self.distort_limit[1])
for _ in range(self.num_cells + 1)
)
for _dim in range(self.spatial_dims)
]

def __call__(
self,
img: NdarrayOrTensor,
mode: Optional[Union[GridSampleMode, str]] = None,
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
) -> NdarrayOrTensor:
"""
Args:
img: shape must be (num_channels, H, W[, D]).
mode: {``"bilinear"``, ``"nearest"``}
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"border"``.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample

"""
self.randomize()
if not self._do_transform:
return img
return self.grid_distortion(img, distort_steps=self.distort_steps, mode=mode, padding_mode=padding_mode)
Loading