diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index d1083a641b..b61da87551 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -977,6 +977,12 @@ Intensity (Dict) :members: :special-members: __call__ +`RandCoarseShuffled` +"""""""""""""""""""" +.. autoclass:: RandCoarseShuffled + :members: + :special-members: __call__ + `HistogramNormalized` """"""""""""""""""""" .. autoclass:: HistogramNormalized diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index b9ba303ed7..e4ec38f82d 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -144,6 +144,9 @@ RandCoarseDropoutd, RandCoarseDropoutD, RandCoarseDropoutDict, + RandCoarseShuffled, + RandCoarseShuffleD, + RandCoarseShuffleDict, RandGaussianNoised, RandGaussianNoiseD, RandGaussianNoiseDict, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index d10c3017a3..f6d4dfff5a 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1786,6 +1786,21 @@ class RandCoarseShuffle(RandCoarseTransform): Kang, Guoliang, et al. "Patchshuffle regularization." arXiv preprint arXiv:1707.07103 (2017). https://arxiv.org/abs/1707.07103 + Args: + holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to + randomly select the expected number of regions. + spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg + as the minimum spatial size to randomly select size for every region. + if some components of the `spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + max_holes: if not None, define the maximum number to randomly select the expected number of regions. + max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region. + if some components of the `max_spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + prob: probability of applying the transform. + """ def _transform_holes(self, img: np.ndarray): diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index bc53fb6b7b..ca24980359 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -34,6 +34,7 @@ NormalizeIntensity, RandBiasField, RandCoarseDropout, + RandCoarseShuffle, RandGaussianNoise, RandKSpaceSpikeNoise, RandRicianNoise, @@ -75,6 +76,7 @@ "RandKSpaceSpikeNoised", "RandHistogramShiftd", "RandCoarseDropoutd", + "RandCoarseShuffled", "HistogramNormalized", "RandGaussianNoiseD", "RandGaussianNoiseDict", @@ -126,6 +128,8 @@ "RandRicianNoiseDict", "RandCoarseDropoutD", "RandCoarseDropoutDict", + "RandCoarseShuffleD", + "RandCoarseShuffleDict", "HistogramNormalizeD", "HistogramNormalizeDict", ] @@ -1478,6 +1482,13 @@ def __init__( prob=prob, ) + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandCoarseDropoutd": + self.dropper.set_random_state(seed, state) + super().set_random_state(seed, state) + return self + def randomize(self, img_size: Sequence[int]) -> None: self.dropper.randomize(img_size=img_size) @@ -1492,6 +1503,72 @@ def __call__(self, data): return d +class RandCoarseShuffled(Randomizable, MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.RandCoarseShuffle`. + Expect all the data specified by `keys` have same spatial shape and will randomly dropout the same regions + for every key, if want to shuffle different regions for every key, please use this transform separately. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to + randomly select the expected number of regions. + spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg + as the minimum spatial size to randomly select size for every region. + if some components of the `spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + max_holes: if not None, define the maximum number to randomly select the expected number of regions. + max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region. + if some components of the `max_spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + prob: probability of applying the transform. + allow_missing_keys: don't raise exception if key is missing. + + """ + + def __init__( + self, + keys: KeysCollection, + holes: int, + spatial_size: Union[Sequence[int], int], + max_holes: Optional[int] = None, + max_spatial_size: Optional[Union[Sequence[int], int]] = None, + prob: float = 0.1, + allow_missing_keys: bool = False, + ): + MapTransform.__init__(self, keys, allow_missing_keys) + self.shuffle = RandCoarseShuffle( + holes=holes, + spatial_size=spatial_size, + max_holes=max_holes, + max_spatial_size=max_spatial_size, + prob=prob, + ) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandCoarseShuffled": + self.shuffle.set_random_state(seed, state) + super().set_random_state(seed, state) + return self + + def randomize(self, img_size: Sequence[int]) -> None: + self.shuffle.randomize(img_size=img_size) + + def __call__(self, data): + d = dict(data) + # expect all the specified keys have same spatial shape + self.randomize(d[self.keys[0]].shape[1:]) + if self.shuffle._do_transform: + for key in self.key_iterator(d): + d[key] = self.shuffle(img=d[key]) + + return d + + class HistogramNormalized(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.HistogramNormalize`. @@ -1562,3 +1639,4 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda RandKSpaceSpikeNoiseD = RandKSpaceSpikeNoiseDict = RandKSpaceSpikeNoised RandCoarseDropoutD = RandCoarseDropoutDict = RandCoarseDropoutd HistogramNormalizeD = HistogramNormalizeDict = HistogramNormalized +RandCoarseShuffleD = RandCoarseShuffleDict = RandCoarseShuffled diff --git a/tests/test_rand_coarse_shuffle.py b/tests/test_rand_coarse_shuffle.py index 97d492fd24..0b8cdc6cf8 100644 --- a/tests/test_rand_coarse_shuffle.py +++ b/tests/test_rand_coarse_shuffle.py @@ -45,7 +45,7 @@ class TestRandCoarseShuffle(unittest.TestCase): @parameterized.expand(TEST_CASES) - def test_local_patch_shuffle(self, input_param, input_data, expected_val): + def test_shuffle(self, input_param, input_data, expected_val): g = RandCoarseShuffle(**input_param) g.set_random_state(seed=12) result = g(**input_data) diff --git a/tests/test_rand_coarse_shuffled.py b/tests/test_rand_coarse_shuffled.py new file mode 100644 index 0000000000..d2845fdaae --- /dev/null +++ b/tests/test_rand_coarse_shuffled.py @@ -0,0 +1,56 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RandCoarseShuffled + +TEST_CASES = [ + [ + {"keys": "img", "holes": 5, "spatial_size": 1, "max_spatial_size": -1, "prob": 0.0}, + {"img": np.arange(8).reshape((1, 2, 2, 2))}, + np.arange(8).reshape((1, 2, 2, 2)), + ], + [ + {"keys": "img", "holes": 10, "spatial_size": 1, "max_spatial_size": -1, "prob": 1.0}, + {"img": np.arange(27).reshape((1, 3, 3, 3))}, + np.asarray( + [ + [ + [[13, 17, 5], [6, 16, 25], [12, 15, 22]], + [[24, 7, 3], [9, 2, 23], [0, 4, 26]], + [[19, 11, 14], [1, 20, 8], [18, 10, 21]], + ] + ] + ), + ], + [ + {"keys": "img", "holes": 2, "spatial_size": 1, "max_spatial_size": -1, "prob": 1.0}, + {"img": np.arange(16).reshape((2, 2, 2, 2))}, + np.asarray([[[[7, 2], [1, 4]], [[5, 0], [3, 6]]], [[[8, 13], [10, 15]], [[14, 12], [11, 9]]]]), + ], +] + + +class TestRandCoarseShuffled(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shuffle(self, input_param, input_data, expected_val): + g = RandCoarseShuffled(**input_param) + g.set_random_state(seed=12) + result = g(input_data) + np.testing.assert_allclose(result["img"], expected_val, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main()