Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions monai/networks/layers/simplelayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@ class GaussianFilter:
def __init__(self, spatial_dims, sigma, truncated=4., device=None):
"""
Args:
spatial_dims (int): number of spatial dimensions of the input image.
must have shape (Batch, channels, H[, W, ...]).
sigma (float): std.
truncated (float): spreads how many stds.
device (torch.device): device on which the tensor will be allocated.
"""
self.kernel = torch.nn.Parameter(torch.tensor(gaussian_1d(sigma, truncated)), False)
self.spatial_dims = spatial_dims
Expand Down
234 changes: 228 additions & 6 deletions monai/transforms/composables.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@
defined in `monai.transforms.transforms`.
"""

import torch
from collections.abc import Hashable

import monai
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.networks.layers.simplelayers import GaussianFilter
from monai.transforms.compose import Randomizable, Transform
from monai.transforms.transforms import (LoadNifti, AsChannelFirst, Orientation,
AddChannel, Spacing, Rotate90, SpatialCrop)
AddChannel, Spacing, Rotate90, SpatialCrop,
RandAffine, Rand2DElastic, Rand3DElastic)
from monai.utils.misc import ensure_tuple
from monai.transforms.utils import generate_pos_neg_label_crop_centers
from monai.transforms.utils import generate_pos_neg_label_crop_centers, create_grid
from monai.utils.aliases import alias

export = monai.utils.export("monai.transforms")
Expand All @@ -36,15 +39,15 @@ class MapTransform(Transform):
The ``keys`` parameter will be used to get and set the actual data
item to transform. That is, the callable of this transform should
follow the pattern:
```
.. code-block:: python

def __call__(self, data):
for key in self.keys:
if key in data:
update output data with some_transform_function(data[key]).
# update output data with some_transform_function(data[key]).
else:
do nothing or some exceptions handling.
# do nothing or some exceptions handling.
return data
```
"""

def __init__(self, keys):
Expand Down Expand Up @@ -372,3 +375,222 @@ def __call__(self, data):
results[i][key] = data[key]

return results


@export
@alias('RandAffineD', 'RandAffineDict')
class RandAffined(Randomizable, MapTransform):
"""
A dictionary-based wrapper of ``monai.transforms.transforms.RandAffine``.
"""

def __init__(self, keys,
spatial_size, prob=0.1,
rotate_range=None, shear_range=None, translate_range=None, scale_range=None,
mode='bilinear', padding_mode='zeros', as_tensor_output=True, device=None):
"""
Args:
keys (Hashable items): keys of the corresponding items to be transformed.
spatial_size (list or tuple of int): output image spatial size.
if ``data`` component has two spatial dimensions, ``spatial_size`` should have 2 elements [h, w].
if ``data`` component has three spatial dimensions, ``spatial_size`` should have 3 elements [h, w, d].
prob (float): probability of returning a randomized affine grid.
defaults to 0.1, with 10% chance returns a randomized grid.
mode ('nearest'|'bilinear'): interpolation order. Defaults to ``'bilinear'``.
if mode is a tuple of interpolation mode strings, each string corresponds to a key in ``keys``.
this is useful to set different modes for different data items.
padding_mode ('zeros'|'border'|'reflection'): mode of handling out of range indices.
Defaults to ``'zeros'``.
as_tensor_output (bool): the computation is implemented using pytorch tensors, this option specifies
whether to convert it back to numpy arrays.
device (torch.device): device on which the tensor will be allocated.

See also:
- ``monai.transform.composables.MapTransform``
- ``RandAffineGrid`` for the random affine paramters configurations.
"""
MapTransform.__init__(self, keys)
default_mode = 'bilinear' if isinstance(mode, (tuple, list)) else mode
self.rand_affine = RandAffine(prob=prob,
rotate_range=rotate_range, shear_range=shear_range,
translate_range=translate_range, scale_range=scale_range,
spatial_size=spatial_size,
mode=default_mode, padding_mode=padding_mode,
as_tensor_output=as_tensor_output, device=device)
self.mode = mode

