From 252a1c5f888e248c5a546496783fa68b7b2a1cb9 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 10 May 2022 01:18:38 +0000 Subject: [PATCH 1/9] Add split patch support to patch wsi dataset Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/data/wsi_datasets.py | 36 +++++++++++++++++++++----- tests/test_patch_wsi_dataset_new.py | 40 ++++++++++++++++++++++++++++- 2 files changed, 69 insertions(+), 7 deletions(-) diff --git a/monai/data/wsi_datasets.py b/monai/data/wsi_datasets.py index 750b3fda20..24b14c3091 100644 --- a/monai/data/wsi_datasets.py +++ b/monai/data/wsi_datasets.py @@ -10,13 +10,13 @@ # limitations under the License. import inspect -from typing import Callable, Dict, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np from monai.data import Dataset from monai.data.wsi_reader import BaseWSIReader, WSIReader -from monai.transforms import apply_transform +from monai.transforms import GridSplit, apply_transform from monai.utils import ensure_tuple_rep __all__ = ["PatchWSIDataset"] @@ -38,6 +38,8 @@ class PatchWSIDataset(Dataset): - a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader. - an instance of a a class inherited from `BaseWSIReader`, it is set as the wsi_reader. + split_grid: a tuple define the shape of the grid upon which the image is split. + split_size: a tuple or an integer that defines the output sub patch sizes. kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class Note: @@ -59,6 +61,8 @@ def __init__( level: Optional[int] = None, transform: Optional[Callable] = None, reader="cuCIM", + split_grid: Optional[Union[int, Tuple[int, int]]] = None, + split_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs, ): super().__init__(data, transform) @@ -91,6 +95,15 @@ def __init__( # Initialized an empty whole slide image object dict self.wsi_object_dict: Dict = {} + # Create the splitter to split patches into subpatches on a grid + self.split_size = split_size + if split_grid is None: + self.split_grid = None + self.splitter = None + else: + self.split_grid = ensure_tuple_rep(split_grid, 2) + self.splitter = GridSplit(grid=self.split_grid, size=self.split_size) # type: ignore + def _get_wsi_object(self, sample: Dict): image_path = sample["image"] if image_path not in self.wsi_object_dict: @@ -122,7 +135,7 @@ def _get_data(self, sample: Dict): location = self._get_location(sample) level = self._get_level(sample) size = self._get_size(sample) - return self.wsi_reader.get_data(wsi=wsi_obj, location=location, size=size, level=level) + return self.wsi_reader.get_data(wsi=wsi_obj, location=location[::-1], size=size, level=level) def _transform(self, index: int): # Get a single entry of data @@ -131,7 +144,18 @@ def _transform(self, index: int): image, metadata = self._get_data(sample) # Get the label label = self._get_label(sample) + output: Union[Dict, List] + if self.splitter: + # Split the extracted patch (image) into sub-patches + output = [] + sub_patches = self.splitter(image) + metadata["sub_patch"] = {"grid": self.split_grid, "size": self.split_size} + for i in range(len(sub_patches)): + sub_meta = metadata.copy() + sub_meta["sub_patch"]["id"] = i + output.append({"image": sub_patches[i], "label": label[i : i + 1], "metadata": metadata}) + else: + output = {"image": image, "label": label, "metadata": metadata} - # Create put all patch information together and apply transforms - patch = {"image": image, "label": label, "metadata": metadata} - return apply_transform(self.transform, patch) if self.transform else patch + # Apply transforms and output + return apply_transform(self.transform, output) if self.transform else output diff --git a/tests/test_patch_wsi_dataset_new.py b/tests/test_patch_wsi_dataset_new.py index 0be30536de..b88aba8f0f 100644 --- a/tests/test_patch_wsi_dataset_new.py +++ b/tests/test_patch_wsi_dataset_new.py @@ -90,6 +90,35 @@ ], ] +TEST_CASE_SPLIT_0 = [ + { + "data": [{"image": FILE_PATH, "location": [10004, 20004], "label": [0, 2, 0, 1]}], + "size": (2, 2), + "split_grid": (2, 2), + }, + [ + {"image": np.array([[[247]], [[245]], [[246]]], dtype=np.uint8), "label": np.array([0])}, + {"image": np.array([[[246]], [[246]], [[244]]], dtype=np.uint8), "label": np.array([2])}, + {"image": np.array([[[246]], [[246]], [[244]]], dtype=np.uint8), "label": np.array([0])}, + {"image": np.array([[[246]], [[246]], [[244]]], dtype=np.uint8), "label": np.array([1])}, + ], +] + +TEST_CASE_SPLIT_1 = [ + { + "data": [{"image": FILE_PATH, "location": [10004, 20004], "label": [0, 0, 0, 1]}], + "size": (8, 8), + "split_grid": (2, 2), + "split_size": 1, + }, + [ + {"image": np.array([[[246]], [[245]], [[250]]], dtype=np.uint8), "label": np.array([0])}, + {"image": np.array([[[247]], [[245]], [[246]]], dtype=np.uint8), "label": np.array([0])}, + {"image": np.array([[[246]], [[246]], [[244]]], dtype=np.uint8), "label": np.array([0])}, + {"image": np.array([[[246]], [[246]], [[246]]], dtype=np.uint8), "label": np.array([1])}, + ], +] + @skipUnless(has_cucim or has_osl or has_tiff, "Requires cucim, openslide, or tifffile!") def setUpModule(): # noqa: N802 @@ -150,6 +179,15 @@ def test_read_patches_str_multi(self, input_parameters, expected): self.assertIsNone(assert_array_equal(dataset[i]["label"], expected[i]["label"])) self.assertIsNone(assert_array_equal(dataset[i]["image"], expected[i]["image"])) + @parameterized.expand([TEST_CASE_SPLIT_0, TEST_CASE_SPLIT_1]) + def test_read_patches_str(self, input_parameters, expected): + dataset = PatchWSIDataset(reader=self.backend, **input_parameters) + for i, sample in enumerate(dataset[0]): + self.assertTupleEqual(sample["label"].shape, expected[i]["label"].shape) + self.assertTupleEqual(sample["image"].shape, expected[i]["image"].shape) + self.assertIsNone(assert_array_equal(sample["label"], expected[i]["label"])) + self.assertIsNone(assert_array_equal(sample["image"], expected[i]["image"])) + @skipUnless(has_cucim, "Requires cucim") class TestPatchWSIDatasetCuCIM(PatchWSIDatasetTests.Tests): @@ -158,7 +196,7 @@ def setUpClass(cls): cls.backend = "cucim" -@skipUnless(has_osl, "Requires cucim") +@skipUnless(has_osl, "Requires openslide") class TestPatchWSIDatasetOpenSlide(PatchWSIDatasetTests.Tests): @classmethod def setUpClass(cls): From 347e6b6179aa0d69ebf7181ab51c1d7cb29b3018 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 10 May 2022 13:21:11 +0000 Subject: [PATCH 2/9] Rename the split test Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_patch_wsi_dataset_new.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_patch_wsi_dataset_new.py b/tests/test_patch_wsi_dataset_new.py index b88aba8f0f..e1fe95eda0 100644 --- a/tests/test_patch_wsi_dataset_new.py +++ b/tests/test_patch_wsi_dataset_new.py @@ -180,7 +180,7 @@ def test_read_patches_str_multi(self, input_parameters, expected): self.assertIsNone(assert_array_equal(dataset[i]["image"], expected[i]["image"])) @parameterized.expand([TEST_CASE_SPLIT_0, TEST_CASE_SPLIT_1]) - def test_read_patches_str(self, input_parameters, expected): + def test_read_split_patches(self, input_parameters, expected): dataset = PatchWSIDataset(reader=self.backend, **input_parameters) for i, sample in enumerate(dataset[0]): self.assertTupleEqual(sample["label"].shape, expected[i]["label"].shape) From d52d173f9452315d90a3c7a573770dc99c25874c Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Wed, 11 May 2022 16:55:27 +0000 Subject: [PATCH 3/9] Update GridSplit and revert PatchWSIDataset Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/data/wsi_datasets.py | 28 ++----------------- monai/transforms/spatial/array.py | 28 +++++++++---------- monai/transforms/spatial/dictionary.py | 11 ++++++-- tests/test_grid_split.py | 28 ++++++++++--------- tests/test_grid_splitd.py | 32 ++++++++++++---------- tests/test_patch_wsi_dataset_new.py | 38 -------------------------- 6 files changed, 56 insertions(+), 109 deletions(-) diff --git a/monai/data/wsi_datasets.py b/monai/data/wsi_datasets.py index 24b14c3091..ff8b0ec0f4 100644 --- a/monai/data/wsi_datasets.py +++ b/monai/data/wsi_datasets.py @@ -38,8 +38,6 @@ class PatchWSIDataset(Dataset): - a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader. - an instance of a a class inherited from `BaseWSIReader`, it is set as the wsi_reader. - split_grid: a tuple define the shape of the grid upon which the image is split. - split_size: a tuple or an integer that defines the output sub patch sizes. kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class Note: @@ -61,8 +59,6 @@ def __init__( level: Optional[int] = None, transform: Optional[Callable] = None, reader="cuCIM", - split_grid: Optional[Union[int, Tuple[int, int]]] = None, - split_size: Optional[Union[int, Tuple[int, int]]] = None, **kwargs, ): super().__init__(data, transform) @@ -95,15 +91,6 @@ def __init__( # Initialized an empty whole slide image object dict self.wsi_object_dict: Dict = {} - # Create the splitter to split patches into subpatches on a grid - self.split_size = split_size - if split_grid is None: - self.split_grid = None - self.splitter = None - else: - self.split_grid = ensure_tuple_rep(split_grid, 2) - self.splitter = GridSplit(grid=self.split_grid, size=self.split_size) # type: ignore - def _get_wsi_object(self, sample: Dict): image_path = sample["image"] if image_path not in self.wsi_object_dict: @@ -140,22 +127,13 @@ def _get_data(self, sample: Dict): def _transform(self, index: int): # Get a single entry of data sample: Dict = self.data[index] + # Extract patch image and associated metadata image, metadata = self._get_data(sample) + # Get the label label = self._get_label(sample) - output: Union[Dict, List] - if self.splitter: - # Split the extracted patch (image) into sub-patches - output = [] - sub_patches = self.splitter(image) - metadata["sub_patch"] = {"grid": self.split_grid, "size": self.split_size} - for i in range(len(sub_patches)): - sub_meta = metadata.copy() - sub_meta["sub_patch"]["id"] = i - output.append({"image": sub_patches[i], "label": label[i : i + 1], "metadata": metadata}) - else: - output = {"image": image, "label": label, "metadata": metadata} # Apply transforms and output + output = {"image": image, "label": label, "metadata": metadata} return apply_transform(self.transform, output) if self.transform else output diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 6b67762b95..bf88cab506 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -2493,39 +2493,37 @@ def __init__(self, grid: Tuple[int, int] = (2, 2), size: Optional[Union[int, Tup # Patch size self.size = None if size is None else ensure_tuple_rep(size, len(self.grid)) - def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, image: NdarrayOrTensor) -> List[NdarrayOrTensor]: if self.grid == (1, 1) and self.size is None: if isinstance(image, torch.Tensor): - return torch.stack([image]) + return [image] elif isinstance(image, np.ndarray): - return np.stack([image]) # type: ignore + return [image] else: raise ValueError(f"Input type [{type(image)}] is not supported.") size, steps = self._get_params(image.shape[1:]) - patches: NdarrayOrTensor + patches: List[NdarrayOrTensor] if isinstance(image, torch.Tensor): - patches = ( - image.unfold(1, size[0], steps[0]) - .unfold(2, size[1], steps[1]) - .flatten(1, 2) - .transpose(0, 1) - .contiguous() + unfolded_image = ( + image.unfold(1, size[0], steps[0]).unfold(2, size[1], steps[1]).flatten(1, 2).transpose(0, 1) ) + # Make a list of contiguous patches + patches = [p.contiguous() for p in unfolded_image] elif isinstance(image, np.ndarray): x_step, y_step = steps c_stride, x_stride, y_stride = image.strides n_channels = image.shape[0] - patches = as_strided( + strided_image = as_strided( image, shape=(*self.grid, n_channels, size[0], size[1]), strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), writeable=False, ) - # flatten the first two dimensions - patches = patches.reshape(np.prod(patches.shape[:2]), *patches.shape[2:]) - # make it a contiguous array - patches = np.ascontiguousarray(patches) + # Flatten the first two dimensions + strided_image = strided_image.reshape(-1, *strided_image.shape[2:]) + # Make a list of contiguous patches + patches = [np.ascontiguousarray(p) for p in strided_image] else: raise ValueError(f"Input type [{type(image)}] is not supported.") diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 47fe05700e..654267c5b4 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -2178,13 +2178,18 @@ def __init__( allow_missing_keys: bool = False, ): super().__init__(keys, allow_missing_keys) + self.grid = grid self.splitter = GridSplit(grid=grid, size=size) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: d = dict(data) + n_outputs = np.prod(self.grid) + output: List[Dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(n_outputs)] for key in self.key_iterator(d): - d[key] = self.splitter(d[key]) - return d + result = self.splitter(d[key]) + for i in range(n_outputs): + output[i][key] = result[i] + return output SpatialResampleD = SpatialResampleDict = SpatialResampled diff --git a/tests/test_grid_split.py b/tests/test_grid_split.py index 6f0525029d..82734ffd93 100644 --- a/tests/test_grid_split.py +++ b/tests/test_grid_split.py @@ -26,14 +26,14 @@ A2 = torch.cat([A21, A22], 2) A = torch.cat([A1, A2], 1) -TEST_CASE_0 = [{"grid": (2, 2)}, A, torch.stack([A11, A12, A21, A22])] -TEST_CASE_1 = [{"grid": (2, 1)}, A, torch.stack([A1, A2])] -TEST_CASE_2 = [{"grid": (1, 2)}, A1, torch.stack([A11, A12])] -TEST_CASE_3 = [{"grid": (1, 2)}, A2, torch.stack([A21, A22])] -TEST_CASE_4 = [{"grid": (1, 1), "size": (2, 2)}, A, torch.stack([A11])] -TEST_CASE_5 = [{"grid": (1, 1), "size": 4}, A, torch.stack([A])] -TEST_CASE_6 = [{"grid": (2, 2), "size": 2}, A, torch.stack([A11, A12, A21, A22])] -TEST_CASE_7 = [{"grid": (1, 1)}, A, torch.stack([A])] +TEST_CASE_0 = [{"grid": (2, 2)}, A, [A11, A12, A21, A22]] +TEST_CASE_1 = [{"grid": (2, 1)}, A, [A1, A2]] +TEST_CASE_2 = [{"grid": (1, 2)}, A1, [A11, A12]] +TEST_CASE_3 = [{"grid": (1, 2)}, A2, [A21, A22]] +TEST_CASE_4 = [{"grid": (1, 1), "size": (2, 2)}, A, [A11]] +TEST_CASE_5 = [{"grid": (1, 1), "size": 4}, A, [A]] +TEST_CASE_6 = [{"grid": (2, 2), "size": 2}, A, [A11, A12, A21, A22]] +TEST_CASE_7 = [{"grid": (1, 1)}, A, [A]] TEST_CASE_8 = [ {"grid": (2, 2), "size": 2}, torch.arange(12).reshape(1, 3, 4).to(torch.float32), @@ -52,9 +52,9 @@ TEST_SINGLE.append([p, *TEST_CASE_7]) TEST_SINGLE.append([p, *TEST_CASE_8]) -TEST_CASE_MC_0 = [{"grid": (2, 2)}, [A, A], [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])]] -TEST_CASE_MC_1 = [{"grid": (2, 1)}, [A] * 5, [torch.stack([A1, A2])] * 5] -TEST_CASE_MC_2 = [{"grid": (1, 2)}, [A1, A2], [torch.stack([A11, A12]), torch.stack([A21, A22])]] +TEST_CASE_MC_0 = [{"grid": (2, 2)}, [A, A], [[A11, A12, A21, A22], [A11, A12, A21, A22]]] +TEST_CASE_MC_1 = [{"grid": (2, 1)}, [A] * 5, [[A1, A2]] * 5] +TEST_CASE_MC_2 = [{"grid": (1, 2)}, [A1, A2], [[A11, A12], [A21, A22]]] TEST_MULTIPLE = [] for p in TEST_NDARRAYS: @@ -69,7 +69,8 @@ def test_split_patch_single_call(self, in_type, input_parameters, image, expecte input_image = in_type(image) splitter = GridSplit(**input_parameters) output = splitter(input_image) - assert_allclose(output, expected, type_test=False) + for output_patch, expected_patch in zip(output, expected): + assert_allclose(output_patch, expected_patch, type_test=False) @parameterized.expand(TEST_MULTIPLE) def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list): @@ -77,7 +78,8 @@ def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, ex for image, expected in zip(img_list, expected_list): input_image = in_type(image) output = splitter(input_image) - assert_allclose(output, expected, type_test=False) + for output_patch, expected_patch in zip(output, expected): + assert_allclose(output_patch, expected_patch, type_test=False) if __name__ == "__main__": diff --git a/tests/test_grid_splitd.py b/tests/test_grid_splitd.py index f325a16946..c9274d4c83 100644 --- a/tests/test_grid_splitd.py +++ b/tests/test_grid_splitd.py @@ -26,14 +26,14 @@ A2 = torch.cat([A21, A22], 2) A = torch.cat([A1, A2], 1) -TEST_CASE_0 = [{"keys": "image", "grid": (2, 2)}, {"image": A}, torch.stack([A11, A12, A21, A22])] -TEST_CASE_1 = [{"keys": "image", "grid": (2, 1)}, {"image": A}, torch.stack([A1, A2])] -TEST_CASE_2 = [{"keys": "image", "grid": (1, 2)}, {"image": A1}, torch.stack([A11, A12])] -TEST_CASE_3 = [{"keys": "image", "grid": (1, 2)}, {"image": A2}, torch.stack([A21, A22])] -TEST_CASE_4 = [{"keys": "image", "grid": (1, 1), "size": (2, 2)}, {"image": A}, torch.stack([A11])] -TEST_CASE_5 = [{"keys": "image", "grid": (1, 1), "size": 4}, {"image": A}, torch.stack([A])] -TEST_CASE_6 = [{"keys": "image", "grid": (2, 2), "size": 2}, {"image": A}, torch.stack([A11, A12, A21, A22])] -TEST_CASE_7 = [{"keys": "image", "grid": (1, 1)}, {"image": A}, torch.stack([A])] +TEST_CASE_0 = [{"keys": "image", "grid": (2, 2)}, {"image": A}, [A11, A12, A21, A22]] +TEST_CASE_1 = [{"keys": "image", "grid": (2, 1)}, {"image": A}, [A1, A2]] +TEST_CASE_2 = [{"keys": "image", "grid": (1, 2)}, {"image": A1}, [A11, A12]] +TEST_CASE_3 = [{"keys": "image", "grid": (1, 2)}, {"image": A2}, [A21, A22]] +TEST_CASE_4 = [{"keys": "image", "grid": (1, 1), "size": (2, 2)}, {"image": A}, [A11]] +TEST_CASE_5 = [{"keys": "image", "grid": (1, 1), "size": 4}, {"image": A}, [A]] +TEST_CASE_6 = [{"keys": "image", "grid": (2, 2), "size": 2}, {"image": A}, [A11, A12, A21, A22]] +TEST_CASE_7 = [{"keys": "image", "grid": (1, 1)}, {"image": A}, [A]] TEST_CASE_8 = [ {"keys": "image", "grid": (2, 2), "size": 2}, {"image": torch.arange(12).reshape(1, 3, 4).to(torch.float32)}, @@ -55,17 +55,17 @@ TEST_CASE_MC_0 = [ {"keys": "image", "grid": (2, 2)}, [{"image": A}, {"image": A}], - [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])], + [[A11, A12, A21, A22], [A11, A12, A21, A22]], ] TEST_CASE_MC_1 = [ {"keys": "image", "grid": (2, 1)}, [{"image": A}, {"image": A}, {"image": A}], - [torch.stack([A1, A2])] * 3, + [[A1, A2]] * 3, ] TEST_CASE_MC_2 = [ {"keys": "image", "grid": (1, 2)}, [{"image": A1}, {"image": A2}], - [torch.stack([A11, A12]), torch.stack([A21, A22])], + [[A11, A12], [A21, A22]], ] TEST_MULTIPLE = [] @@ -82,8 +82,9 @@ def test_split_patch_single_call(self, in_type, input_parameters, img_dict, expe for k, v in img_dict.items(): input_dict[k] = in_type(v) splitter = GridSplitd(**input_parameters) - output = splitter(input_dict)[input_parameters["keys"]] - assert_allclose(output, expected, type_test=False) + output = splitter(input_dict) + for output_patch, expected_patch in zip(output, expected): + assert_allclose(output_patch[input_parameters["keys"]], expected_patch, type_test=False) @parameterized.expand(TEST_MULTIPLE) def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list): @@ -92,8 +93,9 @@ def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, ex input_dict = {} for k, v in img_dict.items(): input_dict[k] = in_type(v) - output = splitter(input_dict)[input_parameters["keys"]] - assert_allclose(output, expected, type_test=False) + output = splitter(input_dict) + for output_patch, expected_patch in zip(output, expected): + assert_allclose(output_patch[input_parameters["keys"]], expected_patch, type_test=False) if __name__ == "__main__": diff --git a/tests/test_patch_wsi_dataset_new.py b/tests/test_patch_wsi_dataset_new.py index e1fe95eda0..d128d45262 100644 --- a/tests/test_patch_wsi_dataset_new.py +++ b/tests/test_patch_wsi_dataset_new.py @@ -90,35 +90,6 @@ ], ] -TEST_CASE_SPLIT_0 = [ - { - "data": [{"image": FILE_PATH, "location": [10004, 20004], "label": [0, 2, 0, 1]}], - "size": (2, 2), - "split_grid": (2, 2), - }, - [ - {"image": np.array([[[247]], [[245]], [[246]]], dtype=np.uint8), "label": np.array([0])}, - {"image": np.array([[[246]], [[246]], [[244]]], dtype=np.uint8), "label": np.array([2])}, - {"image": np.array([[[246]], [[246]], [[244]]], dtype=np.uint8), "label": np.array([0])}, - {"image": np.array([[[246]], [[246]], [[244]]], dtype=np.uint8), "label": np.array([1])}, - ], -] - -TEST_CASE_SPLIT_1 = [ - { - "data": [{"image": FILE_PATH, "location": [10004, 20004], "label": [0, 0, 0, 1]}], - "size": (8, 8), - "split_grid": (2, 2), - "split_size": 1, - }, - [ - {"image": np.array([[[246]], [[245]], [[250]]], dtype=np.uint8), "label": np.array([0])}, - {"image": np.array([[[247]], [[245]], [[246]]], dtype=np.uint8), "label": np.array([0])}, - {"image": np.array([[[246]], [[246]], [[244]]], dtype=np.uint8), "label": np.array([0])}, - {"image": np.array([[[246]], [[246]], [[246]]], dtype=np.uint8), "label": np.array([1])}, - ], -] - @skipUnless(has_cucim or has_osl or has_tiff, "Requires cucim, openslide, or tifffile!") def setUpModule(): # noqa: N802 @@ -179,15 +150,6 @@ def test_read_patches_str_multi(self, input_parameters, expected): self.assertIsNone(assert_array_equal(dataset[i]["label"], expected[i]["label"])) self.assertIsNone(assert_array_equal(dataset[i]["image"], expected[i]["image"])) - @parameterized.expand([TEST_CASE_SPLIT_0, TEST_CASE_SPLIT_1]) - def test_read_split_patches(self, input_parameters, expected): - dataset = PatchWSIDataset(reader=self.backend, **input_parameters) - for i, sample in enumerate(dataset[0]): - self.assertTupleEqual(sample["label"].shape, expected[i]["label"].shape) - self.assertTupleEqual(sample["image"].shape, expected[i]["image"].shape) - self.assertIsNone(assert_array_equal(sample["label"], expected[i]["label"])) - self.assertIsNone(assert_array_equal(sample["image"], expected[i]["image"])) - @skipUnless(has_cucim, "Requires cucim") class TestPatchWSIDatasetCuCIM(PatchWSIDatasetTests.Tests): From a6e2df3af8ba783640129077989311b08cdab713 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Wed, 11 May 2022 18:06:19 +0000 Subject: [PATCH 4/9] Add support for various sizes to split different keys Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/transforms/spatial/array.py | 37 +++++++++++++++----------- monai/transforms/spatial/dictionary.py | 7 ++--- tests/test_grid_splitd.py | 8 +++--- 3 files changed, 30 insertions(+), 22 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index bf88cab506..9c3a87a79d 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -2493,8 +2493,12 @@ def __init__(self, grid: Tuple[int, int] = (2, 2), size: Optional[Union[int, Tup # Patch size self.size = None if size is None else ensure_tuple_rep(size, len(self.grid)) - def __call__(self, image: NdarrayOrTensor) -> List[NdarrayOrTensor]: - if self.grid == (1, 1) and self.size is None: + def __call__( + self, image: NdarrayOrTensor, size: Optional[Union[int, Tuple[int, int], np.ndarray]] = None + ) -> List[NdarrayOrTensor]: + input_size = self.size if size is None else ensure_tuple_rep(size, len(self.grid)) + + if self.grid == (1, 1) and input_size is None: if isinstance(image, torch.Tensor): return [image] elif isinstance(image, np.ndarray): @@ -2502,11 +2506,14 @@ def __call__(self, image: NdarrayOrTensor) -> List[NdarrayOrTensor]: else: raise ValueError(f"Input type [{type(image)}] is not supported.") - size, steps = self._get_params(image.shape[1:]) + split_size, steps = self._get_params(image.shape[1:], input_size) patches: List[NdarrayOrTensor] if isinstance(image, torch.Tensor): unfolded_image = ( - image.unfold(1, size[0], steps[0]).unfold(2, size[1], steps[1]).flatten(1, 2).transpose(0, 1) + image.unfold(1, split_size[0], steps[0]) + .unfold(2, split_size[1], steps[1]) + .flatten(1, 2) + .transpose(0, 1) ) # Make a list of contiguous patches patches = [p.contiguous() for p in unfolded_image] @@ -2516,7 +2523,7 @@ def __call__(self, image: NdarrayOrTensor) -> List[NdarrayOrTensor]: n_channels = image.shape[0] strided_image = as_strided( image, - shape=(*self.grid, n_channels, size[0], size[1]), + shape=(*self.grid, n_channels, split_size[0], split_size[1]), strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), writeable=False, ) @@ -2529,24 +2536,24 @@ def __call__(self, image: NdarrayOrTensor) -> List[NdarrayOrTensor]: return patches - def _get_params(self, image_size: Union[Sequence[int], np.ndarray]): + def _get_params( + self, image_size: Union[Sequence[int], np.ndarray], size: Optional[Union[Sequence[int], np.ndarray]] = None + ): """ Calculate the size and step required for splitting the image Args: The size of the input image """ - if self.size is not None: - # Set the split size to the given default size - if any(self.size[i] > image_size[i] for i in range(len(self.grid))): - raise ValueError("The image size ({image_size})is smaller than the requested split size ({self.size})") - split_size = self.size - else: + if size is None: # infer each sub-image size from the image size and the grid - split_size = tuple(image_size[i] // self.grid[i] for i in range(len(self.grid))) + size = tuple(image_size[i] // self.grid[i] for i in range(len(self.grid))) + + if any(size[i] > image_size[i] for i in range(len(self.grid))): + raise ValueError(f"The image size ({image_size})is smaller than the requested split size ({size})") steps = tuple( - (image_size[i] - split_size[i]) // (self.grid[i] - 1) if self.grid[i] > 1 else image_size[i] + (image_size[i] - size[i]) // (self.grid[i] - 1) if self.grid[i] > 1 else image_size[i] for i in range(len(self.grid)) ) - return split_size, steps + return size, steps diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 654267c5b4..4caca10613 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -2174,19 +2174,20 @@ def __init__( self, keys: KeysCollection, grid: Tuple[int, int] = (2, 2), - size: Optional[Union[int, Tuple[int, int]]] = None, + size: Optional[Union[int, Tuple[int, int], Dict[Hashable, Union[int, Tuple[int, int], None]]]] = None, allow_missing_keys: bool = False, ): super().__init__(keys, allow_missing_keys) self.grid = grid - self.splitter = GridSplit(grid=grid, size=size) + self.size = size if isinstance(size, dict) else {key: size for key in self.keys} + self.splitter = GridSplit(grid=grid) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: d = dict(data) n_outputs = np.prod(self.grid) output: List[Dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(n_outputs)] for key in self.key_iterator(d): - result = self.splitter(d[key]) + result = self.splitter(d[key], self.size[key]) for i in range(n_outputs): output[i][key] = result[i] return output diff --git a/tests/test_grid_splitd.py b/tests/test_grid_splitd.py index c9274d4c83..b03be5ff5a 100644 --- a/tests/test_grid_splitd.py +++ b/tests/test_grid_splitd.py @@ -30,12 +30,12 @@ TEST_CASE_1 = [{"keys": "image", "grid": (2, 1)}, {"image": A}, [A1, A2]] TEST_CASE_2 = [{"keys": "image", "grid": (1, 2)}, {"image": A1}, [A11, A12]] TEST_CASE_3 = [{"keys": "image", "grid": (1, 2)}, {"image": A2}, [A21, A22]] -TEST_CASE_4 = [{"keys": "image", "grid": (1, 1), "size": (2, 2)}, {"image": A}, [A11]] -TEST_CASE_5 = [{"keys": "image", "grid": (1, 1), "size": 4}, {"image": A}, [A]] -TEST_CASE_6 = [{"keys": "image", "grid": (2, 2), "size": 2}, {"image": A}, [A11, A12, A21, A22]] +TEST_CASE_4 = [{"keys": "image", "grid": (1, 1), "size": {"image": (2, 2)}}, {"image": A}, [A11]] +TEST_CASE_5 = [{"keys": "image", "grid": (1, 1), "size": {"image": 4}}, {"image": A}, [A]] +TEST_CASE_6 = [{"keys": "image", "grid": (2, 2), "size": {"image": 2}}, {"image": A}, [A11, A12, A21, A22]] TEST_CASE_7 = [{"keys": "image", "grid": (1, 1)}, {"image": A}, [A]] TEST_CASE_8 = [ - {"keys": "image", "grid": (2, 2), "size": 2}, + {"keys": "image", "grid": (2, 2), "size": {"image": 2}}, {"image": torch.arange(12).reshape(1, 3, 4).to(torch.float32)}, torch.Tensor([[[[0, 1], [4, 5]]], [[[2, 3], [6, 7]]], [[[4, 5], [8, 9]]], [[[6, 7], [10, 11]]]]).to(torch.float32), ] From da4d5fe48041b757eab44a79f1e9c5a72a8ec5ca Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Wed, 11 May 2022 19:44:38 +0000 Subject: [PATCH 5/9] Few fixes Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/data/wsi_datasets.py | 2 +- monai/data/wsi_reader.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/data/wsi_datasets.py b/monai/data/wsi_datasets.py index ff8b0ec0f4..13693e0c17 100644 --- a/monai/data/wsi_datasets.py +++ b/monai/data/wsi_datasets.py @@ -122,7 +122,7 @@ def _get_data(self, sample: Dict): location = self._get_location(sample) level = self._get_level(sample) size = self._get_size(sample) - return self.wsi_reader.get_data(wsi=wsi_obj, location=location[::-1], size=size, level=level) + return self.wsi_reader.get_data(wsi=wsi_obj, location=location, size=size, level=level) def _transform(self, index: int): # Get a single entry of data diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 8dee1f453e..00277ee0af 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -174,7 +174,7 @@ def get_data( # Verify location if location is None: location = (0, 0) - wsi_size = self.get_size(each_wsi, level) + wsi_size = self.get_size(each_wsi, 0) if location[0] > wsi_size[0] or location[1] > wsi_size[1]: raise ValueError(f"Location is outside of the image: location={location}, image size={wsi_size}") From 2d21126820419343e746b5afb0d927c5f9a3f9dd Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Wed, 11 May 2022 20:16:27 +0000 Subject: [PATCH 6/9] Fix imports Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/data/wsi_datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/data/wsi_datasets.py b/monai/data/wsi_datasets.py index 13693e0c17..665cbd196c 100644 --- a/monai/data/wsi_datasets.py +++ b/monai/data/wsi_datasets.py @@ -10,13 +10,13 @@ # limitations under the License. import inspect -from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, Optional, Sequence, Tuple, Union import numpy as np from monai.data import Dataset from monai.data.wsi_reader import BaseWSIReader, WSIReader -from monai.transforms import GridSplit, apply_transform +from monai.transforms import apply_transform from monai.utils import ensure_tuple_rep __all__ = ["PatchWSIDataset"] From 2a5f88c3e420fa0923616861bd4adbf228659550 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Wed, 11 May 2022 20:29:41 +0000 Subject: [PATCH 7/9] formatting Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_grid_splitd.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/tests/test_grid_splitd.py b/tests/test_grid_splitd.py index b03be5ff5a..086dd2691d 100644 --- a/tests/test_grid_splitd.py +++ b/tests/test_grid_splitd.py @@ -57,16 +57,8 @@ [{"image": A}, {"image": A}], [[A11, A12, A21, A22], [A11, A12, A21, A22]], ] -TEST_CASE_MC_1 = [ - {"keys": "image", "grid": (2, 1)}, - [{"image": A}, {"image": A}, {"image": A}], - [[A1, A2]] * 3, -] -TEST_CASE_MC_2 = [ - {"keys": "image", "grid": (1, 2)}, - [{"image": A1}, {"image": A2}], - [[A11, A12], [A21, A22]], -] +TEST_CASE_MC_1 = [{"keys": "image", "grid": (2, 1)}, [{"image": A}, {"image": A}, {"image": A}], [[A1, A2]] * 3] +TEST_CASE_MC_2 = [{"keys": "image", "grid": (1, 2)}, [{"image": A1}, {"image": A2}], [[A11, A12], [A21, A22]]] TEST_MULTIPLE = [] for p in TEST_NDARRAYS: From 4aa70cd1f1ab3b05b0f55bddaf6cf95ee82a6600 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Thu, 12 May 2022 12:46:50 +0000 Subject: [PATCH 8/9] simplify return for grid (1, 1) Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/transforms/spatial/array.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 2d9dab53d3..6568e2c5d0 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -2541,12 +2541,7 @@ def __call__( input_size = self.size if size is None else ensure_tuple_rep(size, len(self.grid)) if self.grid == (1, 1) and input_size is None: - if isinstance(image, torch.Tensor): - return [image] - elif isinstance(image, np.ndarray): - return [image] - else: - raise ValueError(f"Input type [{type(image)}] is not supported.") + return [image] split_size, steps = self._get_params(image.shape[1:], input_size) patches: List[NdarrayOrTensor] From 83fdcf67c668963ef73fb41316ddaeb8e2aec184 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Thu, 12 May 2022 17:16:43 +0000 Subject: [PATCH 9/9] Update docsting for size in gridsplitd Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 4caca10613..d354c89dbc 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -2160,7 +2160,8 @@ class GridSplitd(MapTransform): Args: keys: keys of the corresponding items to be transformed. grid: a tuple define the shape of the grid upon which the image is split. Defaults to (2, 2) - size: a tuple or an integer that defines the output patch sizes. + size: a tuple or an integer that defines the output patch sizes, + or a dictionary that define it seperately for each key, like {"image": 3, "mask", (2, 2)}. If it's an integer, the value will be repeated for each dimension. The default is None, where the patch size will be inferred from the grid shape. allow_missing_keys: don't raise exception if key is missing.