diff --git a/monai/transforms/composables.py b/monai/transforms/composables.py index 4d05d6fc9a..0e404c5233 100644 --- a/monai/transforms/composables.py +++ b/monai/transforms/composables.py @@ -85,7 +85,6 @@ def __init__(self, keys, affine_key, pixdim, interp_order=2, keep_shape=False, o interp_order = ensure_tuple(interp_order) self.interp_order = interp_order \ if len(interp_order) == len(self.keys) else interp_order * len(self.keys) - print(self.interp_order) self.output_key = output_key def __call__(self, data): diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index e7ec89af5a..a2352e5db8 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -631,6 +631,56 @@ def __call__(self, img): return data +@export +class RandRotate(Randomizable): + """Randomly rotates the input arrays. + + Args: + prob (float): Probability of rotation. + degrees (tuple of float or float): Range of rotation in degrees. If single number, + angle is picked from (-degrees, degrees). + axes (tuple of 2 ints): Axes of rotation. Default: (1, 2). This is the first two + axis in spatial dimensions according to MONAI channel first shape assumption. + reshape (bool): If true, output shape is made same as input. Default: True. + order (int): Order of spline interpolation. Range 0-5. Default: 1. This is + different from scipy where default interpolation is 3. + mode (str): Points outside boundary filled according to this mode. Options are + 'constant', 'nearest', 'reflect', 'wrap'. Default: 'constant'. + cval (scalar): Value to fill outside boundary. Default: 0. + prefiter (bool): Apply spline_filter before interpolation. Default: True. + """ + + def __init__(self, degrees, prob=0.1, axes=(1, 2), reshape=True, order=1, + mode='constant', cval=0, prefilter=True): + self.prob = prob + self.degrees = degrees + self.reshape = reshape + self.order = order + self.mode = mode + self.cval = cval + self.prefilter = prefilter + self.axes = axes + + if not hasattr(self.degrees, '__iter__'): + self.degrees = (-self.degrees, self.degrees) + assert len(self.degrees) == 2, "degrees should be a number or pair of numbers." + + self._do_transform = False + self.angle = None + + def randomize(self): + self._do_transform = self.R.random_sample() < self.prob + self.angle = self.R.uniform(low=self.degrees[0], high=self.degrees[1]) + + def __call__(self, img): + self.randomize() + if not self._do_transform: + return img + rotator = Rotate(self.angle, self.axes, self.reshape, self.order, + self.mode, self.cval, self.prefilter) + return rotator(img) + + @export class RandomFlip(Randomizable): """Randomly flips the image along axes. diff --git a/tests/test_random_rotate.py b/tests/test_random_rotate.py new file mode 100644 index 0000000000..29036663af --- /dev/null +++ b/tests/test_random_rotate.py @@ -0,0 +1,43 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +import scipy.ndimage +from parameterized import parameterized + +from monai.transforms import RandRotate +from tests.utils import NumpyImageTestCase2D + + +class RandomRotateTest(NumpyImageTestCase2D): + + @parameterized.expand([ + (90, (1, 2), True, 1, 'reflect', 0, True), + ((-45, 45), (2, 1), True, 3, 'constant', 0, True), + (180, (2, 3), False, 2, 'constant', 4, False), + ]) + def test_correct_results(self, degrees, axes, reshape, + order, mode, cval, prefilter): + rotate_fn = RandRotate(degrees, prob=1.0, axes=axes, reshape=reshape, + order=order, mode=mode, cval=cval, prefilter=prefilter) + rotate_fn.set_random_state(243) + rotated = rotate_fn(self.imt) + + angle = rotate_fn.angle + expected = scipy.ndimage.rotate(self.imt, angle, axes, reshape, order=order, + mode=mode, cval=cval, prefilter=prefilter) + self.assertTrue(np.allclose(expected, rotated)) + + +if __name__ == '__main__': + unittest.main()