diff --git a/monai/apps/pathology/transforms/spatial/array.py b/monai/apps/pathology/transforms/spatial/array.py index 91e8aa94e5..abfdf3cac0 100644 --- a/monai/apps/pathology/transforms/spatial/array.py +++ b/monai/apps/pathology/transforms/spatial/array.py @@ -17,6 +17,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.transforms.transform import Randomizable, Transform +from monai.utils import convert_data_type, convert_to_dst_type from monai.utils.enums import TransformBackends __all__ = ["SplitOnGrid", "TileOnGrid"] @@ -129,6 +130,8 @@ class TileOnGrid(Randomizable, Transform): """ + backend = [TransformBackends.NUMPY] + def __init__( self, tile_count: Optional[int] = None, @@ -185,37 +188,39 @@ def randomize(self, img_size: Sequence[int]) -> None: else: self.random_idxs = np.array((0,)) - def __call__(self, image: np.ndarray) -> np.ndarray: + def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: + img_np: np.ndarray + img_np, *_ = convert_data_type(image, np.ndarray) # type: ignore # add random offset - self.randomize(img_size=image.shape) + self.randomize(img_size=img_np.shape) if self.random_offset and (self.offset[0] > 0 or self.offset[1] > 0): - image = image[:, self.offset[0] :, self.offset[1] :] + img_np = img_np[:, self.offset[0] :, self.offset[1] :] # pad to full size, divisible by tile_size if self.pad_full: - c, h, w = image.shape + c, h, w = img_np.shape pad_h = (self.tile_size - h % self.tile_size) % self.tile_size pad_w = (self.tile_size - w % self.tile_size) % self.tile_size - image = np.pad( - image, + img_np = np.pad( + img_np, [[0, 0], [pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2]], constant_values=self.background_val, ) # extact tiles - x_step, y_step = self.step, self.step - x_size, y_size = self.tile_size, self.tile_size - c_len, x_len, y_len = image.shape - c_stride, x_stride, y_stride = image.strides + h_step, w_step = self.step, self.step + h_size, w_size = self.tile_size, self.tile_size + c_len, h_len, w_len = img_np.shape + c_stride, h_stride, w_stride = img_np.strides llw = as_strided( - image, - shape=((x_len - x_size) // x_step + 1, (y_len - y_size) // y_step + 1, c_len, x_size, y_size), - strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), + img_np, + shape=((h_len - h_size) // h_step + 1, (w_len - w_size) // w_step + 1, c_len, h_size, w_size), + strides=(h_stride * h_step, w_stride * w_step, c_stride, h_stride, w_stride), writeable=False, ) - image = llw.reshape(-1, c_len, x_size, y_size) + img_np = llw.reshape(-1, c_len, h_size, w_size) # if keeping all patches if self.tile_count is None: @@ -224,32 +229,34 @@ def __call__(self, image: np.ndarray) -> np.ndarray: thresh = 0.999 * 3 * self.background_val * self.tile_size * self.tile_size if self.filter_mode == "min": # default, keep non-background tiles (small values) - idxs = np.argwhere(image.sum(axis=(1, 2, 3)) < thresh) - image = image[idxs.reshape(-1)] + idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) < thresh) + img_np = img_np[idxs.reshape(-1)] elif self.filter_mode == "max": - idxs = np.argwhere(image.sum(axis=(1, 2, 3)) >= thresh) - image = image[idxs.reshape(-1)] + idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) >= thresh) + img_np = img_np[idxs.reshape(-1)] else: - if len(image) > self.tile_count: + if len(img_np) > self.tile_count: if self.filter_mode == "min": # default, keep non-background tiles (smallest values) - idxs = np.argsort(image.sum(axis=(1, 2, 3)))[: self.tile_count] - image = image[idxs] + idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[: self.tile_count] + img_np = img_np[idxs] elif self.filter_mode == "max": - idxs = np.argsort(image.sum(axis=(1, 2, 3)))[-self.tile_count :] - image = image[idxs] + idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[-self.tile_count :] + img_np = img_np[idxs] else: # random subset (more appropriate for WSIs without distinct background) if self.random_idxs is not None: - image = image[self.random_idxs] + img_np = img_np[self.random_idxs] - elif len(image) < self.tile_count: - image = np.pad( - image, - [[0, self.tile_count - len(image)], [0, 0], [0, 0], [0, 0]], + elif len(img_np) < self.tile_count: + img_np = np.pad( + img_np, + [[0, self.tile_count - len(img_np)], [0, 0], [0, 0], [0, 0]], constant_values=self.background_val, ) + image, *_ = convert_to_dst_type(src=img_np, dst=image, dtype=image.dtype) + return image diff --git a/monai/apps/pathology/transforms/spatial/dictionary.py b/monai/apps/pathology/transforms/spatial/dictionary.py index 32df2cae3c..aae98e7c8d 100644 --- a/monai/apps/pathology/transforms/spatial/dictionary.py +++ b/monai/apps/pathology/transforms/spatial/dictionary.py @@ -12,8 +12,6 @@ import copy from typing import Any, Dict, Hashable, List, Mapping, Optional, Tuple, Union -import numpy as np - from monai.config import KeysCollection from monai.config.type_definitions import NdarrayOrTensor from monai.transforms.transform import MapTransform, Randomizable @@ -81,6 +79,8 @@ class TileOnGridd(Randomizable, MapTransform): """ + backend = SplitOnGrid.backend + def __init__( self, keys: KeysCollection, @@ -112,7 +112,9 @@ def __init__( def randomize(self, data: Any = None) -> None: self.seed = self.R.randint(10000) # type: ignore - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Union[Dict[Hashable, np.ndarray], List[Dict]]: + def __call__( + self, data: Mapping[Hashable, NdarrayOrTensor] + ) -> Union[Dict[Hashable, NdarrayOrTensor], List[Dict[Hashable, NdarrayOrTensor]]]: self.randomize() diff --git a/tests/test_tile_on_grid.py b/tests/test_tile_on_grid.py index 8ad091d7d6..1a3fc8d44d 100644 --- a/tests/test_tile_on_grid.py +++ b/tests/test_tile_on_grid.py @@ -16,6 +16,7 @@ from parameterized import parameterized from monai.apps.pathology.transforms import TileOnGrid +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASES = [] for tile_count in [16, 64]: @@ -38,6 +39,10 @@ for step in [4, 8]: TEST_CASES.append([{"tile_count": 16, "step": step, "tile_size": tile_size}]) +TESTS = [] +for p in TEST_NDARRAYS: + for tc in TEST_CASES: + TESTS.append([p, *tc]) TEST_CASES2 = [] for tile_count in [16, 64]: @@ -56,6 +61,11 @@ ] ) +TESTS2 = [] +for p in TEST_NDARRAYS: + for tc in TEST_CASES2: + TESTS2.append([p, *tc]) + def make_image( tile_count: int, @@ -104,25 +114,27 @@ def make_image( class TestTileOnGrid(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_tile_patch_single_call(self, input_parameters): + @parameterized.expand(TESTS) + def test_tile_patch_single_call(self, in_type, input_parameters): img, tiles = make_image(**input_parameters) + input_img = in_type(img) tiler = TileOnGrid(**input_parameters) - output = tiler(img) - np.testing.assert_equal(output, tiles) + output = tiler(input_img) + assert_allclose(output, tiles, type_test=False) - @parameterized.expand(TEST_CASES2) - def test_tile_patch_random_call(self, input_parameters): + @parameterized.expand(TESTS2) + def test_tile_patch_random_call(self, in_type, input_parameters): img, tiles = make_image(**input_parameters, seed=123) + input_img = in_type(img) tiler = TileOnGrid(**input_parameters) tiler.set_random_state(seed=123) - output = tiler(img) - np.testing.assert_equal(output, tiles) + output = tiler(input_img) + assert_allclose(output, tiles, type_test=False) if __name__ == "__main__": diff --git a/tests/test_tile_on_grid_dict.py b/tests/test_tile_on_grid_dict.py index 48d36f8f63..78b7907805 100644 --- a/tests/test_tile_on_grid_dict.py +++ b/tests/test_tile_on_grid_dict.py @@ -13,9 +13,11 @@ from typing import Optional import numpy as np +import torch from parameterized import parameterized from monai.apps.pathology.transforms import TileOnGridDict +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASES = [] for tile_count in [16, 64]: @@ -36,11 +38,14 @@ ] ) - for tile_size in [8, 16]: for step in [4, 8]: TEST_CASES.append([{"tile_count": 16, "step": step, "tile_size": tile_size}]) +TESTS = [] +for p in TEST_NDARRAYS: + for tc in TEST_CASES: + TESTS.append([p, *tc]) TEST_CASES2 = [] for tile_count in [16, 64]: @@ -61,6 +66,10 @@ ] ) +TESTS2 = [] +for p in TEST_NDARRAYS: + for tc in TEST_CASES2: + TESTS2.append([p, *tc]) for tile_size in [8, 16]: for step in [4, 8]: @@ -114,27 +123,31 @@ def make_image( class TestTileOnGridDict(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_tile_patch_single_call(self, input_parameters): + @parameterized.expand(TESTS) + def test_tile_patch_single_call(self, in_type, input_parameters): key = "image" input_parameters["keys"] = key img, tiles = make_image(**input_parameters) + input_img = in_type(img) splitter = TileOnGridDict(**input_parameters) - output = splitter({key: img}) + output = splitter({key: input_img}) if input_parameters.get("return_list_of_dicts", False): - output = np.stack([ix[key] for ix in output], axis=0) + if isinstance(input_img, torch.Tensor): + output = torch.stack([ix[key] for ix in output], axis=0) + else: + output = np.stack([ix[key] for ix in output], axis=0) else: output = output[key] - np.testing.assert_equal(tiles, output) + assert_allclose(output, tiles, type_test=False) - @parameterized.expand(TEST_CASES2) - def test_tile_patch_random_call(self, input_parameters): + @parameterized.expand(TESTS2) + def test_tile_patch_random_call(self, in_type, input_parameters): key = "image" input_parameters["keys"] = key @@ -142,18 +155,21 @@ def test_tile_patch_random_call(self, input_parameters): random_state = np.random.RandomState(123) seed = random_state.randint(10000) img, tiles = make_image(**input_parameters, seed=seed) + input_img = in_type(img) splitter = TileOnGridDict(**input_parameters) splitter.set_random_state(seed=123) - output = splitter({key: img}) + output = splitter({key: input_img}) if input_parameters.get("return_list_of_dicts", False): - output = np.stack([ix[key] for ix in output], axis=0) + if isinstance(input_img, torch.Tensor): + output = torch.stack([ix[key] for ix in output], axis=0) + else: + output = np.stack([ix[key] for ix in output], axis=0) else: output = output[key] - - np.testing.assert_equal(tiles, output) + assert_allclose(output, tiles, type_test=False) if __name__ == "__main__":