diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 676e0274fe..a93c48984c 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -737,6 +737,13 @@ Spatial :members: :special-members: __call__ +`GridSplit` +""""""""""" +.. autoclass:: GridSplit + :members: + :special-members: __call__ + + Smooth Field ^^^^^^^^^^^^ @@ -1506,6 +1513,13 @@ Spatial (Dict) :members: :special-members: __call__ +`GridSplitd` +"""""""""""" +.. autoclass:: GridSplitd + :members: + :special-members: __call__ + + `RandRotate90d` """"""""""""""" .. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandRotate90d.png diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 581e368ba0..c2385499b3 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -311,6 +311,7 @@ AffineGrid, Flip, GridDistortion, + GridSplit, Orientation, Rand2DElastic, Rand3DElastic, @@ -342,6 +343,9 @@ GridDistortiond, GridDistortionD, GridDistortionDict, + GridSplitd, + GridSplitD, + GridSplitDict, Orientationd, OrientationD, OrientationDict, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 37f1c3edc3..6b67762b95 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -18,6 +18,7 @@ import numpy as np import torch +from numpy.lib.stride_tricks import as_strided from monai.config import USE_COMPILED, DtypeLike from monai.config.type_definitions import NdarrayOrTensor @@ -65,6 +66,7 @@ "Orientation", "Flip", "GridDistortion", + "GridSplit", "Resize", "Rotate", "Zoom", @@ -2462,3 +2464,91 @@ def __call__( if not self._do_transform: return img return self.grid_distortion(img, distort_steps=self.distort_steps, mode=mode, padding_mode=padding_mode) + + +class GridSplit(Transform): + """ + Split the image into patches based on the provided grid in 2D. + + Args: + grid: a tuple define the shape of the grid upon which the image is split. Defaults to (2, 2) + size: a tuple or an integer that defines the output patch sizes. + If it's an integer, the value will be repeated for each dimension. + The default is None, where the patch size will be inferred from the grid shape. + + Example: + Given an image (torch.Tensor or numpy.ndarray) with size of (3, 10, 10) and a grid of (2, 2), + it will return a Tensor or array with the size of (4, 3, 5, 5). + Here, if the `size` is provided, the returned shape will be (4, 3, size, size) + + Note: This transform currently support only image with two spatial dimensions. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self, grid: Tuple[int, int] = (2, 2), size: Optional[Union[int, Tuple[int, int]]] = None): + # Grid size + self.grid = grid + + # Patch size + self.size = None if size is None else ensure_tuple_rep(size, len(self.grid)) + + def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: + if self.grid == (1, 1) and self.size is None: + if isinstance(image, torch.Tensor): + return torch.stack([image]) + elif isinstance(image, np.ndarray): + return np.stack([image]) # type: ignore + else: + raise ValueError(f"Input type [{type(image)}] is not supported.") + + size, steps = self._get_params(image.shape[1:]) + patches: NdarrayOrTensor + if isinstance(image, torch.Tensor): + patches = ( + image.unfold(1, size[0], steps[0]) + .unfold(2, size[1], steps[1]) + .flatten(1, 2) + .transpose(0, 1) + .contiguous() + ) + elif isinstance(image, np.ndarray): + x_step, y_step = steps + c_stride, x_stride, y_stride = image.strides + n_channels = image.shape[0] + patches = as_strided( + image, + shape=(*self.grid, n_channels, size[0], size[1]), + strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), + writeable=False, + ) + # flatten the first two dimensions + patches = patches.reshape(np.prod(patches.shape[:2]), *patches.shape[2:]) + # make it a contiguous array + patches = np.ascontiguousarray(patches) + else: + raise ValueError(f"Input type [{type(image)}] is not supported.") + + return patches + + def _get_params(self, image_size: Union[Sequence[int], np.ndarray]): + """ + Calculate the size and step required for splitting the image + Args: + The size of the input image + """ + if self.size is not None: + # Set the split size to the given default size + if any(self.size[i] > image_size[i] for i in range(len(self.grid))): + raise ValueError("The image size ({image_size})is smaller than the requested split size ({self.size})") + split_size = self.size + else: + # infer each sub-image size from the image size and the grid + split_size = tuple(image_size[i] // self.grid[i] for i in range(len(self.grid))) + + steps = tuple( + (image_size[i] - split_size[i]) // (self.grid[i] - 1) if self.grid[i] > 1 else image_size[i] + for i in range(len(self.grid)) + ) + + return split_size, steps diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index d42a11fd2f..47fe05700e 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -34,6 +34,7 @@ AffineGrid, Flip, GridDistortion, + GridSplit, Orientation, Rand2DElastic, Rand3DElastic, @@ -129,6 +130,9 @@ "ZoomDict", "RandZoomD", "RandZoomDict", + "GridSplitd", + "GridSplitD", + "GridSplitDict", ] GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] @@ -2149,6 +2153,40 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d +class GridSplitd(MapTransform): + """ + Split the image into patches based on the provided grid in 2D. + + Args: + keys: keys of the corresponding items to be transformed. + grid: a tuple define the shape of the grid upon which the image is split. Defaults to (2, 2) + size: a tuple or an integer that defines the output patch sizes. + If it's an integer, the value will be repeated for each dimension. + The default is None, where the patch size will be inferred from the grid shape. + allow_missing_keys: don't raise exception if key is missing. + + Note: This transform currently support only image with two spatial dimensions. + """ + + backend = GridSplit.backend + + def __init__( + self, + keys: KeysCollection, + grid: Tuple[int, int] = (2, 2), + size: Optional[Union[int, Tuple[int, int]]] = None, + allow_missing_keys: bool = False, + ): + super().__init__(keys, allow_missing_keys) + self.splitter = GridSplit(grid=grid, size=size) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.splitter(d[key]) + return d + + SpatialResampleD = SpatialResampleDict = SpatialResampled ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd SpacingD = SpacingDict = Spacingd @@ -2169,3 +2207,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N RandRotateD = RandRotateDict = RandRotated ZoomD = ZoomDict = Zoomd RandZoomD = RandZoomDict = RandZoomd +GridSplitD = GridSplitDict = GridSplitd diff --git a/tests/test_grid_split.py b/tests/test_grid_split.py new file mode 100644 index 0000000000..6f0525029d --- /dev/null +++ b/tests/test_grid_split.py @@ -0,0 +1,84 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import GridSplit +from tests.utils import TEST_NDARRAYS, assert_allclose + +A11 = torch.randn(3, 2, 2) +A12 = torch.randn(3, 2, 2) +A21 = torch.randn(3, 2, 2) +A22 = torch.randn(3, 2, 2) + +A1 = torch.cat([A11, A12], 2) +A2 = torch.cat([A21, A22], 2) +A = torch.cat([A1, A2], 1) + +TEST_CASE_0 = [{"grid": (2, 2)}, A, torch.stack([A11, A12, A21, A22])] +TEST_CASE_1 = [{"grid": (2, 1)}, A, torch.stack([A1, A2])] +TEST_CASE_2 = [{"grid": (1, 2)}, A1, torch.stack([A11, A12])] +TEST_CASE_3 = [{"grid": (1, 2)}, A2, torch.stack([A21, A22])] +TEST_CASE_4 = [{"grid": (1, 1), "size": (2, 2)}, A, torch.stack([A11])] +TEST_CASE_5 = [{"grid": (1, 1), "size": 4}, A, torch.stack([A])] +TEST_CASE_6 = [{"grid": (2, 2), "size": 2}, A, torch.stack([A11, A12, A21, A22])] +TEST_CASE_7 = [{"grid": (1, 1)}, A, torch.stack([A])] +TEST_CASE_8 = [ + {"grid": (2, 2), "size": 2}, + torch.arange(12).reshape(1, 3, 4).to(torch.float32), + torch.Tensor([[[[0, 1], [4, 5]]], [[[2, 3], [6, 7]]], [[[4, 5], [8, 9]]], [[[6, 7], [10, 11]]]]).to(torch.float32), +] + +TEST_SINGLE = [] +for p in TEST_NDARRAYS: + TEST_SINGLE.append([p, *TEST_CASE_0]) + TEST_SINGLE.append([p, *TEST_CASE_1]) + TEST_SINGLE.append([p, *TEST_CASE_2]) + TEST_SINGLE.append([p, *TEST_CASE_3]) + TEST_SINGLE.append([p, *TEST_CASE_4]) + TEST_SINGLE.append([p, *TEST_CASE_5]) + TEST_SINGLE.append([p, *TEST_CASE_6]) + TEST_SINGLE.append([p, *TEST_CASE_7]) + TEST_SINGLE.append([p, *TEST_CASE_8]) + +TEST_CASE_MC_0 = [{"grid": (2, 2)}, [A, A], [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])]] +TEST_CASE_MC_1 = [{"grid": (2, 1)}, [A] * 5, [torch.stack([A1, A2])] * 5] +TEST_CASE_MC_2 = [{"grid": (1, 2)}, [A1, A2], [torch.stack([A11, A12]), torch.stack([A21, A22])]] + +TEST_MULTIPLE = [] +for p in TEST_NDARRAYS: + TEST_MULTIPLE.append([p, *TEST_CASE_MC_0]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_1]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_2]) + + +class TestGridSplit(unittest.TestCase): + @parameterized.expand(TEST_SINGLE) + def test_split_patch_single_call(self, in_type, input_parameters, image, expected): + input_image = in_type(image) + splitter = GridSplit(**input_parameters) + output = splitter(input_image) + assert_allclose(output, expected, type_test=False) + + @parameterized.expand(TEST_MULTIPLE) + def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list): + splitter = GridSplit(**input_parameters) + for image, expected in zip(img_list, expected_list): + input_image = in_type(image) + output = splitter(input_image) + assert_allclose(output, expected, type_test=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_grid_splitd.py b/tests/test_grid_splitd.py new file mode 100644 index 0000000000..f325a16946 --- /dev/null +++ b/tests/test_grid_splitd.py @@ -0,0 +1,100 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import GridSplitd +from tests.utils import TEST_NDARRAYS, assert_allclose + +A11 = torch.randn(3, 2, 2) +A12 = torch.randn(3, 2, 2) +A21 = torch.randn(3, 2, 2) +A22 = torch.randn(3, 2, 2) + +A1 = torch.cat([A11, A12], 2) +A2 = torch.cat([A21, A22], 2) +A = torch.cat([A1, A2], 1) + +TEST_CASE_0 = [{"keys": "image", "grid": (2, 2)}, {"image": A}, torch.stack([A11, A12, A21, A22])] +TEST_CASE_1 = [{"keys": "image", "grid": (2, 1)}, {"image": A}, torch.stack([A1, A2])] +TEST_CASE_2 = [{"keys": "image", "grid": (1, 2)}, {"image": A1}, torch.stack([A11, A12])] +TEST_CASE_3 = [{"keys": "image", "grid": (1, 2)}, {"image": A2}, torch.stack([A21, A22])] +TEST_CASE_4 = [{"keys": "image", "grid": (1, 1), "size": (2, 2)}, {"image": A}, torch.stack([A11])] +TEST_CASE_5 = [{"keys": "image", "grid": (1, 1), "size": 4}, {"image": A}, torch.stack([A])] +TEST_CASE_6 = [{"keys": "image", "grid": (2, 2), "size": 2}, {"image": A}, torch.stack([A11, A12, A21, A22])] +TEST_CASE_7 = [{"keys": "image", "grid": (1, 1)}, {"image": A}, torch.stack([A])] +TEST_CASE_8 = [ + {"keys": "image", "grid": (2, 2), "size": 2}, + {"image": torch.arange(12).reshape(1, 3, 4).to(torch.float32)}, + torch.Tensor([[[[0, 1], [4, 5]]], [[[2, 3], [6, 7]]], [[[4, 5], [8, 9]]], [[[6, 7], [10, 11]]]]).to(torch.float32), +] + +TEST_SINGLE = [] +for p in TEST_NDARRAYS: + TEST_SINGLE.append([p, *TEST_CASE_0]) + TEST_SINGLE.append([p, *TEST_CASE_1]) + TEST_SINGLE.append([p, *TEST_CASE_2]) + TEST_SINGLE.append([p, *TEST_CASE_3]) + TEST_SINGLE.append([p, *TEST_CASE_4]) + TEST_SINGLE.append([p, *TEST_CASE_5]) + TEST_SINGLE.append([p, *TEST_CASE_6]) + TEST_SINGLE.append([p, *TEST_CASE_7]) + TEST_SINGLE.append([p, *TEST_CASE_8]) + +TEST_CASE_MC_0 = [ + {"keys": "image", "grid": (2, 2)}, + [{"image": A}, {"image": A}], + [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])], +] +TEST_CASE_MC_1 = [ + {"keys": "image", "grid": (2, 1)}, + [{"image": A}, {"image": A}, {"image": A}], + [torch.stack([A1, A2])] * 3, +] +TEST_CASE_MC_2 = [ + {"keys": "image", "grid": (1, 2)}, + [{"image": A1}, {"image": A2}], + [torch.stack([A11, A12]), torch.stack([A21, A22])], +] + +TEST_MULTIPLE = [] +for p in TEST_NDARRAYS: + TEST_MULTIPLE.append([p, *TEST_CASE_MC_0]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_1]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_2]) + + +class TestGridSplitd(unittest.TestCase): + @parameterized.expand(TEST_SINGLE) + def test_split_patch_single_call(self, in_type, input_parameters, img_dict, expected): + input_dict = {} + for k, v in img_dict.items(): + input_dict[k] = in_type(v) + splitter = GridSplitd(**input_parameters) + output = splitter(input_dict)[input_parameters["keys"]] + assert_allclose(output, expected, type_test=False) + + @parameterized.expand(TEST_MULTIPLE) + def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list): + splitter = GridSplitd(**input_parameters) + for img_dict, expected in zip(img_list, expected_list): + input_dict = {} + for k, v in img_dict.items(): + input_dict[k] = in_type(v) + output = splitter(input_dict)[input_parameters["keys"]] + assert_allclose(output, expected, type_test=False) + + +if __name__ == "__main__": + unittest.main()