diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 6c2b15ef5b..5a7d81ad8e 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -312,7 +312,8 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): except AttributeError: return NotImplemented - def get_default_affine(self, dtype=torch.float64) -> torch.Tensor: + @staticmethod + def get_default_affine(dtype=torch.float64) -> torch.Tensor: return torch.eye(4, device=torch.device("cpu"), dtype=dtype) def as_tensor(self) -> torch.Tensor: diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index f28d029f29..30c4f246cc 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -63,7 +63,7 @@ pytorch_after, ) from monai.utils.deprecate_utils import deprecated_arg -from monai.utils.enums import GridPatchSort, PytorchPadMode, TraceKeys, TransformBackends +from monai.utils.enums import GridPatchSort, PytorchPadMode, TraceKeys, TransformBackends, WSIPatchKeys from monai.utils.misc import ImageMetaKey as Key from monai.utils.module import look_up_option from monai.utils.type_conversion import convert_data_type, get_equivalent_dtype, get_torch_dtype_from_string @@ -3142,6 +3142,9 @@ class GridPatch(Transform): pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``. pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. + Returns: + MetaTensor: A MetaTensor consisting of a batch of all the patches with associated metadata + """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -3225,21 +3228,24 @@ def __call__(self, array: NdarrayOrTensor): elif self.threshold: patched_image, locations = self.filter_threshold(patched_image, locations) - # Convert to original data type - output = list( - zip( - convert_to_dst_type(src=patched_image, dst=array)[0], - convert_to_dst_type(src=locations, dst=array, dtype=int)[0], - ) - ) - # Pad the patch list to have the requested number of patches - if self.num_patches and len(output) < self.num_patches: - patch = convert_to_dst_type( - src=np.full((array.shape[0], *self.patch_size), self.pad_kwargs.get("constant_values", 0)), dst=array - )[0] - start_location = convert_to_dst_type(src=np.zeros(len(self.patch_size)), dst=array)[0] - output += [(patch, start_location)] * (self.num_patches - len(output)) + if self.num_patches: + padding = self.num_patches - len(patched_image) + if padding > 0: + patched_image = np.pad( + patched_image, + [[0, padding], [0, 0]] + [[0, 0]] * len(self.patch_size), + constant_values=self.pad_kwargs.get("constant_values", 0), + ) + locations = np.pad(locations, [[0, padding], [0, 0]], constant_values=0) + + # Convert to MetaTensor + metadata = array.meta if isinstance(array, MetaTensor) else MetaTensor.get_default_meta() + metadata[WSIPatchKeys.LOCATION] = locations.T + metadata[WSIPatchKeys.COUNT] = len(locations) + metadata["spatial_shape"] = np.tile(np.array(self.patch_size), (len(locations), 1)).T + output = MetaTensor(x=patched_image, meta=metadata) + output.is_batch = True return output @@ -3265,6 +3271,9 @@ class RandGridPatch(GridPatch, RandomizableTransform): pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``. pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. + Returns: + MetaTensor: A MetaTensor consisting of a batch of all the patches with associated metadata + """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 720be615fa..bb6b0f1714 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -15,7 +15,6 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ -from copy import deepcopy from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np @@ -58,12 +57,10 @@ GridSamplePadMode, InterpolateMode, NumpyPadMode, - WSIPatchKeys, convert_to_tensor, ensure_tuple, ensure_tuple_rep, fall_back_tuple, - first, ) from monai.utils.deprecate_utils import deprecated_arg from monai.utils.enums import PytorchPadMode, TraceKeys @@ -1851,25 +1848,11 @@ def __init__( **pad_kwargs, ) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - original_spatial_shape = d[first(self.keys)].shape[1:] - output = [] - results = [self.patcher(d[key]) for key in self.keys] - num_patches = min(len(r) for r in results) - for patch in zip(*results): - new_dict = {k: v[0] for k, v in zip(self.keys, patch)} - # fill in the extra keys with unmodified data - for k in set(d.keys()).difference(set(self.keys)): - new_dict[k] = deepcopy(d[k]) - # fill additional metadata - new_dict["original_spatial_shape"] = original_spatial_shape - new_dict[WSIPatchKeys.LOCATION] = patch[0][1] # use the starting coordinate of the first item - new_dict[WSIPatchKeys.SIZE] = self.patcher.patch_size - new_dict[WSIPatchKeys.COUNT] = num_patches - new_dict["offset"] = self.patcher.offset - output.append(new_dict) - return output + for key in self.key_iterator(d): + d[key] = self.patcher(d[key]) + return d class RandGridPatchd(RandomizableTransform, MapTransform): @@ -1942,31 +1925,15 @@ def set_random_state( self.patcher.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - original_spatial_shape = d[first(self.keys)].shape[1:] - # all the keys share the same random noise - first_key: Union[Hashable, List] = self.first_key(d) - if first_key == []: - return [d] - self.patcher.randomize(d[first_key]) # type: ignore - results = [self.patcher(d[key], randomize=False) for key in self.keys] - - num_patches = min(len(r) for r in results) - output = [] - for patch in zip(*results): - new_dict = {k: v[0] for k, v in zip(self.keys, patch)} - # fill in the extra keys with unmodified data - for k in set(d.keys()).difference(set(self.keys)): - new_dict[k] = deepcopy(d[k]) - # fill additional metadata - new_dict["original_spatial_shape"] = original_spatial_shape - new_dict[WSIPatchKeys.LOCATION] = patch[0][1] # use the starting coordinate of the first item - new_dict[WSIPatchKeys.SIZE] = self.patcher.patch_size - new_dict[WSIPatchKeys.COUNT] = num_patches - new_dict["offset"] = self.patcher.offset - output.append(new_dict) - return output + # All the keys share the same random noise + for key in self.key_iterator(d): + self.patcher.randomize(d[key]) + break + for key in self.key_iterator(d): + d[key] = self.patcher(d[key], randomize=False) + return d SpatialResampleD = SpatialResampleDict = SpatialResampled diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index cde6bd8cc2..2039b3a5a5 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -384,6 +384,7 @@ def __init__( dim: int = 0, keepdim: bool = True, update_meta: bool = True, + list_output: bool = False, allow_missing_keys: bool = False, ) -> None: """ @@ -399,15 +400,34 @@ def __init__( dimension will be squeezed. update_meta: if `True`, copy `[key]_meta_dict` for each output and update affine to reflect the cropped image + list_output: it `True`, the output will be a list of dictionaries with the same keys as original. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) self.output_postfixes = output_postfixes self.splitter = SplitDim(dim, keepdim, update_meta) + self.list_output = list_output + if self.list_output is None and self.output_postfixes is not None: + raise ValueError("`output_postfixes` should not be provided when `list_output` is `True`.") - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__( + self, data: Mapping[Hashable, torch.Tensor] + ) -> Union[Dict[Hashable, torch.Tensor], List[Dict[Hashable, torch.Tensor]]]: d = dict(data) - for key in self.key_iterator(d): + all_keys = list(set(self.key_iterator(d))) + + if self.list_output: + output = [] + results = [self.splitter(d[key]) for key in all_keys] + for row in zip(*results): + new_dict = {k: v for k, v in zip(all_keys, row)} + # fill in the extra keys with unmodified data + for k in set(d.keys()).difference(set(all_keys)): + new_dict[k] = deepcopy(d[k]) + output.append(new_dict) + return output + + for key in all_keys: rets = self.splitter(d[key]) postfixes: Sequence = list(range(len(rets))) if self.output_postfixes is None else self.output_postfixes if len(postfixes) != len(rets): diff --git a/tests/min_tests.py b/tests/min_tests.py index 0f1e4e61ec..4371f3ad33 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -60,6 +60,7 @@ def run_testsuit(): "test_foreground_mask", "test_foreground_maskd", "test_global_mutual_information_loss", + "test_grid_patch", "test_handler_checkpoint_loader", "test_handler_checkpoint_saver", "test_handler_classification_saver", @@ -135,6 +136,7 @@ def run_testsuit(): "test_png_saver", "test_prepare_batch_default", "test_prepare_batch_extra_input", + "test_rand_grid_patch", "test_rand_rotate", "test_rand_rotated", "test_rand_zoom", diff --git a/tests/test_grid_patch.py b/tests/test_grid_patch.py index 8a105afcd2..03b33147dd 100644 --- a/tests/test_grid_patch.py +++ b/tests/test_grid_patch.py @@ -14,8 +14,9 @@ import numpy as np from parameterized import parameterized +from monai.data import MetaTensor, set_track_meta from monai.transforms.spatial.array import GridPatch -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose A = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1) A11 = A[:, :2, :2] @@ -46,34 +47,69 @@ ] TEST_CASE_13 = [{"patch_size": (2, 2), "threshold": 50.0}, A, [A11]] +TEST_CASE_MEAT_0 = [ + {"patch_size": (2, 2)}, + A, + [A11, A12, A21, A22], + [{"location": [0, 0]}, {"location": [0, 2]}, {"location": [2, 0]}, {"location": [2, 2]}], +] + +TEST_CASE_MEAT_1 = [ + {"patch_size": (2, 2)}, + MetaTensor(x=A, meta={"path": "path/to/file"}), + [A11, A12, A21, A22], + [ + {"location": [0, 0], "path": "path/to/file"}, + {"location": [0, 2], "path": "path/to/file"}, + {"location": [2, 0], "path": "path/to/file"}, + {"location": [2, 2], "path": "path/to/file"}, + ], +] -TEST_SINGLE = [] +TEST_CASES = [] for p in TEST_NDARRAYS: - TEST_SINGLE.append([p, *TEST_CASE_0]) - TEST_SINGLE.append([p, *TEST_CASE_1]) - TEST_SINGLE.append([p, *TEST_CASE_2]) - TEST_SINGLE.append([p, *TEST_CASE_3]) - TEST_SINGLE.append([p, *TEST_CASE_4]) - TEST_SINGLE.append([p, *TEST_CASE_5]) - TEST_SINGLE.append([p, *TEST_CASE_6]) - TEST_SINGLE.append([p, *TEST_CASE_7]) - TEST_SINGLE.append([p, *TEST_CASE_8]) - TEST_SINGLE.append([p, *TEST_CASE_9]) - TEST_SINGLE.append([p, *TEST_CASE_10]) - TEST_SINGLE.append([p, *TEST_CASE_11]) - TEST_SINGLE.append([p, *TEST_CASE_12]) - TEST_SINGLE.append([p, *TEST_CASE_13]) + TEST_CASES.append([p, *TEST_CASE_0]) + TEST_CASES.append([p, *TEST_CASE_1]) + TEST_CASES.append([p, *TEST_CASE_2]) + TEST_CASES.append([p, *TEST_CASE_3]) + TEST_CASES.append([p, *TEST_CASE_4]) + TEST_CASES.append([p, *TEST_CASE_5]) + TEST_CASES.append([p, *TEST_CASE_6]) + TEST_CASES.append([p, *TEST_CASE_7]) + TEST_CASES.append([p, *TEST_CASE_8]) + TEST_CASES.append([p, *TEST_CASE_9]) + TEST_CASES.append([p, *TEST_CASE_10]) + TEST_CASES.append([p, *TEST_CASE_11]) + TEST_CASES.append([p, *TEST_CASE_12]) + TEST_CASES.append([p, *TEST_CASE_13]) class TestGridPatch(unittest.TestCase): - @parameterized.expand(TEST_SINGLE) + @parameterized.expand(TEST_CASES) def test_grid_patch(self, in_type, input_parameters, image, expected): input_image = in_type(image) splitter = GridPatch(**input_parameters) - output = list(splitter(input_image)) + output = splitter(input_image) self.assertEqual(len(output), len(expected)) for output_patch, expected_patch in zip(output, expected): - assert_allclose(output_patch[0], expected_patch, type_test=False) + assert_allclose(output_patch, expected_patch, type_test=False) + + @parameterized.expand([TEST_CASE_MEAT_0, TEST_CASE_MEAT_1]) + @SkipIfBeforePyTorchVersion((1, 9, 1)) + def test_grid_patch_meta(self, input_parameters, image, expected, expected_meta): + set_track_meta(True) + splitter = GridPatch(**input_parameters) + output = splitter(image) + self.assertEqual(len(output), len(expected)) + if "path" in expected_meta[0]: + self.assertTrue(output.meta["path"] == expected_meta[0]["path"]) + for output_patch, expected_patch, expected_patch_meta in zip(output, expected, expected_meta): + assert_allclose(output_patch, expected_patch, type_test=False) + self.assertTrue(isinstance(output_patch, MetaTensor)) + self.assertTrue(output_patch.meta["location"] == expected_patch_meta["location"]) + self.assertTrue(output_patch.meta["spatial_shape"], list(output_patch.shape[1:])) + if "path" in expected_meta[0]: + self.assertTrue(output_patch.meta["path"] == expected_patch_meta["path"]) if __name__ == "__main__": diff --git a/tests/test_grid_patchd.py b/tests/test_grid_patchd.py index 8f1e238b42..0f1bea5f8a 100644 --- a/tests/test_grid_patchd.py +++ b/tests/test_grid_patchd.py @@ -74,10 +74,10 @@ def test_grid_patchd(self, in_type, input_parameters, image_dict, expected): if k == image_key: input_dict[k] = in_type(v) splitter = GridPatchd(keys=image_key, **input_parameters) - output = list(splitter(input_dict)) - self.assertEqual(len(output), len(expected)) - for output_patch, expected_patch in zip(output, expected): - assert_allclose(output_patch[image_key], expected_patch, type_test=False) + output = splitter(input_dict) + self.assertEqual(len(output[image_key]), len(expected)) + for output_patch, expected_patch in zip(output[image_key], expected): + assert_allclose(output_patch, expected_patch, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_grid_patch.py b/tests/test_rand_grid_patch.py index 3957dc1ce8..417915fbab 100644 --- a/tests/test_rand_grid_patch.py +++ b/tests/test_rand_grid_patch.py @@ -14,9 +14,10 @@ import numpy as np from parameterized import parameterized +from monai.data import MetaTensor, set_track_meta from monai.transforms.spatial.array import RandGridPatch from monai.utils import set_determinism -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose set_determinism(1234) @@ -57,6 +58,25 @@ ] TEST_CASE_10 = [{"patch_size": (2, 2), "min_offset": 0, "max_offset": 0, "threshold": 50.0}, A, [A11]] +TEST_CASE_MEAT_0 = [ + {"patch_size": (2, 2)}, + A, + [A11, A12, A21, A22], + [{"location": [0, 0]}, {"location": [0, 2]}, {"location": [2, 0]}, {"location": [2, 2]}], +] + +TEST_CASE_MEAT_1 = [ + {"patch_size": (2, 2)}, + MetaTensor(x=A, meta={"path": "path/to/file"}), + [A11, A12, A21, A22], + [ + {"location": [0, 0], "path": "path/to/file"}, + {"location": [0, 2], "path": "path/to/file"}, + {"location": [2, 0], "path": "path/to/file"}, + {"location": [2, 2], "path": "path/to/file"}, + ], +] + TEST_SINGLE = [] for p in TEST_NDARRAYS: TEST_SINGLE.append([p, *TEST_CASE_0]) @@ -78,10 +98,28 @@ def test_rand_grid_patch(self, in_type, input_parameters, image, expected): input_image = in_type(image) splitter = RandGridPatch(**input_parameters) splitter.set_random_state(1234) - output = list(splitter(input_image)) + output = splitter(input_image) self.assertEqual(len(output), len(expected)) for output_patch, expected_patch in zip(output, expected): - assert_allclose(output_patch[0], expected_patch, type_test=False) + assert_allclose(output_patch, expected_patch, type_test=False) + + @parameterized.expand([TEST_CASE_MEAT_0, TEST_CASE_MEAT_1]) + @SkipIfBeforePyTorchVersion((1, 9, 1)) + def test_rand_grid_patch_meta(self, input_parameters, image, expected, expected_meta): + set_track_meta(True) + splitter = RandGridPatch(**input_parameters) + splitter.set_random_state(1234) + output = splitter(image) + self.assertEqual(len(output), len(expected)) + if "path" in expected_meta[0]: + self.assertTrue(output.meta["path"] == expected_meta[0]["path"]) + for output_patch, expected_patch, expected_patch_meta in zip(output, expected, expected_meta): + assert_allclose(output_patch, expected_patch, type_test=False) + self.assertTrue(isinstance(output_patch, MetaTensor)) + self.assertTrue(output_patch.meta["location"] == expected_patch_meta["location"]) + self.assertTrue(output_patch.meta["spatial_shape"], list(output_patch.shape[1:])) + if "path" in expected_meta[0]: + self.assertTrue(output_patch.meta["path"] == expected_patch_meta["path"]) if __name__ == "__main__": diff --git a/tests/test_rand_grid_patchd.py b/tests/test_rand_grid_patchd.py index 656fbd9e36..4f3ec3bb6a 100644 --- a/tests/test_rand_grid_patchd.py +++ b/tests/test_rand_grid_patchd.py @@ -83,10 +83,10 @@ def test_rand_grid_patchd(self, in_type, input_parameters, image_dict, expected) input_dict[k] = in_type(v) splitter = RandGridPatchd(keys=image_key, **input_parameters) splitter.set_random_state(1234) - output = list(splitter(input_dict)) - self.assertEqual(len(output), len(expected)) - for output_patch, expected_patch in zip(output, expected): - assert_allclose(output_patch[image_key], expected_patch, type_test=False) + output = splitter(input_dict) + self.assertEqual(len(output[image_key]), len(expected)) + for output_patch, expected_patch in zip(output[image_key], expected): + assert_allclose(output_patch, expected_patch, type_test=False) if __name__ == "__main__": diff --git a/tests/test_splitdimd.py b/tests/test_splitdimd.py index 1e39439b86..ee8cc043e4 100644 --- a/tests/test_splitdimd.py +++ b/tests/test_splitdimd.py @@ -25,7 +25,8 @@ for p in TEST_NDARRAYS: for keepdim in (True, False): for update_meta in (True, False): - TESTS.append((keepdim, p, update_meta)) + for list_output in (True, False): + TESTS.append((keepdim, p, update_meta, list_output)) class TestSplitDimd(unittest.TestCase): @@ -39,14 +40,18 @@ def setUpClass(cls): cls.data: MetaTensor = loader(data) @parameterized.expand(TESTS) - def test_correct(self, keepdim, im_type, update_meta): + def test_correct(self, keepdim, im_type, update_meta, list_output): data = deepcopy(self.data) data["i"] = im_type(data["i"]) arr = data["i"] for dim in range(arr.ndim): - out = SplitDimd("i", dim=dim, keepdim=keepdim, update_meta=update_meta)(data) - self.assertIsInstance(out, dict) - self.assertEqual(len(out.keys()), len(data.keys()) + arr.shape[dim]) + out = SplitDimd("i", dim=dim, keepdim=keepdim, update_meta=update_meta, list_output=list_output)(data) + if list_output: + self.assertIsInstance(out, list) + self.assertEqual(len(out), arr.shape[dim]) + else: + self.assertIsInstance(out, dict) + self.assertEqual(len(out.keys()), len(data.keys()) + arr.shape[dim]) # if updating metadata, pick some random points and # check same world coordinates between input and output if update_meta: @@ -55,14 +60,20 @@ def test_correct(self, keepdim, im_type, update_meta): split_im_idx = idx[dim] split_idx = deepcopy(idx) split_idx[dim] = 0 - split_im = out[f"i_{split_im_idx}"] + if list_output: + split_im = out[split_im_idx]["i"] + else: + split_im = out[f"i_{split_im_idx}"] if isinstance(data, MetaTensor) and isinstance(split_im, MetaTensor): # idx[1:] to remove channel and then add 1 for 4th element real_world = data.affine @ torch.tensor(idx[1:] + [1]).double() real_world2 = split_im.affine @ torch.tensor(split_idx[1:] + [1]).double() assert_allclose(real_world, real_world2) - out = out["i_0"] + if list_output: + out = out[0]["i"] + else: + out = out["i_0"] expected_ndim = arr.ndim if keepdim else arr.ndim - 1 self.assertEqual(out.ndim, expected_ndim) # assert is a shallow copy