Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,13 @@ Spatial
:members:
:special-members: __call__

`GridSplit`
"""""""""""
.. autoclass:: GridSplit
:members:
:special-members: __call__


Smooth Field
^^^^^^^^^^^^

Expand Down Expand Up @@ -1506,6 +1513,13 @@ Spatial (Dict)
:members:
:special-members: __call__

`GridSplitd`
""""""""""""
.. autoclass:: GridSplitd
:members:
:special-members: __call__


`RandRotate90d`
"""""""""""""""
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandRotate90d.png
Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@
AffineGrid,
Flip,
GridDistortion,
GridSplit,
Orientation,
Rand2DElastic,
Rand3DElastic,
Expand Down Expand Up @@ -342,6 +343,9 @@
GridDistortiond,
GridDistortionD,
GridDistortionDict,
GridSplitd,
GridSplitD,
GridSplitDict,
Orientationd,
OrientationD,
OrientationDict,
Expand Down
90 changes: 90 additions & 0 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import numpy as np
import torch
from numpy.lib.stride_tricks import as_strided

from monai.config import USE_COMPILED, DtypeLike
from monai.config.type_definitions import NdarrayOrTensor
Expand Down Expand Up @@ -65,6 +66,7 @@
"Orientation",
"Flip",
"GridDistortion",
"GridSplit",
"Resize",
"Rotate",
"Zoom",
Expand Down Expand Up @@ -2462,3 +2464,91 @@ def __call__(
if not self._do_transform:
return img
return self.grid_distortion(img, distort_steps=self.distort_steps, mode=mode, padding_mode=padding_mode)


class GridSplit(Transform):
"""
Split the image into patches based on the provided grid in 2D.

Args:
grid: a tuple define the shape of the grid upon which the image is split. Defaults to (2, 2)
size: a tuple or an integer that defines the output patch sizes.
If it's an integer, the value will be repeated for each dimension.
The default is None, where the patch size will be inferred from the grid shape.

Example:
Given an image (torch.Tensor or numpy.ndarray) with size of (3, 10, 10) and a grid of (2, 2),
it will return a Tensor or array with the size of (4, 3, 5, 5).
Here, if the `size` is provided, the returned shape will be (4, 3, size, size)

Note: This transform currently support only image with two spatial dimensions.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, grid: Tuple[int, int] = (2, 2), size: Optional[Union[int, Tuple[int, int]]] = None):
# Grid size
self.grid = grid

# Patch size
self.size = None if size is None else ensure_tuple_rep(size, len(self.grid))

def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor:
if self.grid == (1, 1) and self.size is None:
if isinstance(image, torch.Tensor):
return torch.stack([image])
elif isinstance(image, np.ndarray):
return np.stack([image]) # type: ignore
else:
raise ValueError(f"Input type [{type(image)}] is not supported.")

size, steps = self._get_params(image.shape[1:])
patches: NdarrayOrTensor
if isinstance(image, torch.Tensor):
patches = (
image.unfold(1, size[0], steps[0])
.unfold(2, size[1], steps[1])
.flatten(1, 2)
.transpose(0, 1)
.contiguous()
)
elif isinstance(image, np.ndarray):
x_step, y_step = steps
c_stride, x_stride, y_stride = image.strides
n_channels = image.shape[0]
patches = as_strided(
image,
shape=(*self.grid, n_channels, size[0], size[1]),
strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride),
writeable=False,
)
# flatten the first two dimensions
patches = patches.reshape(np.prod(patches.shape[:2]), *patches.shape[2:])
# make it a contiguous array
patches = np.ascontiguousarray(patches)
else:
raise ValueError(f"Input type [{type(image)}] is not supported.")

return patches

def _get_params(self, image_size: Union[Sequence[int], np.ndarray]):
"""
Calculate the size and step required for splitting the image
Args:
The size of the input image
"""
if self.size is not None:
# Set the split size to the given default size
if any(self.size[i] > image_size[i] for i in range(len(self.grid))):
raise ValueError("The image size ({image_size})is smaller than the requested split size ({self.size})")
split_size = self.size
else:
# infer each sub-image size from the image size and the grid
split_size = tuple(image_size[i] // self.grid[i] for i in range(len(self.grid)))

steps = tuple(
(image_size[i] - split_size[i]) // (self.grid[i] - 1) if self.grid[i] > 1 else image_size[i]
for i in range(len(self.grid))
)

return split_size, steps
39 changes: 39 additions & 0 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
AffineGrid,
Flip,
GridDistortion,
GridSplit,
Orientation,
Rand2DElastic,
Rand3DElastic,
Expand Down Expand Up @@ -129,6 +130,9 @@
"ZoomDict",
"RandZoomD",
"RandZoomDict",
"GridSplitd",
"GridSplitD",
"GridSplitDict",
]

GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str]
Expand Down Expand Up @@ -2149,6 +2153,40 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
return d


class GridSplitd(MapTransform):
"""
Split the image into patches based on the provided grid in 2D.

Args:
keys: keys of the corresponding items to be transformed.
grid: a tuple define the shape of the grid upon which the image is split. Defaults to (2, 2)
size: a tuple or an integer that defines the output patch sizes.
If it's an integer, the value will be repeated for each dimension.
The default is None, where the patch size will be inferred from the grid shape.
allow_missing_keys: don't raise exception if key is missing.

Note: This transform currently support only image with two spatial dimensions.
"""

backend = GridSplit.backend

def __init__(
self,
keys: KeysCollection,
grid: Tuple[int, int] = (2, 2),
size: Optional[Union[int, Tuple[int, int]]] = None,
allow_missing_keys: bool = False,
):
super().__init__(keys, allow_missing_keys)
self.splitter = GridSplit(grid=grid, size=size)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.splitter(d[key])
return d


SpatialResampleD = SpatialResampleDict = SpatialResampled
ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd
SpacingD = SpacingDict = Spacingd
Expand All @@ -2169,3 +2207,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
RandRotateD = RandRotateDict = RandRotated
ZoomD = ZoomDict = Zoomd
RandZoomD = RandZoomDict = RandZoomd
GridSplitD = GridSplitDict = GridSplitd
84 changes: 84 additions & 0 deletions tests/test_grid_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.transforms import GridSplit
from tests.utils import TEST_NDARRAYS, assert_allclose

A11 = torch.randn(3, 2, 2)
A12 = torch.randn(3, 2, 2)
A21 = torch.randn(3, 2, 2)
A22 = torch.randn(3, 2, 2)

A1 = torch.cat([A11, A12], 2)
A2 = torch.cat([A21, A22], 2)
A = torch.cat([A1, A2], 1)

TEST_CASE_0 = [{"grid": (2, 2)}, A, torch.stack([A11, A12, A21, A22])]
TEST_CASE_1 = [{"grid": (2, 1)}, A, torch.stack([A1, A2])]
TEST_CASE_2 = [{"grid": (1, 2)}, A1, torch.stack([A11, A12])]
TEST_CASE_3 = [{"grid": (1, 2)}, A2, torch.stack([A21, A22])]
TEST_CASE_4 = [{"grid": (1, 1), "size": (2, 2)}, A, torch.stack([A11])]
TEST_CASE_5 = [{"grid": (1, 1), "size": 4}, A, torch.stack([A])]
TEST_CASE_6 = [{"grid": (2, 2), "size": 2}, A, torch.stack([A11, A12, A21, A22])]
TEST_CASE_7 = [{"grid": (1, 1)}, A, torch.stack([A])]
TEST_CASE_8 = [
{"grid": (2, 2), "size": 2},
torch.arange(12).reshape(1, 3, 4).to(torch.float32),
torch.Tensor([[[[0, 1], [4, 5]]], [[[2, 3], [6, 7]]], [[[4, 5], [8, 9]]], [[[6, 7], [10, 11]]]]).to(torch.float32),
]

TEST_SINGLE = []
for p in TEST_NDARRAYS:
TEST_SINGLE.append([p, *TEST_CASE_0])
TEST_SINGLE.append([p, *TEST_CASE_1])
TEST_SINGLE.append([p, *TEST_CASE_2])
TEST_SINGLE.append([p, *TEST_CASE_3])
TEST_SINGLE.append([p, *TEST_CASE_4])
TEST_SINGLE.append([p, *TEST_CASE_5])
TEST_SINGLE.append([p, *TEST_CASE_6])
TEST_SINGLE.append([p, *TEST_CASE_7])
TEST_SINGLE.append([p, *TEST_CASE_8])

TEST_CASE_MC_0 = [{"grid": (2, 2)}, [A, A], [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])]]
TEST_CASE_MC_1 = [{"grid": (2, 1)}, [A] * 5, [torch.stack([A1, A2])] * 5]
TEST_CASE_MC_2 = [{"grid": (1, 2)}, [A1, A2], [torch.stack([A11, A12]), torch.stack([A21, A22])]]

TEST_MULTIPLE = []
for p in TEST_NDARRAYS:
TEST_MULTIPLE.append([p, *TEST_CASE_MC_0])
TEST_MULTIPLE.append([p, *TEST_CASE_MC_1])
TEST_MULTIPLE.append([p, *TEST_CASE_MC_2])


class TestGridSplit(unittest.TestCase):
@parameterized.expand(TEST_SINGLE)
def test_split_patch_single_call(self, in_type, input_parameters, image, expected):
input_image = in_type(image)
splitter = GridSplit(**input_parameters)
output = splitter(input_image)
assert_allclose(output, expected, type_test=False)

@parameterized.expand(TEST_MULTIPLE)
def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list):
splitter = GridSplit(**input_parameters)
for image, expected in zip(img_list, expected_list):
input_image = in_type(image)
output = splitter(input_image)
assert_allclose(output, expected, type_test=False)


if __name__ == "__main__":
unittest.main()
Loading