diff --git a/monai/data/wsi_datasets.py b/monai/data/wsi_datasets.py index 750b3fda20..665cbd196c 100644 --- a/monai/data/wsi_datasets.py +++ b/monai/data/wsi_datasets.py @@ -127,11 +127,13 @@ def _get_data(self, sample: Dict): def _transform(self, index: int): # Get a single entry of data sample: Dict = self.data[index] + # Extract patch image and associated metadata image, metadata = self._get_data(sample) + # Get the label label = self._get_label(sample) - # Create put all patch information together and apply transforms - patch = {"image": image, "label": label, "metadata": metadata} - return apply_transform(self.transform, patch) if self.transform else patch + # Apply transforms and output + output = {"image": image, "label": label, "metadata": metadata} + return apply_transform(self.transform, output) if self.transform else output diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 8dee1f453e..00277ee0af 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -174,7 +174,7 @@ def get_data( # Verify location if location is None: location = (0, 0) - wsi_size = self.get_size(each_wsi, level) + wsi_size = self.get_size(each_wsi, 0) if location[0] > wsi_size[0] or location[1] > wsi_size[1]: raise ValueError(f"Location is outside of the image: location={location}, image size={wsi_size}") diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 65df5d2b1b..6568e2c5d0 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -2535,62 +2535,62 @@ def __init__(self, grid: Tuple[int, int] = (2, 2), size: Optional[Union[int, Tup # Patch size self.size = None if size is None else ensure_tuple_rep(size, len(self.grid)) - def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: - if self.grid == (1, 1) and self.size is None: - if isinstance(image, torch.Tensor): - return torch.stack([image]) - elif isinstance(image, np.ndarray): - return np.stack([image]) # type: ignore - else: - raise ValueError(f"Input type [{type(image)}] is not supported.") + def __call__( + self, image: NdarrayOrTensor, size: Optional[Union[int, Tuple[int, int], np.ndarray]] = None + ) -> List[NdarrayOrTensor]: + input_size = self.size if size is None else ensure_tuple_rep(size, len(self.grid)) + + if self.grid == (1, 1) and input_size is None: + return [image] - size, steps = self._get_params(image.shape[1:]) - patches: NdarrayOrTensor + split_size, steps = self._get_params(image.shape[1:], input_size) + patches: List[NdarrayOrTensor] if isinstance(image, torch.Tensor): - patches = ( - image.unfold(1, size[0], steps[0]) - .unfold(2, size[1], steps[1]) + unfolded_image = ( + image.unfold(1, split_size[0], steps[0]) + .unfold(2, split_size[1], steps[1]) .flatten(1, 2) .transpose(0, 1) - .contiguous() ) + # Make a list of contiguous patches + patches = [p.contiguous() for p in unfolded_image] elif isinstance(image, np.ndarray): x_step, y_step = steps c_stride, x_stride, y_stride = image.strides n_channels = image.shape[0] - patches = as_strided( + strided_image = as_strided( image, - shape=(*self.grid, n_channels, size[0], size[1]), + shape=(*self.grid, n_channels, split_size[0], split_size[1]), strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), writeable=False, ) - # flatten the first two dimensions - patches = patches.reshape(np.prod(patches.shape[:2]), *patches.shape[2:]) - # make it a contiguous array - patches = np.ascontiguousarray(patches) + # Flatten the first two dimensions + strided_image = strided_image.reshape(-1, *strided_image.shape[2:]) + # Make a list of contiguous patches + patches = [np.ascontiguousarray(p) for p in strided_image] else: raise ValueError(f"Input type [{type(image)}] is not supported.") return patches - def _get_params(self, image_size: Union[Sequence[int], np.ndarray]): + def _get_params( + self, image_size: Union[Sequence[int], np.ndarray], size: Optional[Union[Sequence[int], np.ndarray]] = None + ): """ Calculate the size and step required for splitting the image Args: The size of the input image """ - if self.size is not None: - # Set the split size to the given default size - if any(self.size[i] > image_size[i] for i in range(len(self.grid))): - raise ValueError("The image size ({image_size})is smaller than the requested split size ({self.size})") - split_size = self.size - else: + if size is None: # infer each sub-image size from the image size and the grid - split_size = tuple(image_size[i] // self.grid[i] for i in range(len(self.grid))) + size = tuple(image_size[i] // self.grid[i] for i in range(len(self.grid))) + + if any(size[i] > image_size[i] for i in range(len(self.grid))): + raise ValueError(f"The image size ({image_size})is smaller than the requested split size ({size})") steps = tuple( - (image_size[i] - split_size[i]) // (self.grid[i] - 1) if self.grid[i] > 1 else image_size[i] + (image_size[i] - size[i]) // (self.grid[i] - 1) if self.grid[i] > 1 else image_size[i] for i in range(len(self.grid)) ) - return split_size, steps + return size, steps diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 47fe05700e..d354c89dbc 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -2160,7 +2160,8 @@ class GridSplitd(MapTransform): Args: keys: keys of the corresponding items to be transformed. grid: a tuple define the shape of the grid upon which the image is split. Defaults to (2, 2) - size: a tuple or an integer that defines the output patch sizes. + size: a tuple or an integer that defines the output patch sizes, + or a dictionary that define it seperately for each key, like {"image": 3, "mask", (2, 2)}. If it's an integer, the value will be repeated for each dimension. The default is None, where the patch size will be inferred from the grid shape. allow_missing_keys: don't raise exception if key is missing. @@ -2174,17 +2175,23 @@ def __init__( self, keys: KeysCollection, grid: Tuple[int, int] = (2, 2), - size: Optional[Union[int, Tuple[int, int]]] = None, + size: Optional[Union[int, Tuple[int, int], Dict[Hashable, Union[int, Tuple[int, int], None]]]] = None, allow_missing_keys: bool = False, ): super().__init__(keys, allow_missing_keys) - self.splitter = GridSplit(grid=grid, size=size) + self.grid = grid + self.size = size if isinstance(size, dict) else {key: size for key in self.keys} + self.splitter = GridSplit(grid=grid) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: d = dict(data) + n_outputs = np.prod(self.grid) + output: List[Dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(n_outputs)] for key in self.key_iterator(d): - d[key] = self.splitter(d[key]) - return d + result = self.splitter(d[key], self.size[key]) + for i in range(n_outputs): + output[i][key] = result[i] + return output SpatialResampleD = SpatialResampleDict = SpatialResampled diff --git a/tests/test_grid_split.py b/tests/test_grid_split.py index 6f0525029d..82734ffd93 100644 --- a/tests/test_grid_split.py +++ b/tests/test_grid_split.py @@ -26,14 +26,14 @@ A2 = torch.cat([A21, A22], 2) A = torch.cat([A1, A2], 1) -TEST_CASE_0 = [{"grid": (2, 2)}, A, torch.stack([A11, A12, A21, A22])] -TEST_CASE_1 = [{"grid": (2, 1)}, A, torch.stack([A1, A2])] -TEST_CASE_2 = [{"grid": (1, 2)}, A1, torch.stack([A11, A12])] -TEST_CASE_3 = [{"grid": (1, 2)}, A2, torch.stack([A21, A22])] -TEST_CASE_4 = [{"grid": (1, 1), "size": (2, 2)}, A, torch.stack([A11])] -TEST_CASE_5 = [{"grid": (1, 1), "size": 4}, A, torch.stack([A])] -TEST_CASE_6 = [{"grid": (2, 2), "size": 2}, A, torch.stack([A11, A12, A21, A22])] -TEST_CASE_7 = [{"grid": (1, 1)}, A, torch.stack([A])] +TEST_CASE_0 = [{"grid": (2, 2)}, A, [A11, A12, A21, A22]] +TEST_CASE_1 = [{"grid": (2, 1)}, A, [A1, A2]] +TEST_CASE_2 = [{"grid": (1, 2)}, A1, [A11, A12]] +TEST_CASE_3 = [{"grid": (1, 2)}, A2, [A21, A22]] +TEST_CASE_4 = [{"grid": (1, 1), "size": (2, 2)}, A, [A11]] +TEST_CASE_5 = [{"grid": (1, 1), "size": 4}, A, [A]] +TEST_CASE_6 = [{"grid": (2, 2), "size": 2}, A, [A11, A12, A21, A22]] +TEST_CASE_7 = [{"grid": (1, 1)}, A, [A]] TEST_CASE_8 = [ {"grid": (2, 2), "size": 2}, torch.arange(12).reshape(1, 3, 4).to(torch.float32), @@ -52,9 +52,9 @@ TEST_SINGLE.append([p, *TEST_CASE_7]) TEST_SINGLE.append([p, *TEST_CASE_8]) -TEST_CASE_MC_0 = [{"grid": (2, 2)}, [A, A], [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])]] -TEST_CASE_MC_1 = [{"grid": (2, 1)}, [A] * 5, [torch.stack([A1, A2])] * 5] -TEST_CASE_MC_2 = [{"grid": (1, 2)}, [A1, A2], [torch.stack([A11, A12]), torch.stack([A21, A22])]] +TEST_CASE_MC_0 = [{"grid": (2, 2)}, [A, A], [[A11, A12, A21, A22], [A11, A12, A21, A22]]] +TEST_CASE_MC_1 = [{"grid": (2, 1)}, [A] * 5, [[A1, A2]] * 5] +TEST_CASE_MC_2 = [{"grid": (1, 2)}, [A1, A2], [[A11, A12], [A21, A22]]] TEST_MULTIPLE = [] for p in TEST_NDARRAYS: @@ -69,7 +69,8 @@ def test_split_patch_single_call(self, in_type, input_parameters, image, expecte input_image = in_type(image) splitter = GridSplit(**input_parameters) output = splitter(input_image) - assert_allclose(output, expected, type_test=False) + for output_patch, expected_patch in zip(output, expected): + assert_allclose(output_patch, expected_patch, type_test=False) @parameterized.expand(TEST_MULTIPLE) def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list): @@ -77,7 +78,8 @@ def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, ex for image, expected in zip(img_list, expected_list): input_image = in_type(image) output = splitter(input_image) - assert_allclose(output, expected, type_test=False) + for output_patch, expected_patch in zip(output, expected): + assert_allclose(output_patch, expected_patch, type_test=False) if __name__ == "__main__": diff --git a/tests/test_grid_splitd.py b/tests/test_grid_splitd.py index f325a16946..086dd2691d 100644 --- a/tests/test_grid_splitd.py +++ b/tests/test_grid_splitd.py @@ -26,16 +26,16 @@ A2 = torch.cat([A21, A22], 2) A = torch.cat([A1, A2], 1) -TEST_CASE_0 = [{"keys": "image", "grid": (2, 2)}, {"image": A}, torch.stack([A11, A12, A21, A22])] -TEST_CASE_1 = [{"keys": "image", "grid": (2, 1)}, {"image": A}, torch.stack([A1, A2])] -TEST_CASE_2 = [{"keys": "image", "grid": (1, 2)}, {"image": A1}, torch.stack([A11, A12])] -TEST_CASE_3 = [{"keys": "image", "grid": (1, 2)}, {"image": A2}, torch.stack([A21, A22])] -TEST_CASE_4 = [{"keys": "image", "grid": (1, 1), "size": (2, 2)}, {"image": A}, torch.stack([A11])] -TEST_CASE_5 = [{"keys": "image", "grid": (1, 1), "size": 4}, {"image": A}, torch.stack([A])] -TEST_CASE_6 = [{"keys": "image", "grid": (2, 2), "size": 2}, {"image": A}, torch.stack([A11, A12, A21, A22])] -TEST_CASE_7 = [{"keys": "image", "grid": (1, 1)}, {"image": A}, torch.stack([A])] +TEST_CASE_0 = [{"keys": "image", "grid": (2, 2)}, {"image": A}, [A11, A12, A21, A22]] +TEST_CASE_1 = [{"keys": "image", "grid": (2, 1)}, {"image": A}, [A1, A2]] +TEST_CASE_2 = [{"keys": "image", "grid": (1, 2)}, {"image": A1}, [A11, A12]] +TEST_CASE_3 = [{"keys": "image", "grid": (1, 2)}, {"image": A2}, [A21, A22]] +TEST_CASE_4 = [{"keys": "image", "grid": (1, 1), "size": {"image": (2, 2)}}, {"image": A}, [A11]] +TEST_CASE_5 = [{"keys": "image", "grid": (1, 1), "size": {"image": 4}}, {"image": A}, [A]] +TEST_CASE_6 = [{"keys": "image", "grid": (2, 2), "size": {"image": 2}}, {"image": A}, [A11, A12, A21, A22]] +TEST_CASE_7 = [{"keys": "image", "grid": (1, 1)}, {"image": A}, [A]] TEST_CASE_8 = [ - {"keys": "image", "grid": (2, 2), "size": 2}, + {"keys": "image", "grid": (2, 2), "size": {"image": 2}}, {"image": torch.arange(12).reshape(1, 3, 4).to(torch.float32)}, torch.Tensor([[[[0, 1], [4, 5]]], [[[2, 3], [6, 7]]], [[[4, 5], [8, 9]]], [[[6, 7], [10, 11]]]]).to(torch.float32), ] @@ -55,18 +55,10 @@ TEST_CASE_MC_0 = [ {"keys": "image", "grid": (2, 2)}, [{"image": A}, {"image": A}], - [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])], -] -TEST_CASE_MC_1 = [ - {"keys": "image", "grid": (2, 1)}, - [{"image": A}, {"image": A}, {"image": A}], - [torch.stack([A1, A2])] * 3, -] -TEST_CASE_MC_2 = [ - {"keys": "image", "grid": (1, 2)}, - [{"image": A1}, {"image": A2}], - [torch.stack([A11, A12]), torch.stack([A21, A22])], + [[A11, A12, A21, A22], [A11, A12, A21, A22]], ] +TEST_CASE_MC_1 = [{"keys": "image", "grid": (2, 1)}, [{"image": A}, {"image": A}, {"image": A}], [[A1, A2]] * 3] +TEST_CASE_MC_2 = [{"keys": "image", "grid": (1, 2)}, [{"image": A1}, {"image": A2}], [[A11, A12], [A21, A22]]] TEST_MULTIPLE = [] for p in TEST_NDARRAYS: @@ -82,8 +74,9 @@ def test_split_patch_single_call(self, in_type, input_parameters, img_dict, expe for k, v in img_dict.items(): input_dict[k] = in_type(v) splitter = GridSplitd(**input_parameters) - output = splitter(input_dict)[input_parameters["keys"]] - assert_allclose(output, expected, type_test=False) + output = splitter(input_dict) + for output_patch, expected_patch in zip(output, expected): + assert_allclose(output_patch[input_parameters["keys"]], expected_patch, type_test=False) @parameterized.expand(TEST_MULTIPLE) def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list): @@ -92,8 +85,9 @@ def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, ex input_dict = {} for k, v in img_dict.items(): input_dict[k] = in_type(v) - output = splitter(input_dict)[input_parameters["keys"]] - assert_allclose(output, expected, type_test=False) + output = splitter(input_dict) + for output_patch, expected_patch in zip(output, expected): + assert_allclose(output_patch[input_parameters["keys"]], expected_patch, type_test=False) if __name__ == "__main__": diff --git a/tests/test_patch_wsi_dataset_new.py b/tests/test_patch_wsi_dataset_new.py index 0be30536de..d128d45262 100644 --- a/tests/test_patch_wsi_dataset_new.py +++ b/tests/test_patch_wsi_dataset_new.py @@ -158,7 +158,7 @@ def setUpClass(cls): cls.backend = "cucim" -@skipUnless(has_osl, "Requires cucim") +@skipUnless(has_osl, "Requires openslide") class TestPatchWSIDatasetOpenSlide(PatchWSIDatasetTests.Tests): @classmethod def setUpClass(cls):