From 25fae0f5202d09229d8065ace6518a8818774edd Mon Sep 17 00:00:00 2001 From: Mohammad Adil Date: Thu, 27 Feb 2020 17:51:35 -0800 Subject: [PATCH] Adding rotation transform. --- monai/transforms/transforms.py | 37 ++++++++++++++++++++++++++++++ requirements.txt | 1 + tests/test_rotate.py | 41 ++++++++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+) create mode 100644 tests/test_rotate.py diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index dc6f571106..454c3aa7f6 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -15,6 +15,7 @@ import numpy as np import torch +import scipy.ndimage import monai from monai.data.utils import get_random_patch, get_valid_patch_size @@ -80,6 +81,42 @@ def __call__(self, img): return np.flip(img, self.axis) +@export +class Rotate: + """ + Rotates an input image by given angle. Uses scipy.ndimage.rotate. For more details, see + http://lagrange.univ-lyon1.fr/docs/scipy/0.17.1/generated/scipy.ndimage.rotate.html. + + Args: + angle (float): Rotation angle in degrees. + axes (tuple of 2 ints): Axes of rotation. Default: (1, 2). This is the first two + axis in spatial dimensions according to MONAI channel first shape assumption. + reshape (bool): If true, output shape is made same as input. Default: True. + order (int): Order of spline interpolation. Range 0-5. Default: 1. This is + different from scipy where default interpolation is 3. + mode (str): Points outside boundary filled according to this mode. Options are + 'constant', 'nearest', 'reflect', 'wrap'. Default: 'constant'. + cval (scalar): Values to fill outside boundary. Default: 0. + prefiter (bool): Apply spline_filter before interpolation. Default: True. + """ + + def __init__(self, angle, axes=(1, 2), reshape=True, order=1, + mode='constant', cval=0, prefilter=True): + self.angle = angle + self.reshape = reshape + self.order = order + self.mode = mode + self.cval = cval + self.prefilter = prefilter + self.axes = axes + + def __call__(self, img): + return scipy.ndimage.rotate(img, self.angle, self.axes, + reshape=self.reshape, order=self.order, + mode=self.mode, cval=self.cval, + prefilter=self.prefilter) + + @export class ToTensor: """ diff --git a/requirements.txt b/requirements.txt index 8325300968..91985396c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ coverage nibabel parameterized tensorboard +scipy \ No newline at end of file diff --git a/tests/test_rotate.py b/tests/test_rotate.py new file mode 100644 index 0000000000..98e25f587f --- /dev/null +++ b/tests/test_rotate.py @@ -0,0 +1,41 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +import scipy.ndimage +from parameterized import parameterized + +from monai.transforms import Rotate +from tests.utils import NumpyImageTestCase2D + + +class RotateTest(NumpyImageTestCase2D): + + @parameterized.expand([ + (90, (1, 2), True, 1, 'reflect', 0, True), + (-90, (2, 1), True, 3, 'constant', 0, True), + (180, (2, 3), False, 2, 'constant', 4, False), + ]) + def test_correct_results(self, angle, axes, reshape, + order, mode, cval, prefilter): + rotate_fn = Rotate(angle, axes, reshape, + order, mode, cval, prefilter) + rotated = rotate_fn(self.imt) + + expected = scipy.ndimage.rotate(self.imt, angle, axes, reshape, order=order, + mode=mode, cval=cval, prefilter=prefilter) + self.assertTrue(np.allclose(expected, rotated)) + + +if __name__ == '__main__': + unittest.main()