diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fd6a351d59..c36c96186c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,11 +41,11 @@ repos: # - id: isort # name: Format imports - - repo: https://github.com/psf/black - rev: 21.7b0 - hooks: - - id: black - name: Format code + # - repo: https://github.com/psf/black + # rev: 21.7b0 + # hooks: + # - id: black + # name: Format code #- repo: https://github.com/executablebooks/mdformat # rev: 0.7.8 @@ -56,8 +56,8 @@ repos: # - mdformat_frontmatter # exclude: CHANGELOG.md - - repo: https://github.com/PyCQA/flake8 - rev: 3.9.2 - hooks: - - id: flake8 - name: Check PEP8 + # - repo: https://github.com/PyCQA/flake8 + # rev: 3.9.2 + # hooks: + # - id: flake8 + # name: Check PEP8 diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index da8ceda0e3..b8f57e0dbe 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -320,6 +320,11 @@ Intensity :members: :special-members: __call__ +`LocalPatchShuffling` +""""""""""""""""""""" +.. autoclass:: LocalPatchShuffling + :members: + :special-members: __call__ IO ^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 5267af4048..41b0872698 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -85,6 +85,7 @@ GibbsNoise, HistogramNormalize, KSpaceSpikeNoise, + LocalPatchShuffling, MaskIntensity, NormalizeIntensity, RandAdjustContrast, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index a68f6a0a2e..ddf79bdbd6 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -13,6 +13,7 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ +import copy from collections.abc import Iterable from functools import partial from typing import Any, Callable, List, Optional, Sequence, Tuple, Union @@ -70,6 +71,7 @@ "RandKSpaceSpikeNoise", "RandCoarseDropout", "HistogramNormalize", + "LocalPatchShuffling", ] @@ -1742,3 +1744,95 @@ def __call__(self, img: np.ndarray, mask: Optional[np.ndarray] = None) -> np.nda max=self.max, dtype=self.dtype, ) + + +class LocalPatchShuffling(RandomizableTransform): + """ + Takes a 3D image and based on input of the local patch size, shuffles the pixels of the local patch within it. + This process is repeated a for N number of times where every time a different random block is selected for local + pixel shuffling. + + Kang, Guoliang, et al. "Patchshuffle regularization." arXiv preprint arXiv:1707.07103 (2017). + """ + + def __init__( + self, + prob: float = 1.0, + number_blocks: int = 1000, + blocksize_ratio: int = 10, + channel_wise: bool = True, + device: Optional[torch.device] = None, + image_only: bool = False, + ) -> None: + """ + Args: + prob: The chance of this transform occuring on the given volume. + number_blocks: Total number of time a random 3D block will be selected for local shuffling of pixels/voxels + contained in the block. + blocksize_ratio: This ratio can be used to estimate the local 3D block sizes that will be selected. + channel_wise: If True, treats each channel of the image separately. + device: device on which the tensor will be allocated. + image_only: if True return only the image volume, otherwise return (image, affine). + """ + RandomizableTransform.__init__(self, prob) + self.prob = prob + self.number_blocks = number_blocks + self.blocksize_ratio = blocksize_ratio + self.channel_wise = channel_wise + + def _local_patch_shuffle(self, img: Union[torch.Tensor, np.ndarray], number_blocks: int, blocksize_ratio: int): + im_shape = img.shape + img_copy = copy.deepcopy(img) + for _each_block in range(number_blocks): + + block_size_x = self.R.randint(1, im_shape[0] // blocksize_ratio) + block_size_y = self.R.randint(1, im_shape[1] // blocksize_ratio) + block_size_z = self.R.randint(1, im_shape[2] // blocksize_ratio) + + noise_x = self.R.randint(0, im_shape[0] - block_size_x) + noise_y = self.R.randint(0, im_shape[1] - block_size_y) + noise_z = self.R.randint(0, im_shape[2] - block_size_z) + + local_patch = img[ + noise_x : noise_x + block_size_x, + noise_y : noise_y + block_size_y, + noise_z : noise_z + block_size_z, + ] + + local_patch = local_patch.flatten() + self.R.shuffle(local_patch) + local_patch = local_patch.reshape((block_size_x, block_size_y, block_size_z)) + + img_copy[ + noise_x : noise_x + block_size_x, noise_y : noise_y + block_size_y, noise_z : noise_z + block_size_z + ] = local_patch + + shuffled_image = img_copy + return shuffled_image + + def __call__( + self, + img: Union[np.ndarray, torch.Tensor], + # spatial_size: Optional[Union[Sequence[int], int]] = None, + # mode: Optional[Union[GridSampleMode, str]] = None, + # padding_mode: Optional[Union[GridSamplePadMode, str]] = None, + ): + """ + Args: + img: shape must be (num_channels, H, W[, D]), + + """ + + super().randomize(None) + if not self._do_transform: + return img + + if self.channel_wise: + # img = self._local_patch_shuffle(img=img) + for i, _d in enumerate(img): + img[i] = self._local_patch_shuffle( + img=img[i], blocksize_ratio=self.blocksize_ratio, number_blocks=self.number_blocks + ) + else: + raise AssertionError("If channel_wise is False, the image needs to be set to channel first") + return img diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index a7d93f88f3..c3bd4a3433 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -12,7 +12,6 @@ A collection of "vanilla" transforms for spatial operations https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ - import warnings from typing import Any, List, Optional, Sequence, Tuple, Union diff --git a/tests/test_rand_local_patch_shuffle.py b/tests/test_rand_local_patch_shuffle.py new file mode 100644 index 0000000000..8e2eefb5d1 --- /dev/null +++ b/tests/test_rand_local_patch_shuffle.py @@ -0,0 +1,49 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import LocalPatchShuffling + +TEST_CASES = [ + [ + {"number_blocks": 10, "blocksize_ratio": 1, "prob": 0.0}, + {"img": np.arange(8).reshape((1, 2, 2, 2))}, + np.arange(8).reshape((1, 2, 2, 2)), + ], + [ + {"number_blocks": 10, "blocksize_ratio": 1, "prob": 1.0}, + {"img": np.arange(27).reshape((1, 3, 3, 3))}, + [ + [ + [[9, 1, 2], [3, 4, 5], [6, 7, 8]], + [[0, 10, 11], [12, 4, 14], [15, 16, 17]], + [[18, 19, 20], [21, 22, 23], [24, 25, 26]], + ] + ], + ], +] + + +class TestLocalPatchShuffle(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_local_patch_shuffle(self, input_param, input_data, expected_val): + g = LocalPatchShuffling(**input_param) + g.set_random_state(seed=12) + result = g(**input_data) + np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main()