diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index a144c8c138..00d8cb9053 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -314,6 +314,12 @@ Spatial :members: :special-members: __call__ +`RandAxisFlip` +"""""""""""""" +.. autoclass:: RandAxisFlip + :members: + :special-members: __call__ + `RandZoom` """""""""" .. autoclass:: RandZoom @@ -791,6 +797,12 @@ Spatial (Dict) :members: :special-members: __call__ +`RandAxisFlipd` +""""""""""""""" +.. autoclass:: RandAxisFlipd + :members: + :special-members: __call__ + `Rotated` """"""""" .. autoclass:: Rotated diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 8b30d76bec..cd5b195bd3 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -181,6 +181,7 @@ Rand3DElastic, RandAffine, RandAffineGrid, + RandAxisFlip, RandDeformGrid, RandFlip, RandRotate, @@ -209,6 +210,9 @@ RandAffined, RandAffineD, RandAffineDict, + RandAxisFlipd, + RandAxisFlipD, + RandAxisFlipDict, RandFlipd, RandFlipD, RandFlipDict, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 3559d0eb3c..2867361b8e 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -59,6 +59,7 @@ "RandRotate90", "RandRotate", "RandFlip", + "RandAxisFlip", "RandZoom", "AffineGrid", "RandAffineGrid", @@ -771,6 +772,37 @@ def __call__(self, img: np.ndarray) -> np.ndarray: return self.flipper(img) +class RandAxisFlip(RandomizableTransform): + """ + Randomly select a spatial axis and flip along it. + See numpy.flip for additional details. + https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html + + Args: + prob: Probability of flipping. + + """ + + def __init__(self, prob: float = 0.1) -> None: + RandomizableTransform.__init__(self, min(max(prob, 0.0), 1.0)) + self._axis: Optional[int] = None + + def randomize(self, data: np.ndarray) -> None: + super().randomize(None) + self._axis = self.R.randint(data.ndim - 1) + + def __call__(self, img: np.ndarray) -> np.ndarray: + """ + Args: + img: channel first array, must have shape: (num_channels, H[, W, ..., ]), + """ + self.randomize(data=img) + if not self._do_transform: + return img + flipper = Flip(spatial_axis=self._axis) + return flipper(img) + + class RandZoom(RandomizableTransform): """ Randomly zooms input arrays with given probability within given zoom range. diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 6693d75bcd..f29258bf28 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -58,6 +58,7 @@ "Rand3DElasticd", "Flipd", "RandFlipd", + "RandAxisFlipd", "Rotated", "RandRotated", "Zoomd", @@ -82,6 +83,8 @@ "FlipDict", "RandFlipD", "RandFlipDict", + "RandAxisFlipD", + "RandAxisFlipDict", "RotateD", "RotateDict", "RandRotateD", @@ -751,6 +754,39 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d +class RandAxisFlipd(RandomizableTransform, MapTransform): + """ + Dictionary-based version :py:class:`monai.transforms.RandAxisFlip`. + + See `numpy.flip` for additional details. + https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html + + Args: + keys: Keys to pick data for transformation. + prob: Probability of flipping. + + """ + + def __init__(self, keys: KeysCollection, prob: float = 0.1) -> None: + MapTransform.__init__(self, keys) + RandomizableTransform.__init__(self, prob) + self._axis: Optional[int] = None + + def randomize(self, data: np.ndarray) -> None: + super().randomize(None) + self._axis = self.R.randint(data.ndim - 1) + + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + self.randomize(data=data[self.keys[0]]) + flipper = Flip(spatial_axis=self._axis) + + d = dict(data) + for key in self.keys: + if self._do_transform: + d[key] = flipper(d[key]) + return d + + class Rotated(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rotate`. @@ -1051,6 +1087,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda Rand3DElasticD = Rand3DElasticDict = Rand3DElasticd FlipD = FlipDict = Flipd RandFlipD = RandFlipDict = RandFlipd +RandAxisFlipD = RandAxisFlipDict = RandAxisFlipd RotateD = RotateDict = Rotated RandRotateD = RandRotateDict = RandRotated ZoomD = ZoomDict = Zoomd diff --git a/tests/test_rand_axis_flip.py b/tests/test_rand_axis_flip.py new file mode 100644 index 0000000000..0bc2eb130e --- /dev/null +++ b/tests/test_rand_axis_flip.py @@ -0,0 +1,32 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from monai.transforms import RandAxisFlip +from tests.utils import NumpyImageTestCase2D + + +class TestRandAxisFlip(NumpyImageTestCase2D): + def test_correct_results(self): + flip = RandAxisFlip(prob=1.0) + result = flip(self.imt[0]) + + expected = [] + for channel in self.imt[0]: + expected.append(np.flip(channel, flip._axis)) + self.assertTrue(np.allclose(np.stack(expected), result)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_axis_flipd.py b/tests/test_rand_axis_flipd.py new file mode 100644 index 0000000000..154d7813cb --- /dev/null +++ b/tests/test_rand_axis_flipd.py @@ -0,0 +1,32 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from monai.transforms import RandAxisFlipd +from tests.utils import NumpyImageTestCase3D + + +class TestRandAxisFlip(NumpyImageTestCase3D): + def test_correct_results(self): + flip = RandAxisFlipd(keys="img", prob=1.0) + result = flip({"img": self.imt[0]}) + + expected = [] + for channel in self.imt[0]: + expected.append(np.flip(channel, flip._axis)) + self.assertTrue(np.allclose(np.stack(expected), result["img"])) + + +if __name__ == "__main__": + unittest.main()