diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index fb1bf468d6..0cef9dd783 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -441,10 +441,52 @@ def __call__(self, img): return data -# if __name__ == "__main__": -# img = np.array((1, 2, 3, 4)).reshape((1, 2, 2)) -# rotator = RandRotate90(prob=0.0, max_k=3, axes=(1, 2)) -# # rotator.set_random_state(1234) -# img_result = rotator(img) -# print(type(img)) -# print(img_result) +@export +class RandZoom(Randomizable): + """Randomly zooms input arrays with given probability within given zoom range. + + Args: + prob (float): Probability of zooming. + min_zoom (float or sequence): Min zoom factor. Can be float or sequence same size as image. + max_zoom (float or sequence): Max zoom factor. Can be float or sequence same size as image. + order (int): order of interpolation. Default=3. + mode ('reflect', 'constant', 'nearest', 'mirror', 'wrap'): Determines how input is + extended beyond boundaries. Default: 'constant'. + cval (scalar, optional): Value to fill past edges. Default is 0. + use_gpu (bool): Should use cpu or gpu. Uses cupyx which doesn't support order > 1 and modes + 'wrap' and 'reflect'. Defaults to cpu for these cases or if cupyx not found. + keep_size (bool): Should keep original size (pad if needed). + """ + + def __init__(self, prob=0.1, min_zoom=0.9, max_zoom=1.1, order=3, + mode='constant', cval=0, prefilter=True, + use_gpu=False, keep_size=False): + if hasattr(min_zoom, '__iter__') and \ + hasattr(max_zoom, '__iter__'): + assert len(min_zoom) == len(max_zoom), "min_zoom and max_zoom must have same length." + self.min_zoom = min_zoom + self.max_zoom = max_zoom + self.prob = prob + self.order = order + self.mode = mode + self.cval = cval + self.prefilter = prefilter + self.use_gpu = use_gpu + self.keep_size = keep_size + + self._do_transform = False + self._zoom = None + + def randomize(self): + self._do_transform = self.R.random_sample() < self.prob + if hasattr(self.min_zoom, '__iter__'): + self._zoom = (self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)) + else: + self._zoom = self.R.uniform(self.min_zoom, self.max_zoom) + + def __call__(self, img): + self.randomize() + if not self._do_transform: + return img + zoomer = Zoom(self._zoom, self.order, self.mode, self.cval, self.prefilter, self.use_gpu, self.keep_size) + return zoomer(img) diff --git a/tests/test_random_zoom.py b/tests/test_random_zoom.py new file mode 100644 index 0000000000..d193a16dd2 --- /dev/null +++ b/tests/test_random_zoom.py @@ -0,0 +1,76 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import importlib + +from scipy.ndimage import zoom as zoom_scipy +from parameterized import parameterized + +from monai.transforms import RandZoom +from tests.utils import NumpyImageTestCase2D + + +class ZoomTest(NumpyImageTestCase2D): + + @parameterized.expand([ + (0.9, 1.1, 3, 'constant', 0, True, False, False), + ]) + def test_correct_results(self, min_zoom, max_zoom, order, mode, + cval, prefilter, use_gpu, keep_size): + random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, order=order, + mode=mode, cval=cval, prefilter=prefilter, use_gpu=use_gpu, + keep_size=keep_size) + random_zoom.set_random_state(234) + + zoomed = random_zoom(self.imt) + expected = zoom_scipy(self.imt, zoom=random_zoom._zoom, mode=mode, + order=order, cval=cval, prefilter=prefilter) + + self.assertTrue(np.allclose(expected, zoomed)) + + @parameterized.expand([ + (0.8, 1.2, 1, 'constant', 0, True) + ]) + def test_gpu_zoom(self, min_zoom, max_zoom, order, mode, cval, prefilter): + if importlib.util.find_spec('cupy'): + random_zoom = RandZoom( + prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, order=order, + mode=mode, cval=cval, prefilter=prefilter, use_gpu=True, + keep_size=False) + random_zoom.set_random_state(234) + + zoomed = random_zoom(self.imt) + expected = zoom_scipy(self.imt, zoom=random_zoom._zoom, mode=mode, order=order, + cval=cval, prefilter=prefilter) + + self.assertTrue(np.allclose(expected, zoomed)) + + def test_keep_size(self): + random_zoom = RandZoom(prob=1.0, min_zoom=0.6, + max_zoom=0.7, keep_size=True) + zoomed = random_zoom(self.imt) + self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape)) + + @parameterized.expand([ + ("no_min_zoom", None, 1.1, 1, TypeError), + ("invalid_order", 0.9, 1.1 , 's', AssertionError) + ]) + def test_invalid_inputs(self, _, min_zoom, max_zoom, order, raises): + with self.assertRaises(raises): + random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, order=order) + zoomed = random_zoom(self.imt) + + +if __name__ == '__main__': + unittest.main()