Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 44 additions & 16 deletions monai/apps/pathology/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import torch
from numpy.lib.stride_tricks import as_strided

from monai.config.type_definitions import NdarrayOrTensor
from monai.transforms.transform import Randomizable, Transform
from monai.utils.enums import TransformBackends

__all__ = ["SplitOnGrid", "TileOnGrid"]

Expand All @@ -35,6 +37,8 @@ class SplitOnGrid(Transform):
Note: the shape of the input image is inferred based on the first image used.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self, grid_size: Union[int, Tuple[int, int]] = (2, 2), patch_size: Optional[Union[int, Tuple[int, int]]] = None
):
Expand All @@ -50,17 +54,41 @@ def __init__(
else:
self.patch_size = patch_size

def __call__(self, image: torch.Tensor) -> torch.Tensor:
def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor:
if self.grid_size == (1, 1) and self.patch_size is None:
return torch.stack([image])
if isinstance(image, torch.Tensor):
return torch.stack([image])
elif isinstance(image, np.ndarray):
return np.stack([image])
else:
raise ValueError(f"Input type [{type(image)}] is not supported.")

patch_size, steps = self.get_params(image.shape[1:])
patches = (
image.unfold(1, patch_size[0], steps[0])
.unfold(2, patch_size[1], steps[1])
.flatten(1, 2)
.transpose(0, 1)
.contiguous()
)
patches: NdarrayOrTensor
if isinstance(image, torch.Tensor):
patches = (
image.unfold(1, patch_size[0], steps[0])
.unfold(2, patch_size[1], steps[1])
.flatten(1, 2)
.transpose(0, 1)
.contiguous()
)
elif isinstance(image, np.ndarray):
h_step, w_step = steps
c_stride, h_stride, w_stride = image.strides
patches = as_strided(
image,
shape=(*self.grid_size, 3, patch_size[0], patch_size[1]),
strides=(h_stride * h_step, w_stride * w_step, c_stride, h_stride, w_stride),
writeable=False,
)
# flatten the first two dimensions
patches = patches.reshape(np.prod(patches.shape[:2]), *patches.shape[2:])
# make it a contiguous array
patches = np.ascontiguousarray(patches)
else:
raise ValueError(f"Input type [{type(image)}] is not supported.")

return patches

def get_params(self, image_size):
Expand Down Expand Up @@ -177,17 +205,17 @@ def __call__(self, image: np.ndarray) -> np.ndarray:
)

