diff --git a/monai/transforms/composables.py b/monai/transforms/composables.py index 8ccd5a747a..c19c3f7df9 100644 --- a/monai/transforms/composables.py +++ b/monai/transforms/composables.py @@ -22,7 +22,8 @@ from monai.transforms.compose import Randomizable, Transform from monai.transforms.transforms import (LoadNifti, AsChannelFirst, Orientation, AddChannel, Spacing, Rotate90, SpatialCrop, - RandAffine, Rand2DElastic, Rand3DElastic) + RandAffine, Rand2DElastic, Rand3DElastic, + Flip, Rotate, Zoom) from monai.utils.misc import ensure_tuple from monai.transforms.utils import generate_pos_neg_label_crop_centers, create_grid from monai.utils.aliases import alias @@ -476,7 +477,6 @@ def __init__(self, keys, as_tensor_output (bool): the computation is implemented using pytorch tensors, this option specifies whether to convert it back to numpy arrays. device (torch.device): device on which the tensor will be allocated. - See also: - ``RandAffineGrid`` for the random affine paramters configurations. - ``Affine`` for the affine transformation parameters configurations. @@ -551,7 +551,6 @@ def __init__(self, keys, as_tensor_output (bool): the computation is implemented using pytorch tensors, this option specifies whether to convert it back to numpy arrays. device (torch.device): device on which the tensor will be allocated. - See also: - ``RandAffineGrid`` for the random affine paramters configurations. - ``Affine`` for the affine transformation parameters configurations. @@ -594,3 +593,228 @@ def __call__(self, data): for key in self.keys: # same interpolation mode d[key] = self.rand_3d_elastic.resampler(d[key], grid, mode=self.rand_3d_elastic.mode) return d + + +@export +@alias('FlipD', 'FlipDict') +class Flipd(MapTransform): + """Dictionary-based wrapper of Flip. + + Args: + keys (dict): Keys to pick data for transformation. + axis (None, int or tuple of ints): Axes along which to flip over. Default is None. + """ + + def __init__(self, keys, axis=None): + MapTransform.__init__(self, keys) + self.flipper = Flip(axis=axis) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.flipper(d[key]) + return d + + +@export +@alias('RandFlipD', 'RandFlipDict') +class RandFlipd(Randomizable, MapTransform): + """Dict-based wrapper of RandFlip. + + Args: + prob (float): Probability of flipping. + axis (None, int or tuple of ints): Axes along which to flip over. Default is None. + """ + + def __init__(self, keys, prob=0.1, axis=None): + MapTransform.__init__(self, keys) + self.axis = axis + self.prob = prob + + self._do_transform = False + self.flipper = Flip(axis=axis) + + def randomize(self): + self._do_transform = self.R.random_sample() < self.prob + + def __call__(self, data): + self.randomize() + d = dict(data) + if not self._do_transform: + return d + for key in self.keys: + d[key] = self.flipper(d[key]) + return d + + +@export +@alias('RotateD', 'RotateDict') +class Rotated(MapTransform): + """Dictionary-based wrapper of Rotate. + + Args: + keys (dict): Keys to pick data for transformation. + angle (float): Rotation angle in degrees. + axes (tuple of 2 ints): Axes of rotation. Default: (1, 2). This is the first two + axis in spatial dimensions according to MONAI channel first shape assumption. + reshape (bool): If true, output shape is made same as input. Default: True. + order (int): Order of spline interpolation. Range 0-5. Default: 1. This is + different from scipy where default interpolation is 3. + mode (str): Points outside boundary filled according to this mode. Options are + 'constant', 'nearest', 'reflect', 'wrap'. Default: 'constant'. + cval (scalar): Values to fill outside boundary. Default: 0. + prefiter (bool): Apply spline_filter before interpolation. Default: True. + """ + + def __init__(self, keys, angle, axes=(1, 2), reshape=True, order=1, + mode='constant', cval=0, prefilter=True): + MapTransform.__init__(self, keys) + self.rotator = Rotate(angle=angle, axes=axes, reshape=reshape, + order=order, mode=mode, cval=cval, prefilter=prefilter) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.rotator(d[key]) + return d + + +@export +@alias('RandRotateD', 'RandRotateDict') +class RandRotated(Randomizable, MapTransform): + """Randomly rotates the input arrays. + + Args: + prob (float): Probability of rotation. + degrees (tuple of float or float): Range of rotation in degrees. If single number, + angle is picked from (-degrees, degrees). + axes (tuple of 2 ints): Axes of rotation. Default: (1, 2). This is the first two + axis in spatial dimensions according to MONAI channel first shape assumption. + reshape (bool): If true, output shape is made same as input. Default: True. + order (int): Order of spline interpolation. Range 0-5. Default: 1. This is + different from scipy where default interpolation is 3. + mode (str): Points outside boundary filled according to this mode. Options are + 'constant', 'nearest', 'reflect', 'wrap'. Default: 'constant'. + cval (scalar): Value to fill outside boundary. Default: 0. + prefiter (bool): Apply spline_filter before interpolation. Default: True. + """ + def __init__(self, keys, degrees, prob=0.1, axes=(1, 2), reshape=True, order=1, + mode='constant', cval=0, prefilter=True): + MapTransform.__init__(self, keys) + self.prob = prob + self.degrees = degrees + self.reshape = reshape + self.order = order + self.mode = mode + self.cval = cval + self.prefilter = prefilter + self.axes = axes + + if not hasattr(self.degrees, '__iter__'): + self.degrees = (-self.degrees, self.degrees) + assert len(self.degrees) == 2, "degrees should be a number or pair of numbers." + + self._do_transform = False + self.angle = None + + def randomize(self): + self._do_transform = self.R.random_sample() < self.prob + self.angle = self.R.uniform(low=self.degrees[0], high=self.degrees[1]) + + def __call__(self, data): + self.randomize() + d = dict(data) + if not self._do_transform: + return d + rotator = Rotate(self.angle, self.axes, self.reshape, self.order, + self.mode, self.cval, self.prefilter) + for key in self.keys: + d[key] = self.flipper(d[key]) + return d + + +@export +@alias('ZoomD', 'ZoomDict') +class Zoomd(MapTransform): + """Dictionary-based wrapper of Zoom transform. + + Args: + zoom (float or sequence): The zoom factor along the spatial axes. + If a float, zoom is the same for each spatial axis. + If a sequence, zoom should contain one value for each spatial axis. + order (int): order of interpolation. Default=3. + mode (str): Determines how input is extended beyond boundaries. Default is 'constant'. + cval (scalar, optional): Value to fill past edges. Default is 0. + use_gpu (bool): Should use cpu or gpu. Uses cupyx which doesn't support order > 1 and modes + 'wrap' and 'reflect'. Defaults to cpu for these cases or if cupyx not found. + keep_size (bool): Should keep original size (pad if needed). + """ + + def __init__(self, keys, zoom, order=3, mode='constant', cval=0, + prefilter=True, use_gpu=False, keep_size=False): + MapTransform.__init__(self, keys) + self.zoomer = Zoom(zoom=zoom, order=order, mode=mode, cval=cval, + prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.zoomer(d[key]) + return d + + +@export +@alias('RandZoomD', 'RandZoomDict') +class RandZoomd(Randomizable, MapTransform): + """Dict-based wrapper of RandZoom. + + Args: + keys (dict): Keys to pick data for transformation. + prob (float): Probability of zooming. + min_zoom (float or sequence): Min zoom factor. Can be float or sequence same size as image. + max_zoom (float or sequence): Max zoom factor. Can be float or sequence same size as image. + order (int): order of interpolation. Default=3. + mode ('reflect', 'constant', 'nearest', 'mirror', 'wrap'): Determines how input is + extended beyond boundaries. Default: 'constant'. + cval (scalar, optional): Value to fill past edges. Default is 0. + use_gpu (bool): Should use cpu or gpu. Uses cupyx which doesn't support order > 1 and modes + 'wrap' and 'reflect'. Defaults to cpu for these cases or if cupyx not found. + keep_size (bool): Should keep original size (pad if needed). + """ + + def __init__(self, keys, prob=0.1, min_zoom=0.9, + max_zoom=1.1, order=3, mode='constant', + cval=0, prefilter=True, use_gpu=False, keep_size=False): + MapTransform.__init__(self, keys) + if hasattr(min_zoom, '__iter__') and \ + hasattr(max_zoom, '__iter__'): + assert len(min_zoom) == len(max_zoom), "min_zoom and max_zoom must have same length." + self.min_zoom = min_zoom + self.max_zoom = max_zoom + self.prob = prob + self.order = order + self.mode = mode + self.cval = cval + self.prefilter = prefilter + self.use_gpu = use_gpu + self.keep_size = keep_size + + self._do_transform = False + self._zoom = None + + def randomize(self): + self._do_transform = self.R.random_sample() < self.prob + if hasattr(self.min_zoom, '__iter__'): + self._zoom = (self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)) + else: + self._zoom = self.R.uniform(self.min_zoom, self.max_zoom) + + def __call__(self, data): + self.randomize() + d = dict(data) + if not self._do_transform: + return d + zoomer = Zoom(self._zoom, self.order, self.mode, self.cval, self.prefilter, self.use_gpu, self.keep_size) + for key in self.keys: + d[key] = zoomer(d[key]) + return d diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index 370a7fb305..8f140972f6 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -434,7 +434,7 @@ def __call__(self, img): pad_vec[idx] = [half, diff - half] elif diff < 0: # need slicing slice_vec[idx] = slice(half, half + od) - zoomed = np.pad(zoomed, pad_vec) + zoomed = np.pad(zoomed, pad_vec, mode='constant') return zoomed[tuple(slice_vec)] @@ -696,6 +696,7 @@ def __init__(self, prob=0.1, axis=None): self.flipper = Flip(axis=axis) self._do_transform = False + self.flipper = Flip(axis=axis) def randomize(self): self._do_transform = self.R.random_sample() < self.prob diff --git a/tests/test_flip.py b/tests/test_flip.py index 3b027ec2c8..a261c315e2 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -14,31 +14,44 @@ import numpy as np from parameterized import parameterized -from monai.transforms import Flip +from monai.transforms import Flip, Flipd from tests.utils import NumpyImageTestCase2D +INVALID_CASES = [("wrong_axis", ['s', 1], TypeError), + ("not_numbers", 's', TypeError)] + +VALID_CASES = [("no_axis", None), + ("one_axis", 1), + ("many_axis", [0, 1, 2])] + class FlipTest(NumpyImageTestCase2D): - @parameterized.expand([ - ("wrong_axis", ['s', 1], TypeError), - ("not_numbers", 's', TypeError) - ]) + @parameterized.expand(INVALID_CASES) def test_invalid_inputs(self, _, axis, raises): with self.assertRaises(raises): flip = Flip(axis) flip(self.imt) - @parameterized.expand([ - ("no_axis", None), - ("one_axis", 1), - ("many_axis", [0, 1, 2]) - ]) + @parameterized.expand(INVALID_CASES) + def test_invalid_cases_dict(self, _, axis, raises): + with self.assertRaises(raises): + flip = Flipd(keys='img', axis=axis) + flip({'img': self.imt}) + + @parameterized.expand(VALID_CASES) def test_correct_results(self, _, axis): flip = Flip(axis=axis) expected = np.flip(self.imt, axis) self.assertTrue(np.allclose(expected, flip(self.imt))) + @parameterized.expand(VALID_CASES) + def test_correct_results_dict(self, _, axis): + flip = Flipd(keys='img', axis=axis) + expected = np.flip(self.imt, axis) + res = flip({'img': self.imt}) + assert np.allclose(expected, res['img']) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_random_flip.py b/tests/test_rand_flip.py similarity index 62% rename from tests/test_random_flip.py rename to tests/test_rand_flip.py index ee89a133d9..be03ff5a28 100644 --- a/tests/test_random_flip.py +++ b/tests/test_rand_flip.py @@ -14,31 +14,38 @@ import numpy as np from parameterized import parameterized -from monai.transforms import RandFlip +from monai.transforms import RandFlip, RandFlipd from tests.utils import NumpyImageTestCase2D +INVALID_CASES = [("wrong_axis", ['s', 1], TypeError), + ("not_numbers", 's', TypeError)] -class RandomFlipTest(NumpyImageTestCase2D): +VALID_CASES = [("no_axis", None), + ("one_axis", 1), + ("many_axis", [0, 1, 2])] - @parameterized.expand([ - ("wrong_axis", ['s', 1], TypeError), - ("not_numbers", 's', TypeError) - ]) +class RandFlipTest(NumpyImageTestCase2D): + + @parameterized.expand(INVALID_CASES) def test_invalid_inputs(self, _, axis, raises): with self.assertRaises(raises): flip = RandFlip(prob=1.0, axis=axis) flip(self.imt) - @parameterized.expand([ - ("no_axis", None), - ("one_axis", 1), - ("many_axis", [0, 1, 2]) - ]) + @parameterized.expand(VALID_CASES) def test_correct_results(self, _, axis): flip = RandFlip(prob=1.0, axis=axis) expected = np.flip(self.imt, axis) self.assertTrue(np.allclose(expected, flip(self.imt))) + @parameterized.expand(VALID_CASES) + def test_correct_results_dict(self, _, axis): + flip = RandFlipd(keys='img', prob=1.0, axis=axis) + res = flip({'img': self.imt}) + + expected = np.flip(self.imt, axis) + self.assertTrue(np.allclose(expected, res['img'])) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_random_rotate.py b/tests/test_rand_rotate.py similarity index 100% rename from tests/test_random_rotate.py rename to tests/test_rand_rotate.py diff --git a/tests/test_random_zoom.py b/tests/test_rand_zoom.py similarity index 76% rename from tests/test_random_zoom.py rename to tests/test_rand_zoom.py index d193a16dd2..530504b887 100644 --- a/tests/test_random_zoom.py +++ b/tests/test_rand_zoom.py @@ -17,15 +17,14 @@ from scipy.ndimage import zoom as zoom_scipy from parameterized import parameterized -from monai.transforms import RandZoom +from monai.transforms import RandZoom, RandZoomd from tests.utils import NumpyImageTestCase2D +VALID_CASES = [(0.9, 1.1, 3, 'constant', 0, True, False, False)] class ZoomTest(NumpyImageTestCase2D): - @parameterized.expand([ - (0.9, 1.1, 3, 'constant', 0, True, False, False), - ]) + @parameterized.expand(VALID_CASES) def test_correct_results(self, min_zoom, max_zoom, order, mode, cval, prefilter, use_gpu, keep_size): random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, order=order, @@ -39,6 +38,21 @@ def test_correct_results(self, min_zoom, max_zoom, order, mode, self.assertTrue(np.allclose(expected, zoomed)) + @parameterized.expand(VALID_CASES) + def test_correct_results_dict(self, min_zoom, max_zoom, order, mode, + cval, prefilter, use_gpu, keep_size): + keys = 'img' + random_zoom = RandZoomd(keys, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, order=order, + mode=mode, cval=cval, prefilter=prefilter, use_gpu=use_gpu, + keep_size=keep_size) + random_zoom.set_random_state(234) + + zoomed = random_zoom({keys: self.imt}) + expected = zoom_scipy(self.imt, zoom=random_zoom._zoom, mode=mode, + order=order, cval=cval, prefilter=prefilter) + + self.assertTrue(np.allclose(expected, zoomed[keys])) + @parameterized.expand([ (0.8, 1.2, 1, 'constant', 0, True) ]) diff --git a/tests/test_rotate.py b/tests/test_rotate.py index 98e25f587f..0c34f5809e 100644 --- a/tests/test_rotate.py +++ b/tests/test_rotate.py @@ -15,17 +15,16 @@ import scipy.ndimage from parameterized import parameterized -from monai.transforms import Rotate +from monai.transforms import Rotate, Rotated from tests.utils import NumpyImageTestCase2D +TEST_CASES = [(90, (1, 2), True, 1, 'reflect', 0, True), + (-90, (2, 1), True, 3, 'constant', 0, True), + (180, (2, 3), False, 2, 'constant', 4, False)] class RotateTest(NumpyImageTestCase2D): - @parameterized.expand([ - (90, (1, 2), True, 1, 'reflect', 0, True), - (-90, (2, 1), True, 3, 'constant', 0, True), - (180, (2, 3), False, 2, 'constant', 4, False), - ]) + @parameterized.expand(TEST_CASES) def test_correct_results(self, angle, axes, reshape, order, mode, cval, prefilter): rotate_fn = Rotate(angle, axes, reshape, @@ -36,6 +35,18 @@ def test_correct_results(self, angle, axes, reshape, mode=mode, cval=cval, prefilter=prefilter) self.assertTrue(np.allclose(expected, rotated)) + @parameterized.expand(TEST_CASES) + def test_correct_results_dict(self, angle, axes, reshape, + order, mode, cval, prefilter): + key = 'img' + rotate_fn = Rotated(key, angle, axes, reshape, order, + mode, cval, prefilter) + rotated = rotate_fn({key: self.imt}) + + expected = scipy.ndimage.rotate(self.imt, angle, axes, reshape, order=order, + mode=mode, cval=cval, prefilter=prefilter) + self.assertTrue(np.allclose(expected, rotated[key])) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_zoom.py b/tests/test_zoom.py index 874e587a98..83795542bc 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -17,17 +17,22 @@ from scipy.ndimage import zoom as zoom_scipy from parameterized import parameterized -from monai.transforms import Zoom +from monai.transforms import Zoom, Zoomd from tests.utils import NumpyImageTestCase2D +VALID_CASES = [(1.1, 3, 'constant', 0, True, False, False), + (0.9, 3, 'constant', 0, True, False, False), + (0.8, 1, 'reflect', 0, False, False, False)] + +GPU_CASES = [("gpu_zoom", 0.6, 1, 'constant', 0, True)] + +INVALID_CASES = [("no_zoom", None, 1, TypeError), + ("invalid_order", 0.9, 's', AssertionError)] + class ZoomTest(NumpyImageTestCase2D): - @parameterized.expand([ - (1.1, 3, 'constant', 0, True, False, False), - (0.9, 3, 'constant', 0, True, False, False), - (0.8, 1, 'reflect', 0, False, False, False) - ]) + @parameterized.expand(VALID_CASES) def test_correct_results(self, zoom, order, mode, cval, prefilter, use_gpu, keep_size): zoom_fn = Zoom(zoom=zoom, order=order, mode=mode, cval=cval, prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size) @@ -36,9 +41,19 @@ def test_correct_results(self, zoom, order, mode, cval, prefilter, use_gpu, keep cval=cval, prefilter=prefilter) self.assertTrue(np.allclose(expected, zoomed)) - @parameterized.expand([ - ("gpu_zoom", 0.6, 1, 'constant', 0, True) - ]) + @parameterized.expand(VALID_CASES) + def test_correct_results_dict(self, zoom, order, mode, cval, prefilter, use_gpu, keep_size): + key = 'img' + zoom_fn = Zoomd(key, zoom=zoom, order=order, mode=mode, cval=cval, + prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size) + zoomed = zoom_fn({key: self.imt[0]}) + + expected = zoom_scipy(self.imt, zoom=zoom, mode=mode, order=order, + cval=cval, prefilter=prefilter) + self.assertTrue(np.allclose(expected, zoomed[key])) + + + @parameterized.expand(GPU_CASES) def test_gpu_zoom(self, _, zoom, order, mode, cval, prefilter): if importlib.util.find_spec('cupy'): zoom_fn = Zoom(zoom=zoom, order=order, mode=mode, cval=cval, @@ -57,10 +72,7 @@ def test_keep_size(self): zoomed = zoom_fn(self.imt[0]) self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) - @parameterized.expand([ - ("no_zoom", None, 1, TypeError), - ("invalid_order", 0.9, 's', AssertionError) - ]) + @parameterized.expand(INVALID_CASES) def test_invalid_inputs(self, _, zoom, order, raises): with self.assertRaises(raises): zoom_fn = Zoom(zoom=zoom, order=order)