diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index 602baf48d1..db4eafb9ca 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -62,6 +62,24 @@ def __call__(self, img): return rescale_array(img, self.minv, self.maxv, self.dtype) +@export +class Flip: + """Reverses the order of elements along the given axis. Preserves shape. + Uses np.flip in practice. See numpy.flip for additional details. + + Args: + axes (None, int or tuple of ints): Axes along which to flip over. Default is None. + """ + + def __init__(self, axis=None): + assert axis is None or isinstance(axis, (int, list, tuple)), \ + "axis must be None, int or tuple of ints." + self.axis = axis + + def __call__(self, img): + return np.flip(img, self.axis) + + @export class ToTensor: """ diff --git a/tests/test_flip.py b/tests/test_flip.py new file mode 100644 index 0000000000..a70b9c92c5 --- /dev/null +++ b/tests/test_flip.py @@ -0,0 +1,44 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import Flip +from tests.utils import NumpyImageTestCase2D + + +class FlipTest(NumpyImageTestCase2D): + + @parameterized.expand([ + ("wrong_axis", ['s', 1], TypeError), + ("not_numbers", 's', AssertionError) + ]) + def test_invalid_inputs(self, _, axis, raises): + with self.assertRaises(raises): + flip = Flip(axis) + flip(self.imt) + + @parameterized.expand([ + ("no_axis", None), + ("one_axis", 1), + ("many_axis", [0, 1, 2]) + ]) + def test_correct_results(self, _, axis): + flip = Flip(axis=axis) + expected = np.flip(self.imt, axis) + self.assertTrue(np.allclose(expected, flip(self.imt))) + + +if __name__ == '__main__': + unittest.main()