# extact tiles
xstep, ystep = self.step, self.step
xsize, ysize = self.tile_size, self.tile_size
clen, xlen, ylen = image.shape
cstride, xstride, ystride = image.strides
x_step, y_step = self.step, self.step
x_size, y_size = self.tile_size, self.tile_size
c_len, x_len, y_len = image.shape
c_stride, x_stride, y_stride = image.strides
llw = as_strided(
image,
shape=((xlen - xsize) // xstep + 1, (ylen - ysize) // ystep + 1, clen, xsize, ysize),
strides=(xstride * xstep, ystride * ystep, cstride, xstride, ystride),
shape=((x_len - x_size) // x_step + 1, (y_len - y_size) // y_step + 1, c_len, x_size, y_size),
strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride),
writeable=False,
)
image = llw.reshape(-1, clen, xsize, ysize)
image = llw.reshape(-1, c_len, x_size, y_size)

# if keeping all patches
if self.tile_count is None:
Expand Down
8 changes: 5 additions & 3 deletions monai/apps/pathology/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from typing import Any, Dict, Hashable, List, Mapping, Optional, Tuple, Union

import numpy as np
import torch

from monai.config import KeysCollection
from monai.config.type_definitions import NdarrayOrTensor
from monai.transforms.transform import MapTransform, Randomizable

from .array import SplitOnGrid, TileOnGrid
Expand All @@ -35,9 +35,11 @@ class SplitOnGridd(MapTransform):
If it's an integer, the value will be repeated for each dimension.
The default is (0, 0), where the patch size will be inferred from the grid shape.

Note: the shape of the input image is infered based on the first image used.
Note: the shape of the input image is inferred based on the first image used.
"""

backend = SplitOnGrid.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -48,7 +50,7 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.splitter = SplitOnGrid(grid_size=grid_size, patch_size=patch_size)

def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.splitter(d[key])
Expand Down
52 changes: 29 additions & 23 deletions tests/test_split_on_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.apps.pathology.transforms import SplitOnGrid
from tests.utils import TEST_NDARRAYS, assert_allclose

A11 = torch.randn(3, 2, 2)
A12 = torch.randn(3, 2, 2)
Expand All @@ -27,45 +27,51 @@
A = torch.cat([A1, A2], 1)

TEST_CASE_0 = [{"grid_size": (2, 2)}, A, torch.stack([A11, A12, A21, A22])]

TEST_CASE_1 = [{"grid_size": (2, 1)}, A, torch.stack([A1, A2])]

TEST_CASE_2 = [{"grid_size": (1, 2)}, A1, torch.stack([A11, A12])]

TEST_CASE_3 = [{"grid_size": (1, 2)}, A2, torch.stack([A21, A22])]

TEST_CASE_4 = [{"grid_size": (1, 1), "patch_size": (2, 2)}, A, torch.stack([A11])]

TEST_CASE_5 = [{"grid_size": 1, "patch_size": 4}, A, torch.stack([A])]

TEST_CASE_6 = [{"grid_size": 2, "patch_size": 2}, A, torch.stack([A11, A12, A21, A22])]

TEST_CASE_7 = [{"grid_size": 1}, A, torch.stack([A])]

TEST_CASE_MC_0 = [{"grid_size": (2, 2)}, [A, A], [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])]]

TEST_SINGLE = []
for p in TEST_NDARRAYS:
TEST_SINGLE.append([p, *TEST_CASE_0])
TEST_SINGLE.append([p, *TEST_CASE_1])
TEST_SINGLE.append([p, *TEST_CASE_2])
TEST_SINGLE.append([p, *TEST_CASE_3])
TEST_SINGLE.append([p, *TEST_CASE_4])
TEST_SINGLE.append([p, *TEST_CASE_5])
TEST_SINGLE.append([p, *TEST_CASE_6])
TEST_SINGLE.append([p, *TEST_CASE_7])

TEST_CASE_MC_0 = [{"grid_size": (2, 2)}, [A, A], [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])]]
TEST_CASE_MC_1 = [{"grid_size": (2, 1)}, [A] * 5, [torch.stack([A1, A2])] * 5]


TEST_CASE_MC_2 = [{"grid_size": (1, 2)}, [A1, A2], [torch.stack([A11, A12]), torch.stack([A21, A22])]]

TEST_MULTIPLE = []
for p in TEST_NDARRAYS:
TEST_MULTIPLE.append([p, *TEST_CASE_MC_0])
TEST_MULTIPLE.append([p, *TEST_CASE_MC_1])
TEST_MULTIPLE.append([p, *TEST_CASE_MC_2])


class TestSplitOnGrid(unittest.TestCase):
@parameterized.expand(
[TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]
)
def test_split_pathce_single_call(self, input_parameters, img, expected):
@parameterized.expand(TEST_SINGLE)
def test_split_patch_single_call(self, in_type, input_parameters, image, expected):
input_image = in_type(image)
splitter = SplitOnGrid(**input_parameters)
output = splitter(img)
np.testing.assert_equal(output.numpy(), expected.numpy())
output = splitter(input_image)
assert_allclose(output, expected, type_test=False)

@parameterized.expand([TEST_CASE_MC_0, TEST_CASE_MC_1, TEST_CASE_MC_2])
def test_split_pathce_multiple_call(self, input_parameters, img_list, expected_list):
@parameterized.expand(TEST_MULTIPLE)
def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list):
splitter = SplitOnGrid(**input_parameters)
for img, expected in zip(img_list, expected_list):
output = splitter(img)
np.testing.assert_equal(output.numpy(), expected.numpy())
for image, expected in zip(img_list, expected_list):
input_image = in_type(image)
output = splitter(input_image)
assert_allclose(output, expected, type_test=False)


if __name__ == "__main__":
Expand Down
60 changes: 37 additions & 23 deletions tests/test_split_on_grid_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.apps.pathology.transforms import SplitOnGridDict
from tests.utils import TEST_NDARRAYS, assert_allclose

A11 = torch.randn(3, 2, 2)
A12 = torch.randn(3, 2, 2)
Expand All @@ -27,53 +27,67 @@
A = torch.cat([A1, A2], 1)

TEST_CASE_0 = [{"keys": "image", "grid_size": (2, 2)}, {"image": A}, torch.stack([A11, A12, A21, A22])]

TEST_CASE_1 = [{"keys": "image", "grid_size": (2, 1)}, {"image": A}, torch.stack([A1, A2])]

TEST_CASE_2 = [{"keys": "image", "grid_size": (1, 2)}, {"image": A1}, torch.stack([A11, A12])]

TEST_CASE_3 = [{"keys": "image", "grid_size": (1, 2)}, {"image": A2}, torch.stack([A21, A22])]

TEST_CASE_4 = [{"keys": "image", "grid_size": (1, 1), "patch_size": (2, 2)}, {"image": A}, torch.stack([A11])]

TEST_CASE_5 = [{"keys": "image", "grid_size": 1, "patch_size": 4}, {"image": A}, torch.stack([A])]

TEST_CASE_6 = [{"keys": "image", "grid_size": 2, "patch_size": 2}, {"image": A}, torch.stack([A11, A12, A21, A22])]

TEST_CASE_7 = [{"keys": "image", "grid_size": 1}, {"image": A}, torch.stack([A])]

TEST_SINGLE = []
for p in TEST_NDARRAYS:
TEST_SINGLE.append([p, *TEST_CASE_0])
TEST_SINGLE.append([p, *TEST_CASE_1])
TEST_SINGLE.append([p, *TEST_CASE_2])
TEST_SINGLE.append([p, *TEST_CASE_3])
TEST_SINGLE.append([p, *TEST_CASE_4])
TEST_SINGLE.append([p, *TEST_CASE_5])
TEST_SINGLE.append([p, *TEST_CASE_6])
TEST_SINGLE.append([p, *TEST_CASE_7])

TEST_CASE_MC_0 = [
{"keys": "image", "grid_size": (2, 2)},
[{"image": A}, {"image": A}],
[torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])],
]


TEST_CASE_MC_1 = [{"keys": "image", "grid_size": (2, 1)}, [{"image": A}] * 5, [torch.stack([A1, A2])] * 5]


TEST_CASE_MC_1 = [
{"keys": "image", "grid_size": (2, 1)},
[{"image": A}, {"image": A}, {"image": A}],
[torch.stack([A1, A2])] * 3,
]
TEST_CASE_MC_2 = [
{"keys": "image", "grid_size": (1, 2)},
[{"image": A1}, {"image": A2}],
[torch.stack([A11, A12]), torch.stack([A21, A22])],
]

TEST_MULTIPLE = []
for p in TEST_NDARRAYS:
TEST_MULTIPLE.append([p, *TEST_CASE_MC_0])
TEST_MULTIPLE.append([p, *TEST_CASE_MC_1])
TEST_MULTIPLE.append([p, *TEST_CASE_MC_2])


class TestSplitOnGridDict(unittest.TestCase):
@parameterized.expand(
[TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]
)
def test_split_pathce_single_call(self, input_parameters, img_dict, expected):
@parameterized.expand(TEST_SINGLE)
def test_split_patch_single_call(self, in_type, input_parameters, img_dict, expected):
input_dict = {}
for k, v in img_dict.items():
input_dict[k] = in_type(v)
splitter = SplitOnGridDict(**input_parameters)
output = splitter(img_dict)[input_parameters["keys"]]
np.testing.assert_equal(output.numpy(), expected.numpy())
output = splitter(input_dict)[input_parameters["keys"]]
assert_allclose(output, expected, type_test=False)

@parameterized.expand([TEST_CASE_MC_0, TEST_CASE_MC_1, TEST_CASE_MC_2])
def test_split_pathce_multiple_call(self, input_parameters, img_list, expected_list):
@parameterized.expand(TEST_MULTIPLE)
def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list):
splitter = SplitOnGridDict(**input_parameters)
for img_dict, expected in zip(img_list, expected_list):
output = splitter(img_dict)[input_parameters["keys"]]
np.testing.assert_equal(output.numpy(), expected.numpy())
input_dict = {}
for k, v in img_dict.items():
input_dict[k] = in_type(v)
output = splitter(input_dict)[input_parameters["keys"]]
assert_allclose(output, expected, type_test=False)


if __name__ == "__main__":
Expand Down