From 5b1f258238f984def1729f61a739cea3f8000929 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 11 Mar 2020 10:07:50 +0000 Subject: [PATCH 1/2] add dictionary-based wrapper random spatial transforms (#166) --- monai/networks/layers/simplelayers.py | 3 + monai/transforms/composables.py | 234 +++++++++++++++++++++++++- monai/transforms/transforms.py | 100 ++++++----- tests/test_flip.py | 2 +- tests/test_random_affine.py | 4 +- tests/test_random_affined.py | 90 ++++++++++ tests/test_random_elastic_2d.py | 17 +- tests/test_random_elastic_3d.py | 3 +- tests/test_random_elasticd_2d.py | 88 ++++++++++ tests/test_random_elasticd_3d.py | 72 ++++++++ tests/test_random_flip.py | 8 +- 11 files changed, 552 insertions(+), 69 deletions(-) create mode 100644 tests/test_random_affined.py create mode 100644 tests/test_random_elasticd_2d.py create mode 100644 tests/test_random_elasticd_3d.py diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index 5ed491354b..c41ff93a0f 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -40,8 +40,11 @@ class GaussianFilter: def __init__(self, spatial_dims, sigma, truncated=4., device=None): """ Args: + spatial_dims (int): number of spatial dimensions of the input image. + must have shape (Batch, channels, H[, W, ...]). sigma (float): std. truncated (float): spreads how many stds. + device (torch.device): device on which the tensor will be allocated. """ self.kernel = torch.nn.Parameter(torch.tensor(gaussian_1d(sigma, truncated)), False) self.spatial_dims = spatial_dims diff --git a/monai/transforms/composables.py b/monai/transforms/composables.py index 0e404c5233..8ccd5a747a 100644 --- a/monai/transforms/composables.py +++ b/monai/transforms/composables.py @@ -13,15 +13,18 @@ defined in `monai.transforms.transforms`. """ +import torch from collections.abc import Hashable import monai from monai.data.utils import get_random_patch, get_valid_patch_size +from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms.compose import Randomizable, Transform from monai.transforms.transforms import (LoadNifti, AsChannelFirst, Orientation, - AddChannel, Spacing, Rotate90, SpatialCrop) + AddChannel, Spacing, Rotate90, SpatialCrop, + RandAffine, Rand2DElastic, Rand3DElastic) from monai.utils.misc import ensure_tuple -from monai.transforms.utils import generate_pos_neg_label_crop_centers +from monai.transforms.utils import generate_pos_neg_label_crop_centers, create_grid from monai.utils.aliases import alias export = monai.utils.export("monai.transforms") @@ -36,15 +39,15 @@ class MapTransform(Transform): The ``keys`` parameter will be used to get and set the actual data item to transform. That is, the callable of this transform should follow the pattern: - ``` + .. code-block:: python + def __call__(self, data): for key in self.keys: if key in data: - update output data with some_transform_function(data[key]). + # update output data with some_transform_function(data[key]). else: - do nothing or some exceptions handling. + # do nothing or some exceptions handling. return data - ``` """ def __init__(self, keys): @@ -372,3 +375,222 @@ def __call__(self, data): results[i][key] = data[key] return results + + +@export +@alias('RandAffineD', 'RandAffineDict') +class RandAffined(Randomizable, MapTransform): + """ + A dictionary-based wrapper of ``monai.transforms.transforms.RandAffine``. + """ + + def __init__(self, keys, + spatial_size, prob=0.1, + rotate_range=None, shear_range=None, translate_range=None, scale_range=None, + mode='bilinear', padding_mode='zeros', as_tensor_output=True, device=None): + """ + Args: + keys (Hashable items): keys of the corresponding items to be transformed. + spatial_size (list or tuple of int): output image spatial size. + if ``data`` component has two spatial dimensions, ``spatial_size`` should have 2 elements [h, w]. + if ``data`` component has three spatial dimensions, ``spatial_size`` should have 3 elements [h, w, d]. + prob (float): probability of returning a randomized affine grid. + defaults to 0.1, with 10% chance returns a randomized grid. + mode ('nearest'|'bilinear'): interpolation order. Defaults to ``'bilinear'``. + if mode is a tuple of interpolation mode strings, each string corresponds to a key in ``keys``. + this is useful to set different modes for different data items. + padding_mode ('zeros'|'border'|'reflection'): mode of handling out of range indices. + Defaults to ``'zeros'``. + as_tensor_output (bool): the computation is implemented using pytorch tensors, this option specifies + whether to convert it back to numpy arrays. + device (torch.device): device on which the tensor will be allocated. + + See also: + - ``monai.transform.composables.MapTransform`` + - ``RandAffineGrid`` for the random affine paramters configurations. + """ + MapTransform.__init__(self, keys) + default_mode = 'bilinear' if isinstance(mode, (tuple, list)) else mode + self.rand_affine = RandAffine(prob=prob, + rotate_range=rotate_range, shear_range=shear_range, + translate_range=translate_range, scale_range=scale_range, + spatial_size=spatial_size, + mode=default_mode, padding_mode=padding_mode, + as_tensor_output=as_tensor_output, device=device) + self.mode = mode + + def set_random_state(self, seed=None, state=None): + self.rand_affine.set_random_state(seed, state) + Randomizable.set_random_state(self, seed, state) + return self + + def randomize(self): + self.rand_affine.randomize() + + def __call__(self, data): + d = dict(data) + self.randomize() + + spatial_size = self.rand_affine.spatial_size + if self.rand_affine.do_transform: + grid = self.rand_affine.rand_affine_grid(spatial_size=spatial_size) + else: + grid = create_grid(spatial_size) + + if isinstance(self.mode, (tuple, list)): + for key, m in zip(self.keys, self.mode): + d[key] = self.rand_affine.resampler(d[key], grid, mode=m) + return d + + for key in self.keys: # same interpolation mode + d[key] = self.rand_affine.resampler(d[key], grid, self.rand_affine.mode) + return d + + +@export +@alias('Rand2DElasticD', 'Rand2DElasticDict') +class Rand2DElasticd(Randomizable, MapTransform): + """ + A dictionary-based wrapper of ``monai.transforms.transforms.Rand2DElastic``. + """ + + def __init__(self, keys, + spatial_size, spacing, magnitude_range, prob=0.1, + rotate_range=None, shear_range=None, translate_range=None, scale_range=None, + mode='bilinear', padding_mode='zeros', as_tensor_output=False, device=None): + """ + Args: + keys (Hashable items): keys of the corresponding items to be transformed. + spatial_size (2 ints): specifying output image spatial size [h, w]. + spacing (2 ints): distance in between the control points. + magnitude_range (2 ints): the random offsets will be generated from + ``uniform[magnitude[0], magnitude[1])``. + prob (float): probability of returning a randomized affine grid. + defaults to 0.1, with 10% chance returns a randomized grid, + otherwise returns a ``spatial_size`` centered area extracted from the input image. + mode ('nearest'|'bilinear'): interpolation order. Defaults to ``'bilinear'``. + if mode is a tuple of interpolation mode strings, each string corresponds to a key in ``keys``. + this is useful to set different modes for different data items. + padding_mode ('zeros'|'border'|'reflection'): mode of handling out of range indices. + Defaults to ``'zeros'``. + as_tensor_output (bool): the computation is implemented using pytorch tensors, this option specifies + whether to convert it back to numpy arrays. + device (torch.device): device on which the tensor will be allocated. + + See also: + - ``RandAffineGrid`` for the random affine paramters configurations. + - ``Affine`` for the affine transformation parameters configurations. + """ + MapTransform.__init__(self, keys) + default_mode = 'bilinear' if isinstance(mode, (tuple, list)) else mode + self.rand_2d_elastic = Rand2DElastic(spacing=spacing, magnitude_range=magnitude_range, prob=prob, + rotate_range=rotate_range, shear_range=shear_range, + translate_range=translate_range, scale_range=scale_range, + spatial_size=spatial_size, + mode=default_mode, padding_mode=padding_mode, + as_tensor_output=as_tensor_output, device=device) + self.mode = mode + + def set_random_state(self, seed=None, state=None): + self.rand_2d_elastic.set_random_state(seed, state) + Randomizable.set_random_state(self, seed, state) + return self + + def randomize(self, spatial_size): + self.rand_2d_elastic.randomize(spatial_size) + + def __call__(self, data): + d = dict(data) + spatial_size = self.rand_2d_elastic.spatial_size + self.randomize(spatial_size) + + if self.rand_2d_elastic.do_transform: + grid = self.rand_2d_elastic.deform_grid(spatial_size) + grid = self.rand_2d_elastic.rand_affine_grid(grid=grid) + grid = torch.nn.functional.interpolate(grid[None], spatial_size, mode='bicubic', align_corners=False)[0] + else: + grid = create_grid(spatial_size) + + if isinstance(self.mode, (tuple, list)): + for key, m in zip(self.keys, self.mode): + d[key] = self.rand_2d_elastic.resampler(d[key], grid, mode=m) + return d + + for key in self.keys: # same interpolation mode + d[key] = self.rand_2d_elastic.resampler(d[key], grid, mode=self.rand_2d_elastic.mode) + return d + + +@export +@alias('Rand3DElasticD', 'Rand3DElasticDict') +class Rand3DElasticd(Randomizable, MapTransform): + """ + A dictionary-based wrapper of ``monai.transforms.transforms.Rand3DElastic``. + """ + + def __init__(self, keys, + spatial_size, sigma_range, magnitude_range, prob=0.1, + rotate_range=None, shear_range=None, translate_range=None, scale_range=None, + mode='bilinear', padding_mode='zeros', as_tensor_output=False, device=None): + """ + Args: + keys (Hashable items): keys of the corresponding items to be transformed. + spatial_size (3 ints): specifying output image spatial size [h, w, d]. + sigma_range (2 ints): a Gaussian kernel with standard deviation sampled + from ``uniform[sigma_range[0], sigma_range[1])`` will be used to smooth the random offset grid. + magnitude_range (2 ints): the random offsets on the grid will be generated from + ``uniform[magnitude[0], magnitude[1])``. + prob (float): probability of returning a randomized affine grid. + defaults to 0.1, with 10% chance returns a randomized grid, + otherwise returns a ``spatial_size`` centered area extracted from the input image. + mode ('nearest'|'bilinear'): interpolation order. Defaults to ``'bilinear'``. + if mode is a tuple of interpolation mode strings, each string corresponds to a key in ``keys``. + this is useful to set different modes for different data items. + padding_mode ('zeros'|'border'|'reflection'): mode of handling out of range indices. + Defaults to ``'zeros'``. + as_tensor_output (bool): the computation is implemented using pytorch tensors, this option specifies + whether to convert it back to numpy arrays. + device (torch.device): device on which the tensor will be allocated. + + See also: + - ``RandAffineGrid`` for the random affine paramters configurations. + - ``Affine`` for the affine transformation parameters configurations. + """ + MapTransform.__init__(self, keys) + default_mode = 'bilinear' if isinstance(mode, (tuple, list)) else mode + self.rand_3d_elastic = Rand3DElastic(sigma_range=sigma_range, magnitude_range=magnitude_range, prob=prob, + rotate_range=rotate_range, shear_range=shear_range, + translate_range=translate_range, scale_range=scale_range, + spatial_size=spatial_size, + mode=default_mode, padding_mode=padding_mode, + as_tensor_output=as_tensor_output, device=device) + self.mode = mode + + def set_random_state(self, seed=None, state=None): + self.rand_3d_elastic.set_random_state(seed, state) + Randomizable.set_random_state(self, seed, state) + return self + + def randomize(self, grid_size): + self.rand_3d_elastic.randomize(grid_size) + + def __call__(self, data): + d = dict(data) + spatial_size = self.rand_3d_elastic.spatial_size + self.randomize(spatial_size) + grid = create_grid(spatial_size) + if self.rand_3d_elastic.do_transform: + device = self.rand_3d_elastic.device + grid = torch.tensor(grid).to(device) + gaussian = GaussianFilter(spatial_dims=3, sigma=self.rand_3d_elastic.sigma, truncated=3., device=device) + grid[:3] += gaussian(self.rand_3d_elastic.rand_offset[None])[0] * self.rand_3d_elastic.magnitude + grid = self.rand_3d_elastic.rand_affine_grid(grid=grid) + + if isinstance(self.mode, (tuple, list)): + for key, m in zip(self.keys, self.mode): + d[key] = self.rand_3d_elastic.resampler(d[key], grid, mode=m) + return d + + for key in self.keys: # same interpolation mode + d[key] = self.rand_3d_elastic.resampler(d[key], grid, mode=self.rand_3d_elastic.mode) + return d diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index a2352e5db8..370a7fb305 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -268,15 +268,14 @@ def __call__(self, img): @export class Flip: """Reverses the order of elements along the given axis. Preserves shape. - Uses np.flip in practice. See numpy.flip for additional details. + Uses ``np.flip`` in practice. See numpy.flip for additional details. + https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html Args: - axes (None, int or tuple of ints): Axes along which to flip over. Default is None. + axis (None, int or tuple of ints): Axes along which to flip over. Default is None. """ def __init__(self, axis=None): - assert axis is None or isinstance(axis, (int, list, tuple)), \ - "axis must be None, int or tuple of ints." self.axis = axis def __call__(self, img): @@ -638,19 +637,19 @@ class RandRotate(Randomizable): Args: prob (float): Probability of rotation. degrees (tuple of float or float): Range of rotation in degrees. If single number, - angle is picked from (-degrees, degrees). + angle is picked from (-degrees, degrees). axes (tuple of 2 ints): Axes of rotation. Default: (1, 2). This is the first two axis in spatial dimensions according to MONAI channel first shape assumption. reshape (bool): If true, output shape is made same as input. Default: True. order (int): Order of spline interpolation. Range 0-5. Default: 1. This is different from scipy where default interpolation is 3. - mode (str): Points outside boundary filled according to this mode. Options are + mode (str): Points outside boundary filled according to this mode. Options are 'constant', 'nearest', 'reflect', 'wrap'. Default: 'constant'. cval (scalar): Value to fill outside boundary. Default: 0. prefiter (bool): Apply spline_filter before interpolation. Default: True. """ - def __init__(self, degrees, prob=0.1, axes=(1, 2), reshape=True, order=1, + def __init__(self, degrees, prob=0.1, axes=(1, 2), reshape=True, order=1, mode='constant', cval=0, prefilter=True): self.prob = prob self.degrees = degrees @@ -682,17 +681,19 @@ def __call__(self, img): @export -class RandomFlip(Randomizable): - """Randomly flips the image along axes. +class RandFlip(Randomizable): + """Randomly flips the image along axes. Preserves shape. + See numpy.flip for additional details. + https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html Args: prob (float): Probability of flipping. - axes (None, int or tuple of ints): Axes along which to flip over. Default is None. + axis (None, int or tuple of ints): Axes along which to flip over. Default is None. """ def __init__(self, prob=0.1, axis=None): - self.axis = axis self.prob = prob + self.flipper = Flip(axis=axis) self._do_transform = False @@ -703,8 +704,7 @@ def __call__(self, img): self.randomize() if not self._do_transform: return img - flipper = Flip(axis=self.axis) - return flipper(img) + return self.flipper(img) @export @@ -802,8 +802,7 @@ def __call__(self, spatial_size=None, grid=None): affine = affine @ create_scale(spatial_dims, self.scale_params) affine = torch.tensor(affine, device=self.device) - if not torch.is_tensor(grid): - grid = torch.tensor(grid) + grid = torch.tensor(grid) if not torch.is_tensor(grid) else grid.clone().detach() if self.device: grid = grid.to(self.device) grid = (affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:])) @@ -874,8 +873,9 @@ def __call__(self, spatial_size=None, grid=None): a 2D (3xHxW) or 3D (4xHxWxD) grid. """ self.randomize() - affine_grid = AffineGrid(self.rotate_params, self.shear_params, self.translate_params, self.scale_params, - self.as_tensor_output, self.device) + affine_grid = AffineGrid(rotate_params=self.rotate_params, shear_params=self.shear_params, + translate_params=self.translate_params, scale_params=self.scale_params, + as_tensor_output=self.as_tensor_output, device=self.device) return affine_grid(spatial_size, grid) @@ -943,8 +943,7 @@ def __call__(self, img, grid, mode='bilinear'): """ if not torch.is_tensor(img): img = torch.tensor(img) - if not torch.is_tensor(grid): - grid = torch.tensor(grid) + grid = torch.tensor(grid) if not torch.is_tensor(grid) else grid.clone().detach() if self.device: img = img.to(self.device) grid = grid.to(self.device) @@ -1002,13 +1001,13 @@ def __init__(self, whether to convert it back to numpy arrays. device (torch.device): device on which the tensor will be allocated. """ - self.affine_grid = AffineGrid(rotate_params, - shear_params, - translate_params, - scale_params, + self.affine_grid = AffineGrid(rotate_params=rotate_params, + shear_params=shear_params, + translate_params=translate_params, + scale_params=scale_params, as_tensor_output=True, device=device) - self.resampler = Resample(padding_mode, as_tensor_output=as_tensor_output, device=device) + self.resampler = Resample(padding_mode=padding_mode, as_tensor_output=as_tensor_output, device=device) self.spatial_size = spatial_size self.mode = mode @@ -1023,8 +1022,8 @@ def __call__(self, img, spatial_size=None, mode=None): """ spatial_size = spatial_size or self.spatial_size mode = mode or self.mode - grid = self.affine_grid(spatial_size) - return self.resampler(img, grid, mode) + grid = self.affine_grid(spatial_size=spatial_size) + return self.resampler(img=img, grid=grid, mode=mode) @export @@ -1062,7 +1061,9 @@ def __init__(self, Affine for the affine transformation parameters configurations. """ - self.rand_affine_grid = RandAffineGrid(rotate_range, shear_range, translate_range, scale_range, True, device) + self.rand_affine_grid = RandAffineGrid(rotate_range=rotate_range, shear_range=shear_range, + translate_range=translate_range, scale_range=scale_range, + as_tensor_output=True, device=device) self.resampler = Resample(padding_mode=padding_mode, as_tensor_output=as_tensor_output, device=device) self.spatial_size = spatial_size @@ -1078,6 +1079,7 @@ def set_random_state(self, seed=None, state=None): def randomize(self): self.do_transform = self.R.rand() < self.prob + self.rand_affine_grid.randomize() def __call__(self, img, spatial_size=None, mode=None): """ @@ -1095,7 +1097,7 @@ def __call__(self, img, spatial_size=None, mode=None): grid = self.rand_affine_grid(spatial_size=spatial_size) else: grid = create_grid(spatial_size) - return self.resampler(img, grid, mode) + return self.resampler(img=img, grid=grid, mode=mode) @export @@ -1121,13 +1123,14 @@ def __init__(self, Args: spacing (2 ints): distance in between the control points. magnitude_range (2 ints): the random offsets will be generated from - `uniform[magnitude[0], magnitude[1])`. + ``uniform[magnitude[0], magnitude[1])``. prob (float): probability of returning a randomized affine grid. defaults to 0.1, with 10% chance returns a randomized grid, - otherwise returns a `spatial_size` centered area centered extracted from the input image. + otherwise returns a ``spatial_size`` centered area extracted from the input image. spatial_size (2 ints): specifying output image spatial size [h, w]. - mode ('nearest'|'bilinear'): interpolation order. Defaults to 'bilinear'. - padding_mode ('zeros'|'border'|'reflection'): mode of handling out of range indices. Defaults to 'zeros'. + mode ('nearest'|'bilinear'): interpolation order. Defaults to ``'bilinear'``. + padding_mode ('zeros'|'border'|'reflection'): mode of handling out of range indices. + Defaults to ``'zeros'``. as_tensor_output (bool): the computation is implemented using pytorch tensors, this option specifies whether to convert it back to numpy arrays. device (torch.device): device on which the tensor will be allocated. @@ -1136,8 +1139,11 @@ def __init__(self, RandAffineGrid for the random affine paramters configurations. Affine for the affine transformation parameters configurations. """ - self.deform_grid = RandDeformGrid(spacing, magnitude_range, as_tensor_output=True, device=device) - self.rand_affine_grid = RandAffineGrid(rotate_range, shear_range, translate_range, scale_range, True, device) + self.deform_grid = RandDeformGrid(spacing=spacing, magnitude_range=magnitude_range, + as_tensor_output=True, device=device) + self.rand_affine_grid = RandAffineGrid(rotate_range=rotate_range, shear_range=shear_range, + translate_range=translate_range, scale_range=scale_range, + as_tensor_output=True, device=device) self.resampler = Resample(padding_mode=padding_mode, as_tensor_output=as_tensor_output, device=device) self.spatial_size = spatial_size @@ -1151,21 +1157,23 @@ def set_random_state(self, seed=None, state=None): Randomizable.set_random_state(self, seed, state) return self - def randomize(self): + def randomize(self, spatial_size): self.do_transform = self.R.rand() < self.prob + self.deform_grid.randomize(spatial_size) + self.rand_affine_grid.randomize() def __call__(self, img, spatial_size=None, mode=None): """ Args: img (ndarray or tensor): shape must be (num_channels, H, W), spatial_size (2 ints): specifying output image spatial size [h, w]. - mode ('nearest'|'bilinear'): interpolation order. Defaults to 'self.mode'. + mode ('nearest'|'bilinear'): interpolation order. Defaults to ``self.mode``. """ - self.randomize() spatial_size = spatial_size or self.spatial_size + self.randomize(spatial_size) mode = mode or self.mode if self.do_transform: - grid = self.deform_grid(spatial_size) + grid = self.deform_grid(spatial_size=spatial_size) grid = self.rand_affine_grid(grid=grid) grid = torch.nn.functional.interpolate(grid[None], spatial_size, mode='bicubic', align_corners=False)[0] else: @@ -1195,15 +1203,16 @@ def __init__(self, """ Args: sigma_range (2 ints): a Gaussian kernel with standard deviation sampled - from `uniform[sigma_range[0], sigma_range[1])` will be used to smooth the random offset grid. + from ``uniform[sigma_range[0], sigma_range[1])`` will be used to smooth the random offset grid. magnitude_range (2 ints): the random offsets on the grid will be generated from - `uniform[magnitude[0], magnitude[1])`. + ``uniform[magnitude[0], magnitude[1])``. prob (float): probability of returning a randomized affine grid. defaults to 0.1, with 10% chance returns a randomized grid, - otherwise returns a `spatial_size` centered area centered extracted from the input image. - spatial_size (2 ints): specifying output image spatial size [h, w, d]. - mode ('nearest'|'bilinear'): interpolation order. Defaults to 'bilinear'. - padding_mode ('zeros'|'border'|'reflection'): mode of handling out of range indices. Defaults to 'zeros'. + otherwise returns a ``spatial_size`` centered area extracted from the input image. + spatial_size (3 ints): specifying output image spatial size [h, w, d]. + mode ('nearest'|'bilinear'): interpolation order. Defaults to ``'bilinear'``. + padding_mode ('zeros'|'border'|'reflection'): mode of handling out of range indices. + Defaults to ``'zeros'``. as_tensor_output (bool): the computation is implemented using pytorch tensors, this option specifies whether to convert it back to numpy arrays. device (torch.device): device on which the tensor will be allocated. @@ -1238,12 +1247,13 @@ def randomize(self, grid_size): self.rand_offset = self.R.uniform(-1., 1., [3] + list(grid_size)) self.magnitude = self.R.uniform(self.magnitude_range[0], self.magnitude_range[1]) self.sigma = self.R.uniform(self.sigma_range[0], self.sigma_range[1]) + self.rand_affine_grid.randomize() def __call__(self, img, spatial_size=None, mode=None): """ Args: img (ndarray or tensor): shape must be (num_channels, H, W, D), - spatial_size (2 ints): specifying output image spatial size [h, w, d]. + spatial_size (3 ints): specifying spatial 3D output image spatial size [h, w, d]. mode ('nearest'|'bilinear'): interpolation order. Defaults to 'self.mode'. """ spatial_size = spatial_size or self.spatial_size diff --git a/tests/test_flip.py b/tests/test_flip.py index a70b9c92c5..3b027ec2c8 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -22,7 +22,7 @@ class FlipTest(NumpyImageTestCase2D): @parameterized.expand([ ("wrong_axis", ['s', 1], TypeError), - ("not_numbers", 's', AssertionError) + ("not_numbers", 's', TypeError) ]) def test_invalid_inputs(self, _, axis, raises): with self.assertRaises(raises): diff --git a/tests/test_random_affine.py b/tests/test_random_affine.py index 5149a5a80d..60c436cc6d 100644 --- a/tests/test_random_affine.py +++ b/tests/test_random_affine.py @@ -34,7 +34,7 @@ as_tensor_output=True, spatial_size=(2, 2, 2), device=None), {'img': torch.ones((1, 3, 3, 3)), 'mode': 'bilinear'}, - torch.tensor([[[[1.0000, 0.7776], [0.4174, 0.0780]], [[0.0835, 1.0000], [0.3026, 0.5732]]]],) + torch.tensor([[[[0.0000, 0.6577], [0.9911, 1.0000]], [[0.7781, 1.0000], [1.0000, 0.4000]]]]) ], [ dict(prob=0.9, @@ -44,7 +44,7 @@ scale_range=[.1, .2], as_tensor_output=True, device=None), {'img': torch.arange(64).reshape((1, 8, 8)), 'spatial_size': (3, 3)}, - torch.tensor([[[27.3614, 18.0237, 8.6860], [40.0440, 30.7063, 21.3686], [52.7266, 43.3889, 34.0512]]]) + torch.tensor([[[16.9127, 13.3079, 9.7031], [26.8129, 23.2081, 19.6033], [36.7131, 33.1083, 29.5035]]]) ], ] diff --git a/tests/test_random_affined.py b/tests/test_random_affined.py new file mode 100644 index 0000000000..b07f7015e5 --- /dev/null +++ b/tests/test_random_affined.py @@ -0,0 +1,90 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms.composables import RandAffined + +TEST_CASES = [ + [ + dict(as_tensor_output=False, device=None, spatial_size=(2, 2), keys=('img', 'seg')), + {'img': torch.ones((3, 3, 3)), 'seg': torch.ones((3, 3, 3))}, + np.ones((3, 2, 2)) + ], + [ + dict(as_tensor_output=True, device=None, spatial_size=(2, 2, 2), keys=('img', 'seg')), + {'img': torch.ones((1, 3, 3, 3)), 'seg': torch.ones((1, 3, 3, 3))}, + torch.ones((1, 2, 2, 2)) + ], + [ + dict(prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + as_tensor_output=True, + spatial_size=(2, 2, 2), + device=None, + keys=('img', 'seg'), + mode='bilinear'), {'img': torch.ones((1, 3, 3, 3)), 'seg': torch.ones((1, 3, 3, 3))}, + torch.tensor([[[[0.0000, 0.6577], [0.9911, 1.0000]], [[0.7781, 1.0000], [1.0000, 0.4000]]]]) + ], + [ + dict(prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[.1, .2], + as_tensor_output=True, + spatial_size=(3, 3), + keys=('img', 'seg'), + device=None), {'img': torch.arange(64).reshape((1, 8, 8)), 'seg': torch.arange(64).reshape((1, 8, 8))}, + torch.tensor([[[16.9127, 13.3079, 9.7031], [26.8129, 23.2081, 19.6033], [36.7131, 33.1083, 29.5035]]]) + ], + [ + dict(prob=0.9, + mode=('bilinear', 'nearest'), + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[.1, .2], + as_tensor_output=False, + spatial_size=(3, 3), + keys=('img', 'seg'), + device=torch.device('cpu:0')), + {'img': torch.arange(64).reshape((1, 8, 8)), 'seg': torch.arange(64).reshape((1, 8, 8))}, + {'img': np.array([[[16.9127, 13.3079, 9.7031], [26.8129, 23.2081, 19.6033], [36.7131, 33.1083, 29.5035]]]), + 'seg': np.array([[[19., 12., 12.], [27., 20., 21.], [35., 36., 29.]]])} + ], +] + + +class TestRandAffined(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_rand_affined(self, input_param, input_data, expected_val): + g = RandAffined(**input_param).set_random_state(123) + res = g(input_data) + for key in res: + result = res[key] + expected = expected_val[key] if isinstance(expected_val, dict) else expected_val + self.assertEqual(torch.is_tensor(result), torch.is_tensor(expected)) + if torch.is_tensor(result): + np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) + else: + np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_random_elastic_2d.py b/tests/test_random_elastic_2d.py index 53f768bf36..d01fd5c556 100644 --- a/tests/test_random_elastic_2d.py +++ b/tests/test_random_elastic_2d.py @@ -24,8 +24,7 @@ [ {'spacing': (.3, .3), 'magnitude_range': (1., 2.), 'prob': 0.9, 'as_tensor_output': False, 'device': None}, {'img': torch.ones((3, 3, 3)), 'spatial_size': (2, 2), 'mode': 'bilinear'}, - np.array([[[0., 0.608901], [1., 0.5702355]], [[0., 0.608901], [1., 0.5702355]], [[0., 0.608901], - [1., 0.5702355]]]), + np.array([[[0., 0.], [0., 0.04970419]], [[0., 0.], [0., 0.04970419]], [[0., 0.], [0., 0.04970419]]]), ], [ { @@ -33,17 +32,17 @@ 'border', 'as_tensor_output': True, 'device': None, 'spatial_size': (2, 2) }, {'img': torch.arange(27).reshape((3, 3, 3))}, - torch.tensor([[[1.0849, 1.1180], [6.8100, 7.0265]], [[10.0849, 10.1180], [15.8100, 16.0265]], - [[19.0849, 19.1180], [24.8100, 25.0265]]]), + torch.tensor([[[1.6605, 1.0083], [6.0000, 6.2224]], [[10.6605, 10.0084], [15.0000, 15.2224]], + [[19.6605, 19.0083], [24.0000, 24.2224]]]), ], [ { - 'spacing': (.3, .3), 'magnitude_range': (1., 2.), 'translate_range': [-.2, .4], 'scale_range': [1.2, 2.2], - 'prob': 0.9, 'as_tensor_output': False, 'device': None + 'spacing': (.3, .3), 'magnitude_range': (.1, .2), 'translate_range': [-.01, .01], + 'scale_range': [0.01, 0.02], 'prob': 0.9, 'as_tensor_output': False, 'device': None, 'spatial_size': (2, 2), }, - {'img': torch.arange(27).reshape((3, 3, 3)), 'spatial_size': (2, 2)}, - np.array([[[0., 1.1731534], [3.8834658, 6.0565934]], [[0., 9.907095], [12.883466, 15.056594]], - [[0., 18.641037], [21.883465, 24.056593]]]), + {'img': torch.arange(27).reshape((3, 3, 3))}, + np.array([[[0.2001334, 1.2563337], [5.2274017, 7.90148]], [[8.675412, 6.9098353], [13.019891, 16.850012]], + [[17.15069, 12.563337], [20.81238, 25.798544]]]) ], ] diff --git a/tests/test_random_elastic_3d.py b/tests/test_random_elastic_3d.py index 5fb3a3130a..065d260de7 100644 --- a/tests/test_random_elastic_3d.py +++ b/tests/test_random_elastic_3d.py @@ -32,8 +32,7 @@ 'as_tensor_output': False, 'device': None, 'spatial_size': (2, 2, 2) }, {'img': torch.arange(27).reshape((1, 3, 3, 3)), 'mode': 'bilinear'}, - np.array([[[[6.016205, 2.3112855], [12.412318, 11.182229]], [[14.619441, 6.9230556], [17.23721, 16.506298]]]]), - ], + np.array([[[[1.6566806, 7.695548], [7.4342523, 13.580086]], [[11.776854, 18.669481], [18.396517, 21.551771]]]])], ] diff --git a/tests/test_random_elasticd_2d.py b/tests/test_random_elasticd_2d.py new file mode 100644 index 0000000000..1f560651ea --- /dev/null +++ b/tests/test_random_elasticd_2d.py @@ -0,0 +1,88 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms.composables import Rand2DElasticd + +TEST_CASES = [ + [ + { + 'keys': ('img', 'seg'), 'spacing': (.3, .3), 'magnitude_range': (1., 2.), 'prob': 0.0, 'as_tensor_output': + False, 'device': None, 'spatial_size': (2, 2) + }, + {'img': torch.ones((3, 3, 3)), 'seg': torch.ones((3, 3, 3))}, + np.ones((3, 2, 2)), + ], + [ + { + 'keys': ('img', 'seg'), 'spacing': (.3, .3), 'magnitude_range': (1., 2.), 'prob': 0.9, 'as_tensor_output': + False, 'device': None, 'spatial_size': (2, 2), 'mode': 'bilinear' + }, + {'img': torch.ones((3, 3, 3)), 'seg': torch.ones((3, 3, 3))}, + np.array([[[0., 0.], [0., 0.04970419]], [[0., 0.], [0., 0.04970419]], [[0., 0.], [0., 0.04970419]]]), + ], + [ + { + 'keys': ('img', 'seg'), 'spacing': (1., 1.), 'magnitude_range': (1., 1.), 'scale_range': [1.2, 2.2], 'prob': + 0.9, 'padding_mode': 'border', 'as_tensor_output': True, 'device': None, 'spatial_size': (2, 2) + }, + {'img': torch.arange(27).reshape((3, 3, 3)), 'seg': torch.arange(27).reshape((3, 3, 3))}, + torch.tensor([[[1.6605, 1.0083], [6.0000, 6.2224]], [[10.6605, 10.0084], [15.0000, 15.2224]], + [[19.6605, 19.0083], [24.0000, 24.2224]]]), + ], + [ + { + 'keys': ('img', 'seg'), 'spacing': (.3, .3), 'magnitude_range': (.1, .2), 'translate_range': [-.01, .01], + 'scale_range': [0.01, 0.02], 'prob': 0.9, 'as_tensor_output': False, 'device': None, 'spatial_size': (2, 2), + }, + {'img': torch.arange(27).reshape((3, 3, 3)), 'seg': torch.arange(27).reshape((3, 3, 3))}, + np.array([[[0.2001334, 1.2563337], [5.2274017, 7.90148]], [[8.675412, 6.9098353], [13.019891, 16.850012]], + [[17.15069, 12.563337], [20.81238, 25.798544]]]) + ], + [ + { + 'keys': ('img', 'seg'), 'mode': ('bilinear', 'nearest'), 'spacing': (.3, .3), 'magnitude_range': (.1, .2), + 'translate_range': [-.01, .01], + 'scale_range': [0.01, 0.02], 'prob': 0.9, 'as_tensor_output': True, 'device': None, 'spatial_size': (2, 2), + }, + {'img': torch.arange(27).reshape((3, 3, 3)), 'seg': torch.arange(27).reshape((3, 3, 3))}, + {'img': torch.tensor([[[0.2001334, 1.2563337], [5.2274017, 7.90148]], + [[8.675412, 6.9098353], [13.019891, 16.850012]], + [[17.15069, 12.563337], [20.81238, 25.798544]]]), + 'seg': torch.tensor([[[0., 2.], [6., 8.]], [[9., 11.], [15., 17.]], [[18., 20.], [24., 26.]]])} + ], +] + + +class TestRand2DElasticd(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_rand_2d_elasticd(self, input_param, input_data, expected_val): + g = Rand2DElasticd(**input_param) + g.set_random_state(123) + res = g(input_data) + for key in res: + result = res[key] + expected = expected_val[key] if isinstance(expected_val, dict) else expected_val + self.assertEqual(torch.is_tensor(result), torch.is_tensor(expected)) + if torch.is_tensor(result): + np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) + else: + np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_random_elasticd_3d.py b/tests/test_random_elasticd_3d.py new file mode 100644 index 0000000000..a72aa3bbb9 --- /dev/null +++ b/tests/test_random_elasticd_3d.py @@ -0,0 +1,72 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms.composables import Rand3DElasticd + +TEST_CASES = [ + [{'keys': ('img', 'seg'), 'magnitude_range': (.3, 2.3), 'sigma_range': (1., 20.), + 'prob': 0.0, 'as_tensor_output': False, 'device': None, 'spatial_size': (2, 2, 2)}, + {'img': torch.ones((2, 3, 3, 3)), 'seg': torch.ones((2, 3, 3, 3))}, + np.ones((2, 2, 2, 2))], + [ + {'keys': ('img', 'seg'), 'magnitude_range': (.3, .3), 'sigma_range': (1., 2.), + 'prob': 0.9, 'as_tensor_output': False, 'device': None, 'spatial_size': (2, 2, 2)}, + {'img': torch.arange(27).reshape((1, 3, 3, 3)), 'seg': torch.arange(27).reshape((1, 3, 3, 3))}, + np.array([[[[3.2385552, 4.753422], [7.779232, 9.286472]], [[16.769115, 18.287868], [21.300673, 22.808704]]]]), + ], + [ + { + 'keys': ('img', 'seg'), 'magnitude_range': (.3, .3), 'sigma_range': (1., 2.), 'prob': 0.9, + 'rotate_range': [1, 1, 1], 'as_tensor_output': False, 'device': None, + 'spatial_size': (2, 2, 2), 'mode': 'bilinear' + }, + {'img': torch.arange(27).reshape((1, 3, 3, 3)), 'seg': torch.arange(27).reshape((1, 3, 3, 3))}, + np.array([[[[1.6566806, 7.695548], [7.4342523, 13.580086]], [[11.776854, 18.669481], [18.396517, 21.551771]]]]), + ], + [ + { + 'keys': ('img', 'seg'), 'mode': ('bilinear', 'nearest'), 'magnitude_range': (.3, .3), + 'sigma_range': (1., 2.), 'prob': 0.9, 'rotate_range': [1, 1, 1], + 'as_tensor_output': True, 'device': torch.device('cpu:0'), 'spatial_size': (2, 2, 2) + }, + {'img': torch.arange(27).reshape((1, 3, 3, 3)), 'seg': torch.arange(27).reshape((1, 3, 3, 3))}, + {'img': torch.tensor([[[[1.6566806, 7.695548], [7.4342523, 13.580086]], + [[11.776854, 18.669481], [18.396517, 21.551771]]]]), + 'seg': torch.tensor([[[[1., 11.], [7., 17.]], [[9., 19.], [15., 25.]]]])} + ], +] + + +class TestRand3DElasticd(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_rand_3d_elasticd(self, input_param, input_data, expected_val): + g = Rand3DElasticd(**input_param) + g.set_random_state(123) + res = g(input_data) + for key in res: + result = res[key] + expected = expected_val[key] if isinstance(expected_val, dict) else expected_val + self.assertEqual(torch.is_tensor(result), torch.is_tensor(expected)) + if torch.is_tensor(result): + np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) + else: + np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_random_flip.py b/tests/test_random_flip.py index ec95485f20..ee89a133d9 100644 --- a/tests/test_random_flip.py +++ b/tests/test_random_flip.py @@ -14,7 +14,7 @@ import numpy as np from parameterized import parameterized -from monai.transforms import RandomFlip +from monai.transforms import RandFlip from tests.utils import NumpyImageTestCase2D @@ -22,11 +22,11 @@ class RandomFlipTest(NumpyImageTestCase2D): @parameterized.expand([ ("wrong_axis", ['s', 1], TypeError), - ("not_numbers", 's', AssertionError) + ("not_numbers", 's', TypeError) ]) def test_invalid_inputs(self, _, axis, raises): with self.assertRaises(raises): - flip = RandomFlip(prob=1.0, axis=axis) + flip = RandFlip(prob=1.0, axis=axis) flip(self.imt) @parameterized.expand([ @@ -35,7 +35,7 @@ def test_invalid_inputs(self, _, axis, raises): ("many_axis", [0, 1, 2]) ]) def test_correct_results(self, _, axis): - flip = RandomFlip(prob=1.0, axis=axis) + flip = RandFlip(prob=1.0, axis=axis) expected = np.flip(self.imt, axis) self.assertTrue(np.allclose(expected, flip(self.imt))) From 84e5c06a457e0a1e177994158051f4455cb5dbb1 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 11 Mar 2020 22:30:28 +0800 Subject: [PATCH 2/2] [DLMED] change UNet to 1 output --- monai/networks/nets/unet.py | 3 +-- tests/integration_sliding_window.py | 2 +- tests/integration_unet2d.py | 2 +- tests/test_unet.py | 8 ++++---- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index b0d42612eb..6018cd06e8 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -13,7 +13,6 @@ from monai.networks.blocks.convolutions import Convolution, ResidualUnit from monai.networks.layers.simplelayers import SkipConnection -from monai.networks.utils import predict_segmentation from monai.utils import export from monai.utils.aliases import alias @@ -98,4 +97,4 @@ def _get_up_layer(self, in_channels, out_channels, strides, is_top): def forward(self, x): x = self.model(x) - return x, predict_segmentation(x) + return x diff --git a/tests/integration_sliding_window.py b/tests/integration_sliding_window.py index db10d7cc49..8025a4c821 100644 --- a/tests/integration_sliding_window.py +++ b/tests/integration_sliding_window.py @@ -52,7 +52,7 @@ def _sliding_window_processor(_engine, batch): net.eval() img, seg, meta_data = batch with torch.no_grad(): - seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, lambda x: net(x)[0], device) + seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, lambda x: net(x), device) return predict_segmentation(seg_probs) infer_engine = Engine(_sliding_window_processor) diff --git a/tests/integration_unet2d.py b/tests/integration_unet2d.py index 1fd9074c66..ed1afa9872 100644 --- a/tests/integration_unet2d.py +++ b/tests/integration_unet2d.py @@ -46,7 +46,7 @@ def __len__(self): src = DataLoader(_TestBatch(), batch_size=batch_size) def loss_fn(pred, grnd): - return loss(pred[0], grnd) + return loss(pred, grnd) trainer = create_supervised_trainer(net, opt, loss_fn, device, False) diff --git a/tests/test_unet.py b/tests/test_unet.py index 98102375a6..d64407bb4b 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -26,7 +26,7 @@ 'num_res_units': 1, }, torch.randn(16, 1, 32, 32), - (16, 32, 32), + (16, 3, 32, 32), ] TEST_CASE_2 = [ # single channel 3D, batch 16 @@ -39,7 +39,7 @@ 'num_res_units': 1, }, torch.randn(16, 1, 32, 24, 48), - (16, 32, 24, 48), + (16, 3, 32, 24, 48), ] TEST_CASE_3 = [ # 4-channel 3D, batch 16 @@ -52,7 +52,7 @@ 'num_res_units': 1, }, torch.randn(16, 4, 32, 64, 48), - (16, 32, 64, 48), + (16, 3, 32, 64, 48), ] @@ -63,7 +63,7 @@ def test_shape(self, input_param, input_data, expected_shape): net = UNet(**input_param) net.eval() with torch.no_grad(): - result = net.forward(input_data)[1] + result = net.forward(input_data) self.assertEqual(result.shape, expected_shape)