From 2202e95d8ccf47ef88a13a1342c5b015347458a2 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 22 Nov 2021 23:55:01 +0000 Subject: [PATCH 1/5] Update TileOnGrid backend to support torch.Tensor as input/output Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- .../pathology/transforms/spatial/array.py | 55 +++++++++++-------- .../transforms/spatial/dictionary.py | 6 +- 2 files changed, 34 insertions(+), 27 deletions(-) diff --git a/monai/apps/pathology/transforms/spatial/array.py b/monai/apps/pathology/transforms/spatial/array.py index 91e8aa94e5..4766eae791 100644 --- a/monai/apps/pathology/transforms/spatial/array.py +++ b/monai/apps/pathology/transforms/spatial/array.py @@ -17,6 +17,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.transforms.transform import Randomizable, Transform +from monai.utils import convert_data_type, convert_to_dst_type from monai.utils.enums import TransformBackends __all__ = ["SplitOnGrid", "TileOnGrid"] @@ -129,6 +130,8 @@ class TileOnGrid(Randomizable, Transform): """ + backend = [TransformBackends.NUMPY] + def __init__( self, tile_count: Optional[int] = None, @@ -185,21 +188,23 @@ def randomize(self, img_size: Sequence[int]) -> None: else: self.random_idxs = np.array((0,)) - def __call__(self, image: np.ndarray) -> np.ndarray: + def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: + img_np: np.ndarray + img_np, *_ = convert_data_type(image, np.ndarray) # type: ignore # add random offset - self.randomize(img_size=image.shape) + self.randomize(img_size=img_np.shape) if self.random_offset and (self.offset[0] > 0 or self.offset[1] > 0): - image = image[:, self.offset[0] :, self.offset[1] :] + img_np = img_np[:, self.offset[0] :, self.offset[1] :] # pad to full size, divisible by tile_size if self.pad_full: - c, h, w = image.shape + c, h, w = img_np.shape pad_h = (self.tile_size - h % self.tile_size) % self.tile_size pad_w = (self.tile_size - w % self.tile_size) % self.tile_size - image = np.pad( - image, + img_np = np.pad( + img_np, [[0, 0], [pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2]], constant_values=self.background_val, ) @@ -207,15 +212,15 @@ def __call__(self, image: np.ndarray) -> np.ndarray: # extact tiles x_step, y_step = self.step, self.step x_size, y_size = self.tile_size, self.tile_size - c_len, x_len, y_len = image.shape - c_stride, x_stride, y_stride = image.strides + c_len, x_len, y_len = img_np.shape + c_stride, x_stride, y_stride = img_np.strides llw = as_strided( - image, + img_np, shape=((x_len - x_size) // x_step + 1, (y_len - y_size) // y_step + 1, c_len, x_size, y_size), strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), writeable=False, ) - image = llw.reshape(-1, c_len, x_size, y_size) + img_np = llw.reshape(-1, c_len, x_size, y_size) # if keeping all patches if self.tile_count is None: @@ -224,32 +229,34 @@ def __call__(self, image: np.ndarray) -> np.ndarray: thresh = 0.999 * 3 * self.background_val * self.tile_size * self.tile_size if self.filter_mode == "min": # default, keep non-background tiles (small values) - idxs = np.argwhere(image.sum(axis=(1, 2, 3)) < thresh) - image = image[idxs.reshape(-1)] + idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) < thresh) + img_np = img_np[idxs.reshape(-1)] elif self.filter_mode == "max": - idxs = np.argwhere(image.sum(axis=(1, 2, 3)) >= thresh) - image = image[idxs.reshape(-1)] + idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) >= thresh) + img_np = img_np[idxs.reshape(-1)] else: - if len(image) > self.tile_count: + if len(img_np) > self.tile_count: if self.filter_mode == "min": # default, keep non-background tiles (smallest values) - idxs = np.argsort(image.sum(axis=(1, 2, 3)))[: self.tile_count] - image = image[idxs] + idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[: self.tile_count] + img_np = img_np[idxs] elif self.filter_mode == "max": - idxs = np.argsort(image.sum(axis=(1, 2, 3)))[-self.tile_count :] - image = image[idxs] + idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[-self.tile_count :] + img_np = img_np[idxs] else: # random subset (more appropriate for WSIs without distinct background) if self.random_idxs is not None: - image = image[self.random_idxs] + img_np = img_np[self.random_idxs] - elif len(image) < self.tile_count: - image = np.pad( - image, - [[0, self.tile_count - len(image)], [0, 0], [0, 0], [0, 0]], + elif len(img_np) < self.tile_count: + img_np = np.pad( + img_np, + [[0, self.tile_count - len(img_np)], [0, 0], [0, 0], [0, 0]], constant_values=self.background_val, ) + image, *_ = convert_to_dst_type(src=img_np, dst=image, dtype=image.dtype) + return image diff --git a/monai/apps/pathology/transforms/spatial/dictionary.py b/monai/apps/pathology/transforms/spatial/dictionary.py index 32df2cae3c..5817b1f1c2 100644 --- a/monai/apps/pathology/transforms/spatial/dictionary.py +++ b/monai/apps/pathology/transforms/spatial/dictionary.py @@ -12,8 +12,6 @@ import copy from typing import Any, Dict, Hashable, List, Mapping, Optional, Tuple, Union -import numpy as np - from monai.config import KeysCollection from monai.config.type_definitions import NdarrayOrTensor from monai.transforms.transform import MapTransform, Randomizable @@ -112,7 +110,9 @@ def __init__( def randomize(self, data: Any = None) -> None: self.seed = self.R.randint(10000) # type: ignore - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Union[Dict[Hashable, np.ndarray], List[Dict]]: + def __call__( + self, data: Mapping[Hashable, NdarrayOrTensor] + ) -> Union[Dict[Hashable, NdarrayOrTensor], List[Dict[Hashable, NdarrayOrTensor]]]: self.randomize() From bed4ef56f56966c816d1b2bd65dd387109f0c9ba Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 22 Nov 2021 23:57:12 +0000 Subject: [PATCH 2/5] Update array unittests to include Tensors on CPU/CUDA Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_tile_on_grid.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/tests/test_tile_on_grid.py b/tests/test_tile_on_grid.py index 8ad091d7d6..1a3fc8d44d 100644 --- a/tests/test_tile_on_grid.py +++ b/tests/test_tile_on_grid.py @@ -16,6 +16,7 @@ from parameterized import parameterized from monai.apps.pathology.transforms import TileOnGrid +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASES = [] for tile_count in [16, 64]: @@ -38,6 +39,10 @@ for step in [4, 8]: TEST_CASES.append([{"tile_count": 16, "step": step, "tile_size": tile_size}]) +TESTS = [] +for p in TEST_NDARRAYS: + for tc in TEST_CASES: + TESTS.append([p, *tc]) TEST_CASES2 = [] for tile_count in [16, 64]: @@ -56,6 +61,11 @@ ] ) +TESTS2 = [] +for p in TEST_NDARRAYS: + for tc in TEST_CASES2: + TESTS2.append([p, *tc]) + def make_image( tile_count: int, @@ -104,25 +114,27 @@ def make_image( class TestTileOnGrid(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_tile_patch_single_call(self, input_parameters): + @parameterized.expand(TESTS) + def test_tile_patch_single_call(self, in_type, input_parameters): img, tiles = make_image(**input_parameters) + input_img = in_type(img) tiler = TileOnGrid(**input_parameters) - output = tiler(img) - np.testing.assert_equal(output, tiles) + output = tiler(input_img) + assert_allclose(output, tiles, type_test=False) - @parameterized.expand(TEST_CASES2) - def test_tile_patch_random_call(self, input_parameters): + @parameterized.expand(TESTS2) + def test_tile_patch_random_call(self, in_type, input_parameters): img, tiles = make_image(**input_parameters, seed=123) + input_img = in_type(img) tiler = TileOnGrid(**input_parameters) tiler.set_random_state(seed=123) - output = tiler(img) - np.testing.assert_equal(output, tiles) + output = tiler(input_img) + assert_allclose(output, tiles, type_test=False) if __name__ == "__main__": From e752f3e37006f9be7b74dd3ee59a935731f007fd Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 22 Nov 2021 23:57:30 +0000 Subject: [PATCH 3/5] Update dictionary unittests to include Tensors on CPU/CUDA Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_tile_on_grid_dict.py | 40 +++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/tests/test_tile_on_grid_dict.py b/tests/test_tile_on_grid_dict.py index 48d36f8f63..78b7907805 100644 --- a/tests/test_tile_on_grid_dict.py +++ b/tests/test_tile_on_grid_dict.py @@ -13,9 +13,11 @@ from typing import Optional import numpy as np +import torch from parameterized import parameterized from monai.apps.pathology.transforms import TileOnGridDict +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASES = [] for tile_count in [16, 64]: @@ -36,11 +38,14 @@ ] ) - for tile_size in [8, 16]: for step in [4, 8]: TEST_CASES.append([{"tile_count": 16, "step": step, "tile_size": tile_size}]) +TESTS = [] +for p in TEST_NDARRAYS: + for tc in TEST_CASES: + TESTS.append([p, *tc]) TEST_CASES2 = [] for tile_count in [16, 64]: @@ -61,6 +66,10 @@ ] ) +TESTS2 = [] +for p in TEST_NDARRAYS: + for tc in TEST_CASES2: + TESTS2.append([p, *tc]) for tile_size in [8, 16]: for step in [4, 8]: @@ -114,27 +123,31 @@ def make_image( class TestTileOnGridDict(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_tile_patch_single_call(self, input_parameters): + @parameterized.expand(TESTS) + def test_tile_patch_single_call(self, in_type, input_parameters): key = "image" input_parameters["keys"] = key img, tiles = make_image(**input_parameters) + input_img = in_type(img) splitter = TileOnGridDict(**input_parameters) - output = splitter({key: img}) + output = splitter({key: input_img}) if input_parameters.get("return_list_of_dicts", False): - output = np.stack([ix[key] for ix in output], axis=0) + if isinstance(input_img, torch.Tensor): + output = torch.stack([ix[key] for ix in output], axis=0) + else: + output = np.stack([ix[key] for ix in output], axis=0) else: output = output[key] - np.testing.assert_equal(tiles, output) + assert_allclose(output, tiles, type_test=False) - @parameterized.expand(TEST_CASES2) - def test_tile_patch_random_call(self, input_parameters): + @parameterized.expand(TESTS2) + def test_tile_patch_random_call(self, in_type, input_parameters): key = "image" input_parameters["keys"] = key @@ -142,18 +155,21 @@ def test_tile_patch_random_call(self, input_parameters): random_state = np.random.RandomState(123) seed = random_state.randint(10000) img, tiles = make_image(**input_parameters, seed=seed) + input_img = in_type(img) splitter = TileOnGridDict(**input_parameters) splitter.set_random_state(seed=123) - output = splitter({key: img}) + output = splitter({key: input_img}) if input_parameters.get("return_list_of_dicts", False): - output = np.stack([ix[key] for ix in output], axis=0) + if isinstance(input_img, torch.Tensor): + output = torch.stack([ix[key] for ix in output], axis=0) + else: + output = np.stack([ix[key] for ix in output], axis=0) else: output = output[key] - - np.testing.assert_equal(tiles, output) + assert_allclose(output, tiles, type_test=False) if __name__ == "__main__": From ba037586d61ca64c86e0af4c474698fb51514f1a Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 23 Nov 2021 00:09:54 +0000 Subject: [PATCH 4/5] Change x_->h_, y_->w_ for consistency Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/apps/pathology/transforms/spatial/array.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/monai/apps/pathology/transforms/spatial/array.py b/monai/apps/pathology/transforms/spatial/array.py index 4766eae791..abfdf3cac0 100644 --- a/monai/apps/pathology/transforms/spatial/array.py +++ b/monai/apps/pathology/transforms/spatial/array.py @@ -210,17 +210,17 @@ def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: ) # extact tiles - x_step, y_step = self.step, self.step - x_size, y_size = self.tile_size, self.tile_size - c_len, x_len, y_len = img_np.shape - c_stride, x_stride, y_stride = img_np.strides + h_step, w_step = self.step, self.step + h_size, w_size = self.tile_size, self.tile_size + c_len, h_len, w_len = img_np.shape + c_stride, h_stride, w_stride = img_np.strides llw = as_strided( img_np, - shape=((x_len - x_size) // x_step + 1, (y_len - y_size) // y_step + 1, c_len, x_size, y_size), - strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), + shape=((h_len - h_size) // h_step + 1, (w_len - w_size) // w_step + 1, c_len, h_size, w_size), + strides=(h_stride * h_step, w_stride * w_step, c_stride, h_stride, w_stride), writeable=False, ) - img_np = llw.reshape(-1, c_len, x_size, y_size) + img_np = llw.reshape(-1, c_len, h_size, w_size) # if keeping all patches if self.tile_count is None: From 722bc65b0e8cbac565643619f7fd1b62363b3866 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 23 Nov 2021 02:39:46 +0000 Subject: [PATCH 5/5] Add backend to dict version Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/apps/pathology/transforms/spatial/dictionary.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/apps/pathology/transforms/spatial/dictionary.py b/monai/apps/pathology/transforms/spatial/dictionary.py index 5817b1f1c2..aae98e7c8d 100644 --- a/monai/apps/pathology/transforms/spatial/dictionary.py +++ b/monai/apps/pathology/transforms/spatial/dictionary.py @@ -79,6 +79,8 @@ class TileOnGridd(Randomizable, MapTransform): """ + backend = SplitOnGrid.backend + def __init__( self, keys: KeysCollection,