diff --git a/monai/apps/pathology/transforms/spatial/array.py b/monai/apps/pathology/transforms/spatial/array.py index 56927a8033..91e8aa94e5 100644 --- a/monai/apps/pathology/transforms/spatial/array.py +++ b/monai/apps/pathology/transforms/spatial/array.py @@ -15,7 +15,9 @@ import torch from numpy.lib.stride_tricks import as_strided +from monai.config.type_definitions import NdarrayOrTensor from monai.transforms.transform import Randomizable, Transform +from monai.utils.enums import TransformBackends __all__ = ["SplitOnGrid", "TileOnGrid"] @@ -35,6 +37,8 @@ class SplitOnGrid(Transform): Note: the shape of the input image is inferred based on the first image used. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, grid_size: Union[int, Tuple[int, int]] = (2, 2), patch_size: Optional[Union[int, Tuple[int, int]]] = None ): @@ -50,17 +54,41 @@ def __init__( else: self.patch_size = patch_size - def __call__(self, image: torch.Tensor) -> torch.Tensor: + def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: if self.grid_size == (1, 1) and self.patch_size is None: - return torch.stack([image]) + if isinstance(image, torch.Tensor): + return torch.stack([image]) + elif isinstance(image, np.ndarray): + return np.stack([image]) + else: + raise ValueError(f"Input type [{type(image)}] is not supported.") + patch_size, steps = self.get_params(image.shape[1:]) - patches = ( - image.unfold(1, patch_size[0], steps[0]) - .unfold(2, patch_size[1], steps[1]) - .flatten(1, 2) - .transpose(0, 1) - .contiguous() - ) + patches: NdarrayOrTensor + if isinstance(image, torch.Tensor): + patches = ( + image.unfold(1, patch_size[0], steps[0]) + .unfold(2, patch_size[1], steps[1]) + .flatten(1, 2) + .transpose(0, 1) + .contiguous() + ) + elif isinstance(image, np.ndarray): + h_step, w_step = steps + c_stride, h_stride, w_stride = image.strides + patches = as_strided( + image, + shape=(*self.grid_size, 3, patch_size[0], patch_size[1]), + strides=(h_stride * h_step, w_stride * w_step, c_stride, h_stride, w_stride), + writeable=False, + ) + # flatten the first two dimensions + patches = patches.reshape(np.prod(patches.shape[:2]), *patches.shape[2:]) + # make it a contiguous array + patches = np.ascontiguousarray(patches) + else: + raise ValueError(f"Input type [{type(image)}] is not supported.") + return patches def get_params(self, image_size): @@ -177,17 +205,17 @@ def __call__(self, image: np.ndarray) -> np.ndarray: ) # extact tiles - xstep, ystep = self.step, self.step - xsize, ysize = self.tile_size, self.tile_size - clen, xlen, ylen = image.shape - cstride, xstride, ystride = image.strides + x_step, y_step = self.step, self.step + x_size, y_size = self.tile_size, self.tile_size + c_len, x_len, y_len = image.shape + c_stride, x_stride, y_stride = image.strides llw = as_strided( image, - shape=((xlen - xsize) // xstep + 1, (ylen - ysize) // ystep + 1, clen, xsize, ysize), - strides=(xstride * xstep, ystride * ystep, cstride, xstride, ystride), + shape=((x_len - x_size) // x_step + 1, (y_len - y_size) // y_step + 1, c_len, x_size, y_size), + strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), writeable=False, ) - image = llw.reshape(-1, clen, xsize, ysize) + image = llw.reshape(-1, c_len, x_size, y_size) # if keeping all patches if self.tile_count is None: diff --git a/monai/apps/pathology/transforms/spatial/dictionary.py b/monai/apps/pathology/transforms/spatial/dictionary.py index f998e53c93..32df2cae3c 100644 --- a/monai/apps/pathology/transforms/spatial/dictionary.py +++ b/monai/apps/pathology/transforms/spatial/dictionary.py @@ -13,9 +13,9 @@ from typing import Any, Dict, Hashable, List, Mapping, Optional, Tuple, Union import numpy as np -import torch from monai.config import KeysCollection +from monai.config.type_definitions import NdarrayOrTensor from monai.transforms.transform import MapTransform, Randomizable from .array import SplitOnGrid, TileOnGrid @@ -35,9 +35,11 @@ class SplitOnGridd(MapTransform): If it's an integer, the value will be repeated for each dimension. The default is (0, 0), where the patch size will be inferred from the grid shape. - Note: the shape of the input image is infered based on the first image used. + Note: the shape of the input image is inferred based on the first image used. """ + backend = SplitOnGrid.backend + def __init__( self, keys: KeysCollection, @@ -48,7 +50,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.splitter = SplitOnGrid(grid_size=grid_size, patch_size=patch_size) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.splitter(d[key]) diff --git a/tests/test_split_on_grid.py b/tests/test_split_on_grid.py index 4893c4c78a..a3bf2674f4 100644 --- a/tests/test_split_on_grid.py +++ b/tests/test_split_on_grid.py @@ -11,11 +11,11 @@ import unittest -import numpy as np import torch from parameterized import parameterized from monai.apps.pathology.transforms import SplitOnGrid +from tests.utils import TEST_NDARRAYS, assert_allclose A11 = torch.randn(3, 2, 2) A12 = torch.randn(3, 2, 2) @@ -27,45 +27,51 @@ A = torch.cat([A1, A2], 1) TEST_CASE_0 = [{"grid_size": (2, 2)}, A, torch.stack([A11, A12, A21, A22])] - TEST_CASE_1 = [{"grid_size": (2, 1)}, A, torch.stack([A1, A2])] - TEST_CASE_2 = [{"grid_size": (1, 2)}, A1, torch.stack([A11, A12])] - TEST_CASE_3 = [{"grid_size": (1, 2)}, A2, torch.stack([A21, A22])] - TEST_CASE_4 = [{"grid_size": (1, 1), "patch_size": (2, 2)}, A, torch.stack([A11])] - TEST_CASE_5 = [{"grid_size": 1, "patch_size": 4}, A, torch.stack([A])] - TEST_CASE_6 = [{"grid_size": 2, "patch_size": 2}, A, torch.stack([A11, A12, A21, A22])] - TEST_CASE_7 = [{"grid_size": 1}, A, torch.stack([A])] -TEST_CASE_MC_0 = [{"grid_size": (2, 2)}, [A, A], [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])]] - +TEST_SINGLE = [] +for p in TEST_NDARRAYS: + TEST_SINGLE.append([p, *TEST_CASE_0]) + TEST_SINGLE.append([p, *TEST_CASE_1]) + TEST_SINGLE.append([p, *TEST_CASE_2]) + TEST_SINGLE.append([p, *TEST_CASE_3]) + TEST_SINGLE.append([p, *TEST_CASE_4]) + TEST_SINGLE.append([p, *TEST_CASE_5]) + TEST_SINGLE.append([p, *TEST_CASE_6]) + TEST_SINGLE.append([p, *TEST_CASE_7]) +TEST_CASE_MC_0 = [{"grid_size": (2, 2)}, [A, A], [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])]] TEST_CASE_MC_1 = [{"grid_size": (2, 1)}, [A] * 5, [torch.stack([A1, A2])] * 5] - - TEST_CASE_MC_2 = [{"grid_size": (1, 2)}, [A1, A2], [torch.stack([A11, A12]), torch.stack([A21, A22])]] +TEST_MULTIPLE = [] +for p in TEST_NDARRAYS: + TEST_MULTIPLE.append([p, *TEST_CASE_MC_0]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_1]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_2]) + class TestSplitOnGrid(unittest.TestCase): - @parameterized.expand( - [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7] - ) - def test_split_pathce_single_call(self, input_parameters, img, expected): + @parameterized.expand(TEST_SINGLE) + def test_split_patch_single_call(self, in_type, input_parameters, image, expected): + input_image = in_type(image) splitter = SplitOnGrid(**input_parameters) - output = splitter(img) - np.testing.assert_equal(output.numpy(), expected.numpy()) + output = splitter(input_image) + assert_allclose(output, expected, type_test=False) - @parameterized.expand([TEST_CASE_MC_0, TEST_CASE_MC_1, TEST_CASE_MC_2]) - def test_split_pathce_multiple_call(self, input_parameters, img_list, expected_list): + @parameterized.expand(TEST_MULTIPLE) + def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list): splitter = SplitOnGrid(**input_parameters) - for img, expected in zip(img_list, expected_list): - output = splitter(img) - np.testing.assert_equal(output.numpy(), expected.numpy()) + for image, expected in zip(img_list, expected_list): + input_image = in_type(image) + output = splitter(input_image) + assert_allclose(output, expected, type_test=False) if __name__ == "__main__": diff --git a/tests/test_split_on_grid_dict.py b/tests/test_split_on_grid_dict.py index f22e58515f..5f3e442640 100644 --- a/tests/test_split_on_grid_dict.py +++ b/tests/test_split_on_grid_dict.py @@ -11,11 +11,11 @@ import unittest -import numpy as np import torch from parameterized import parameterized from monai.apps.pathology.transforms import SplitOnGridDict +from tests.utils import TEST_NDARRAYS, assert_allclose A11 = torch.randn(3, 2, 2) A12 = torch.randn(3, 2, 2) @@ -27,53 +27,67 @@ A = torch.cat([A1, A2], 1) TEST_CASE_0 = [{"keys": "image", "grid_size": (2, 2)}, {"image": A}, torch.stack([A11, A12, A21, A22])] - TEST_CASE_1 = [{"keys": "image", "grid_size": (2, 1)}, {"image": A}, torch.stack([A1, A2])] - TEST_CASE_2 = [{"keys": "image", "grid_size": (1, 2)}, {"image": A1}, torch.stack([A11, A12])] - TEST_CASE_3 = [{"keys": "image", "grid_size": (1, 2)}, {"image": A2}, torch.stack([A21, A22])] - TEST_CASE_4 = [{"keys": "image", "grid_size": (1, 1), "patch_size": (2, 2)}, {"image": A}, torch.stack([A11])] - TEST_CASE_5 = [{"keys": "image", "grid_size": 1, "patch_size": 4}, {"image": A}, torch.stack([A])] - TEST_CASE_6 = [{"keys": "image", "grid_size": 2, "patch_size": 2}, {"image": A}, torch.stack([A11, A12, A21, A22])] - TEST_CASE_7 = [{"keys": "image", "grid_size": 1}, {"image": A}, torch.stack([A])] +TEST_SINGLE = [] +for p in TEST_NDARRAYS: + TEST_SINGLE.append([p, *TEST_CASE_0]) + TEST_SINGLE.append([p, *TEST_CASE_1]) + TEST_SINGLE.append([p, *TEST_CASE_2]) + TEST_SINGLE.append([p, *TEST_CASE_3]) + TEST_SINGLE.append([p, *TEST_CASE_4]) + TEST_SINGLE.append([p, *TEST_CASE_5]) + TEST_SINGLE.append([p, *TEST_CASE_6]) + TEST_SINGLE.append([p, *TEST_CASE_7]) + TEST_CASE_MC_0 = [ {"keys": "image", "grid_size": (2, 2)}, [{"image": A}, {"image": A}], [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])], ] - - -TEST_CASE_MC_1 = [{"keys": "image", "grid_size": (2, 1)}, [{"image": A}] * 5, [torch.stack([A1, A2])] * 5] - - +TEST_CASE_MC_1 = [ + {"keys": "image", "grid_size": (2, 1)}, + [{"image": A}, {"image": A}, {"image": A}], + [torch.stack([A1, A2])] * 3, +] TEST_CASE_MC_2 = [ {"keys": "image", "grid_size": (1, 2)}, [{"image": A1}, {"image": A2}], [torch.stack([A11, A12]), torch.stack([A21, A22])], ] +TEST_MULTIPLE = [] +for p in TEST_NDARRAYS: + TEST_MULTIPLE.append([p, *TEST_CASE_MC_0]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_1]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_2]) + class TestSplitOnGridDict(unittest.TestCase): - @parameterized.expand( - [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7] - ) - def test_split_pathce_single_call(self, input_parameters, img_dict, expected): + @parameterized.expand(TEST_SINGLE) + def test_split_patch_single_call(self, in_type, input_parameters, img_dict, expected): + input_dict = {} + for k, v in img_dict.items(): + input_dict[k] = in_type(v) splitter = SplitOnGridDict(**input_parameters) - output = splitter(img_dict)[input_parameters["keys"]] - np.testing.assert_equal(output.numpy(), expected.numpy()) + output = splitter(input_dict)[input_parameters["keys"]] + assert_allclose(output, expected, type_test=False) - @parameterized.expand([TEST_CASE_MC_0, TEST_CASE_MC_1, TEST_CASE_MC_2]) - def test_split_pathce_multiple_call(self, input_parameters, img_list, expected_list): + @parameterized.expand(TEST_MULTIPLE) + def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list): splitter = SplitOnGridDict(**input_parameters) for img_dict, expected in zip(img_list, expected_list): - output = splitter(img_dict)[input_parameters["keys"]] - np.testing.assert_equal(output.numpy(), expected.numpy()) + input_dict = {} + for k, v in img_dict.items(): + input_dict[k] = in_type(v) + output = splitter(input_dict)[input_parameters["keys"]] + assert_allclose(output, expected, type_test=False) if __name__ == "__main__":