diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index 0cef9dd783..1098c23fab 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -108,12 +108,12 @@ class Resize: Args: order (int): Order of spline interpolation. Default=1. - mode (str): Points outside boundaries are filled according to given mode. + mode (str): Points outside boundaries are filled according to given mode. Options are 'constant', 'edge', 'symmetric', 'reflect', 'wrap'. cval (float): Used with mode 'constant', the value outside image boundaries. clip (bool): Wheter to clip range of output values after interpolation. Default: True. preserve_range (bool): Whether to keep original range of values. Default is True. - If False, input is converted according to conventions of img_as_float. See + If False, input is converted according to conventions of img_as_float. See https://scikit-image.org/docs/dev/user_guide/data_types.html. anti_aliasing (bool): Whether to apply a gaussian filter to image before down-scaling. Default is True. @@ -121,7 +121,7 @@ class Resize: """ def __init__(self, output_shape, order=1, mode='reflect', cval=0, - clip=True, preserve_range=True, + clip=True, preserve_range=True, anti_aliasing=True, anti_aliasing_sigma=None): assert isinstance(order, int), "order must be integer." self.output_shape = output_shape @@ -137,7 +137,7 @@ def __call__(self, img): return resize(img, self.output_shape, order=self.order, mode=self.mode, cval=self.cval, clip=self.clip, preserve_range=self.preserve_range, - anti_aliasing=self.anti_aliasing, + anti_aliasing=self.anti_aliasing, anti_aliasing_sigma=self.anti_aliasing_sigma) @@ -154,13 +154,13 @@ class Rotate: reshape (bool): If true, output shape is made same as input. Default: True. order (int): Order of spline interpolation. Range 0-5. Default: 1. This is different from scipy where default interpolation is 3. - mode (str): Points outside boundary filled according to this mode. Options are + mode (str): Points outside boundary filled according to this mode. Options are 'constant', 'nearest', 'reflect', 'wrap'. Default: 'constant'. cval (scalar): Values to fill outside boundary. Default: 0. prefiter (bool): Apply spline_filter before interpolation. Default: True. """ - def __init__(self, angle, axes=(1, 2), reshape=True, order=1, + def __init__(self, angle, axes=(1, 2), reshape=True, order=1, mode='constant', cval=0, prefilter=True): self.angle = angle self.reshape = reshape @@ -172,18 +172,18 @@ def __init__(self, angle, axes=(1, 2), reshape=True, order=1, def __call__(self, img): return scipy.ndimage.rotate(img, self.angle, self.axes, - reshape=self.reshape, order=self.order, - mode=self.mode, cval=self.cval, + reshape=self.reshape, order=self.order, + mode=self.mode, cval=self.cval, prefilter=self.prefilter) @export class Zoom: - """ Zooms a nd image. Uses scipy.ndimage.zoom or cupyx.scipy.ndimage.zoom in case of gpu. + """ Zooms a nd image. Uses scipy.ndimage.zoom or cupyx.scipy.ndimage.zoom in case of gpu. For details, please see https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.zoom.html. Args: - zoom (float or sequence): The zoom factor along the axes. If a float, zoom is the same for each axis. + zoom (float or sequence): The zoom factor along the axes. If a float, zoom is the same for each axis. If a sequence, zoom should contain one value for each axis. order (int): order of interpolation. Default=3. mode (str): Determines how input is extended beyond boundaries. Default is 'constant'. @@ -441,6 +441,32 @@ def __call__(self, img): return data +@export +class RandomFlip(Randomizable): + """Randomly flips the image along axes. + + Args: + prob (float): Probability of flipping. + axes (None, int or tuple of ints): Axes along which to flip over. Default is None. + """ + + def __init__(self, prob=0.1, axis=None): + self.axis = axis + self.prob = prob + + self._do_transform = False + + def randomize(self): + self._do_transform = self.R.random_sample() < self.prob + + def __call__(self, img): + self.randomize() + if not self._do_transform: + return img + flipper = Flip(axis=self.axis) + return flipper(img) + + @export class RandZoom(Randomizable): """Randomly zooms input arrays with given probability within given zoom range. @@ -450,7 +476,7 @@ class RandZoom(Randomizable): min_zoom (float or sequence): Min zoom factor. Can be float or sequence same size as image. max_zoom (float or sequence): Max zoom factor. Can be float or sequence same size as image. order (int): order of interpolation. Default=3. - mode ('reflect', 'constant', 'nearest', 'mirror', 'wrap'): Determines how input is + mode ('reflect', 'constant', 'nearest', 'mirror', 'wrap'): Determines how input is extended beyond boundaries. Default: 'constant'. cval (scalar, optional): Value to fill past edges. Default is 0. use_gpu (bool): Should use cpu or gpu. Uses cupyx which doesn't support order > 1 and modes diff --git a/tests/test_random_flip.py b/tests/test_random_flip.py new file mode 100644 index 0000000000..ec95485f20 --- /dev/null +++ b/tests/test_random_flip.py @@ -0,0 +1,44 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandomFlip +from tests.utils import NumpyImageTestCase2D + + +class RandomFlipTest(NumpyImageTestCase2D): + + @parameterized.expand([ + ("wrong_axis", ['s', 1], TypeError), + ("not_numbers", 's', AssertionError) + ]) + def test_invalid_inputs(self, _, axis, raises): + with self.assertRaises(raises): + flip = RandomFlip(prob=1.0, axis=axis) + flip(self.imt) + + @parameterized.expand([ + ("no_axis", None), + ("one_axis", 1), + ("many_axis", [0, 1, 2]) + ]) + def test_correct_results(self, _, axis): + flip = RandomFlip(prob=1.0, axis=axis) + expected = np.flip(self.imt, axis) + self.assertTrue(np.allclose(expected, flip(self.imt))) + + +if __name__ == '__main__': + unittest.main()