def set_random_state(self, seed=None, state=None):
self.rand_affine.set_random_state(seed, state)
Randomizable.set_random_state(self, seed, state)
return self

def randomize(self):
self.rand_affine.randomize()

def __call__(self, data):
d = dict(data)
self.randomize()

spatial_size = self.rand_affine.spatial_size
if self.rand_affine.do_transform:
grid = self.rand_affine.rand_affine_grid(spatial_size=spatial_size)
else:
grid = create_grid(spatial_size)

if isinstance(self.mode, (tuple, list)):
for key, m in zip(self.keys, self.mode):
d[key] = self.rand_affine.resampler(d[key], grid, mode=m)
return d

for key in self.keys: # same interpolation mode
d[key] = self.rand_affine.resampler(d[key], grid, self.rand_affine.mode)
return d


@export
@alias('Rand2DElasticD', 'Rand2DElasticDict')
class Rand2DElasticd(Randomizable, MapTransform):
"""
A dictionary-based wrapper of ``monai.transforms.transforms.Rand2DElastic``.
"""

def __init__(self, keys,
spatial_size, spacing, magnitude_range, prob=0.1,
rotate_range=None, shear_range=None, translate_range=None, scale_range=None,
mode='bilinear', padding_mode='zeros', as_tensor_output=False, device=None):
"""
Args:
keys (Hashable items): keys of the corresponding items to be transformed.
spatial_size (2 ints): specifying output image spatial size [h, w].
spacing (2 ints): distance in between the control points.
magnitude_range (2 ints): the random offsets will be generated from
``uniform[magnitude[0], magnitude[1])``.
prob (float): probability of returning a randomized affine grid.
defaults to 0.1, with 10% chance returns a randomized grid,
otherwise returns a ``spatial_size`` centered area extracted from the input image.
mode ('nearest'|'bilinear'): interpolation order. Defaults to ``'bilinear'``.
if mode is a tuple of interpolation mode strings, each string corresponds to a key in ``keys``.
this is useful to set different modes for different data items.
padding_mode ('zeros'|'border'|'reflection'): mode of handling out of range indices.
Defaults to ``'zeros'``.
as_tensor_output (bool): the computation is implemented using pytorch tensors, this option specifies
whether to convert it back to numpy arrays.
device (torch.device): device on which the tensor will be allocated.

See also:
- ``RandAffineGrid`` for the random affine paramters configurations.
- ``Affine`` for the affine transformation parameters configurations.
"""
MapTransform.__init__(self, keys)
default_mode = 'bilinear' if isinstance(mode, (tuple, list)) else mode
self.rand_2d_elastic = Rand2DElastic(spacing=spacing, magnitude_range=magnitude_range, prob=prob,
rotate_range=rotate_range, shear_range=shear_range,
translate_range=translate_range, scale_range=scale_range,
spatial_size=spatial_size,
mode=default_mode, padding_mode=padding_mode,
as_tensor_output=as_tensor_output, device=device)
self.mode = mode

def set_random_state(self, seed=None, state=None):
self.rand_2d_elastic.set_random_state(seed, state)
Randomizable.set_random_state(self, seed, state)
return self

def randomize(self, spatial_size):
self.rand_2d_elastic.randomize(spatial_size)

def __call__(self, data):
d = dict(data)
spatial_size = self.rand_2d_elastic.spatial_size
self.randomize(spatial_size)

if self.rand_2d_elastic.do_transform:
grid = self.rand_2d_elastic.deform_grid(spatial_size)
grid = self.rand_2d_elastic.rand_affine_grid(grid=grid)
grid = torch.nn.functional.interpolate(grid[None], spatial_size, mode='bicubic', align_corners=False)[0]
else:
grid = create_grid(spatial_size)

