From f531ab86df58028da59b33be409684b2eb69cdd7 Mon Sep 17 00:00:00 2001 From: Mohammad Adil Date: Fri, 6 Mar 2020 15:34:12 -0800 Subject: [PATCH] Adding RandomFlip. --- monai/transforms/transforms.py | 31 ++++++++++++++++++------ tests/test_random_flip.py | 44 ++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 7 deletions(-) create mode 100644 tests/test_random_flip.py diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index fb1bf468d6..f35a072c29 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -441,10 +441,27 @@ def __call__(self, img): return data -# if __name__ == "__main__": -# img = np.array((1, 2, 3, 4)).reshape((1, 2, 2)) -# rotator = RandRotate90(prob=0.0, max_k=3, axes=(1, 2)) -# # rotator.set_random_state(1234) -# img_result = rotator(img) -# print(type(img)) -# print(img_result) +@export +class RandomFlip(Randomizable): + """Randomly flips the image along axes. + + Args: + prob (float): Probability of flipping. + axes (None, int or tuple of ints): Axes along which to flip over. Default is None. + """ + + def __init__(self, prob=0.1, axis=None): + self.axis = axis + self.prob = prob + + self._do_transform = False + + def randomize(self): + self._do_transform = self.R.random_sample() < self.prob + + def __call__(self, img): + self.randomize() + if not self._do_transform: + return img + flipper = Flip(axis=self.axis) + return flipper(img) diff --git a/tests/test_random_flip.py b/tests/test_random_flip.py new file mode 100644 index 0000000000..ec95485f20 --- /dev/null +++ b/tests/test_random_flip.py @@ -0,0 +1,44 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandomFlip +from tests.utils import NumpyImageTestCase2D + + +class RandomFlipTest(NumpyImageTestCase2D): + + @parameterized.expand([ + ("wrong_axis", ['s', 1], TypeError), + ("not_numbers", 's', AssertionError) + ]) + def test_invalid_inputs(self, _, axis, raises): + with self.assertRaises(raises): + flip = RandomFlip(prob=1.0, axis=axis) + flip(self.imt) + + @parameterized.expand([ + ("no_axis", None), + ("one_axis", 1), + ("many_axis", [0, 1, 2]) + ]) + def test_correct_results(self, _, axis): + flip = RandomFlip(prob=1.0, axis=axis) + expected = np.flip(self.imt, axis) + self.assertTrue(np.allclose(expected, flip(self.imt))) + + +if __name__ == '__main__': + unittest.main()