if isinstance(self.mode, (tuple, list)):
for key, m in zip(self.keys, self.mode):
d[key] = self.rand_2d_elastic.resampler(d[key], grid, mode=m)
return d

for key in self.keys: # same interpolation mode
d[key] = self.rand_2d_elastic.resampler(d[key], grid, mode=self.rand_2d_elastic.mode)
return d


@export
@alias('Rand3DElasticD', 'Rand3DElasticDict')
class Rand3DElasticd(Randomizable, MapTransform):
"""
A dictionary-based wrapper of ``monai.transforms.transforms.Rand3DElastic``.
"""

def __init__(self, keys,
spatial_size, sigma_range, magnitude_range, prob=0.1,
rotate_range=None, shear_range=None, translate_range=None, scale_range=None,
mode='bilinear', padding_mode='zeros', as_tensor_output=False, device=None):
"""
Args:
keys (Hashable items): keys of the corresponding items to be transformed.
spatial_size (3 ints): specifying output image spatial size [h, w, d].
sigma_range (2 ints): a Gaussian kernel with standard deviation sampled
from ``uniform[sigma_range[0], sigma_range[1])`` will be used to smooth the random offset grid.
magnitude_range (2 ints): the random offsets on the grid will be generated from
``uniform[magnitude[0], magnitude[1])``.
prob (float): probability of returning a randomized affine grid.
defaults to 0.1, with 10% chance returns a randomized grid,
otherwise returns a ``spatial_size`` centered area extracted from the input image.
mode ('nearest'|'bilinear'): interpolation order. Defaults to ``'bilinear'``.
if mode is a tuple of interpolation mode strings, each string corresponds to a key in ``keys``.
this is useful to set different modes for different data items.
padding_mode ('zeros'|'border'|'reflection'): mode of handling out of range indices.
Defaults to ``'zeros'``.
as_tensor_output (bool): the computation is implemented using pytorch tensors, this option specifies
whether to convert it back to numpy arrays.
device (torch.device): device on which the tensor will be allocated.

See also:
- ``RandAffineGrid`` for the random affine paramters configurations.
- ``Affine`` for the affine transformation parameters configurations.
"""
MapTransform.__init__(self, keys)
default_mode = 'bilinear' if isinstance(mode, (tuple, list)) else mode
self.rand_3d_elastic = Rand3DElastic(sigma_range=sigma_range, magnitude_range=magnitude_range, prob=prob,
rotate_range=rotate_range, shear_range=shear_range,
translate_range=translate_range, scale_range=scale_range,
spatial_size=spatial_size,
mode=default_mode, padding_mode=padding_mode,
as_tensor_output=as_tensor_output, device=device)
self.mode = mode

def set_random_state(self, seed=None, state=None):
self.rand_3d_elastic.set_random_state(seed, state)
Randomizable.set_random_state(self, seed, state)
return self

def randomize(self, grid_size):
self.rand_3d_elastic.randomize(grid_size)

def __call__(self, data):
d = dict(data)
spatial_size = self.rand_3d_elastic.spatial_size
self.randomize(spatial_size)
grid = create_grid(spatial_size)
if self.rand_3d_elastic.do_transform:
device = self.rand_3d_elastic.device
grid = torch.tensor(grid).to(device)
gaussian = GaussianFilter(spatial_dims=3, sigma=self.rand_3d_elastic.sigma, truncated=3., device=device)
grid[:3] += gaussian(self.rand_3d_elastic.rand_offset[None])[0] * self.rand_3d_elastic.magnitude
grid = self.rand_3d_elastic.rand_affine_grid(grid=grid)

if isinstance(self.mode, (tuple, list)):
for key, m in zip(self.keys, self.mode):
d[key] = self.rand_3d_elastic.resampler(d[key], grid, mode=m)
return d

for key in self.keys: # same interpolation mode
d[key] = self.rand_3d_elastic.resampler(d[key], grid, mode=self.rand_3d_elastic.mode)
return d
Loading