From 44c90058684706fe55078fbc7e7bc849609375f0 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 6 Sep 2022 17:41:18 +0000 Subject: [PATCH 01/19] MetaTensor output for GridPatch (and RandGridPatch) Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/transforms/spatial/array.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index f28d029f29..7b68c1ecf1 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -3225,21 +3225,22 @@ def __call__(self, array: NdarrayOrTensor): elif self.threshold: patched_image, locations = self.filter_threshold(patched_image, locations) - # Convert to original data type - output = list( - zip( - convert_to_dst_type(src=patched_image, dst=array)[0], - convert_to_dst_type(src=locations, dst=array, dtype=int)[0], - ) - ) - # Pad the patch list to have the requested number of patches - if self.num_patches and len(output) < self.num_patches: - patch = convert_to_dst_type( - src=np.full((array.shape[0], *self.patch_size), self.pad_kwargs.get("constant_values", 0)), dst=array - )[0] - start_location = convert_to_dst_type(src=np.zeros(len(self.patch_size)), dst=array)[0] - output += [(patch, start_location)] * (self.num_patches - len(output)) + if self.num_patches: + padding = self.num_patches - len(patched_image) + if padding > 0: + patched_image = np.pad( + patched_image, + [[0, padding], [0, 0]] + [[0, 0]] * len(self.patch_size), + constant_values=self.pad_kwargs.get("constant_values", 0), + ) + locations = np.pad(locations, [[0, padding], [0, 0]], constant_values=0) + + # Convert to MetaTensor + metadata = array.meta if isinstance(array, MetaTensor) else {} + metadata["location"] = locations.T + output = MetaTensor(x=patched_image, meta=metadata) + output.is_batch = True return output From 612f5e8b0b1a3a393b7a2558727ec535aa187ee4 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 6 Sep 2022 17:41:39 +0000 Subject: [PATCH 02/19] Add unittests for grid patch MetaTensor Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_grid_patch.py | 68 +++++++++++++++++++++++++++++----------- 1 file changed, 50 insertions(+), 18 deletions(-) diff --git a/tests/test_grid_patch.py b/tests/test_grid_patch.py index 8a105afcd2..5eaa169403 100644 --- a/tests/test_grid_patch.py +++ b/tests/test_grid_patch.py @@ -14,6 +14,7 @@ import numpy as np from parameterized import parameterized +from monai.data import MetaTensor from monai.transforms.spatial.array import GridPatch from tests.utils import TEST_NDARRAYS, assert_allclose @@ -46,34 +47,65 @@ ] TEST_CASE_13 = [{"patch_size": (2, 2), "threshold": 50.0}, A, [A11]] +TEST_CASE_MEAT_0 = [ + {"patch_size": (2, 2)}, + A, + [A11, A12, A21, A22], + [{"location": [0, 0]}, {"location": [0, 2]}, {"location": [2, 0]}, {"location": [2, 2]}], +] + +TEST_CASE_MEAT_1 = [ + {"patch_size": (2, 2)}, + MetaTensor(x=A, meta={"path": "path/to/file"}), + [A11, A12, A21, A22], + [ + {"location": [0, 0], "path": "path/to/file"}, + {"location": [0, 2], "path": "path/to/file"}, + {"location": [2, 0], "path": "path/to/file"}, + {"location": [2, 2], "path": "path/to/file"}, + ], +] -TEST_SINGLE = [] +TEST_CASES = [] for p in TEST_NDARRAYS: - TEST_SINGLE.append([p, *TEST_CASE_0]) - TEST_SINGLE.append([p, *TEST_CASE_1]) - TEST_SINGLE.append([p, *TEST_CASE_2]) - TEST_SINGLE.append([p, *TEST_CASE_3]) - TEST_SINGLE.append([p, *TEST_CASE_4]) - TEST_SINGLE.append([p, *TEST_CASE_5]) - TEST_SINGLE.append([p, *TEST_CASE_6]) - TEST_SINGLE.append([p, *TEST_CASE_7]) - TEST_SINGLE.append([p, *TEST_CASE_8]) - TEST_SINGLE.append([p, *TEST_CASE_9]) - TEST_SINGLE.append([p, *TEST_CASE_10]) - TEST_SINGLE.append([p, *TEST_CASE_11]) - TEST_SINGLE.append([p, *TEST_CASE_12]) - TEST_SINGLE.append([p, *TEST_CASE_13]) + TEST_CASES.append([p, *TEST_CASE_0]) + TEST_CASES.append([p, *TEST_CASE_1]) + TEST_CASES.append([p, *TEST_CASE_2]) + TEST_CASES.append([p, *TEST_CASE_3]) + TEST_CASES.append([p, *TEST_CASE_4]) + TEST_CASES.append([p, *TEST_CASE_5]) + TEST_CASES.append([p, *TEST_CASE_6]) + TEST_CASES.append([p, *TEST_CASE_7]) + TEST_CASES.append([p, *TEST_CASE_8]) + TEST_CASES.append([p, *TEST_CASE_9]) + TEST_CASES.append([p, *TEST_CASE_10]) + TEST_CASES.append([p, *TEST_CASE_11]) + TEST_CASES.append([p, *TEST_CASE_12]) + TEST_CASES.append([p, *TEST_CASE_13]) class TestGridPatch(unittest.TestCase): - @parameterized.expand(TEST_SINGLE) + @parameterized.expand(TEST_CASES) def test_grid_patch(self, in_type, input_parameters, image, expected): input_image = in_type(image) splitter = GridPatch(**input_parameters) - output = list(splitter(input_image)) + output = splitter(input_image) self.assertEqual(len(output), len(expected)) for output_patch, expected_patch in zip(output, expected): - assert_allclose(output_patch[0], expected_patch, type_test=False) + assert_allclose(output_patch, expected_patch, type_test=False) + + @parameterized.expand([TEST_CASE_MEAT_0, TEST_CASE_MEAT_1]) + def test_grid_patch_meta(self, input_parameters, image, expected, expected_meta): + splitter = GridPatch(**input_parameters) + output = splitter(image) + self.assertEqual(len(output), len(expected)) + if "path" in expected_meta[0]: + self.assertTrue(output.meta["path"] == expected_meta[0]["path"]) + for output_patch, expected_patch, expected_patch_meta in zip(output, expected, expected_meta): + assert_allclose(output_patch, expected_patch, type_test=False) + if "path" in expected_meta[0]: + self.assertTrue(output_patch.meta["path"] == expected_patch_meta["path"]) + self.assertTrue(output_patch.meta["location"] == expected_patch_meta["location"]) if __name__ == "__main__": From c576a41d4c166f2202def20475aed985ad01cda9 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 6 Sep 2022 17:41:52 +0000 Subject: [PATCH 03/19] Add unittests for rand grid patch MetaTensor Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_rand_grid_patch.py | 38 +++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/tests/test_rand_grid_patch.py b/tests/test_rand_grid_patch.py index 3957dc1ce8..4ae2dbdd34 100644 --- a/tests/test_rand_grid_patch.py +++ b/tests/test_rand_grid_patch.py @@ -14,6 +14,7 @@ import numpy as np from parameterized import parameterized +from monai.data import MetaTensor from monai.transforms.spatial.array import RandGridPatch from monai.utils import set_determinism from tests.utils import TEST_NDARRAYS, assert_allclose @@ -57,6 +58,25 @@ ] TEST_CASE_10 = [{"patch_size": (2, 2), "min_offset": 0, "max_offset": 0, "threshold": 50.0}, A, [A11]] +TEST_CASE_MEAT_0 = [ + {"patch_size": (2, 2)}, + A, + [A11, A12, A21, A22], + [{"location": [0, 0]}, {"location": [0, 2]}, {"location": [2, 0]}, {"location": [2, 2]}], +] + +TEST_CASE_MEAT_1 = [ + {"patch_size": (2, 2)}, + MetaTensor(x=A, meta={"path": "path/to/file"}), + [A11, A12, A21, A22], + [ + {"location": [0, 0], "path": "path/to/file"}, + {"location": [0, 2], "path": "path/to/file"}, + {"location": [2, 0], "path": "path/to/file"}, + {"location": [2, 2], "path": "path/to/file"}, + ], +] + TEST_SINGLE = [] for p in TEST_NDARRAYS: TEST_SINGLE.append([p, *TEST_CASE_0]) @@ -78,10 +98,24 @@ def test_rand_grid_patch(self, in_type, input_parameters, image, expected): input_image = in_type(image) splitter = RandGridPatch(**input_parameters) splitter.set_random_state(1234) - output = list(splitter(input_image)) + output = splitter(input_image) self.assertEqual(len(output), len(expected)) for output_patch, expected_patch in zip(output, expected): - assert_allclose(output_patch[0], expected_patch, type_test=False) + assert_allclose(output_patch, expected_patch, type_test=False) + + @parameterized.expand([TEST_CASE_MEAT_0, TEST_CASE_MEAT_1]) + def test_rand_grid_patch_meta(self, input_parameters, image, expected, expected_meta): + splitter = RandGridPatch(**input_parameters) + splitter.set_random_state(1234) + output = splitter(image) + self.assertEqual(len(output), len(expected)) + if "path" in expected_meta[0]: + self.assertTrue(output.meta["path"] == expected_meta[0]["path"]) + for output_patch, expected_patch, expected_patch_meta in zip(output, expected, expected_meta): + assert_allclose(output_patch, expected_patch, type_test=False) + if "path" in expected_meta[0]: + self.assertTrue(output_patch.meta["path"] == expected_patch_meta["path"]) + self.assertTrue(output_patch.meta["location"] == expected_patch_meta["location"]) if __name__ == "__main__": From 35f47ebc50c71fbdc1318fd631ab131dbffc0430 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 6 Sep 2022 17:48:37 +0000 Subject: [PATCH 04/19] Update returns docstring Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/transforms/spatial/array.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 7b68c1ecf1..ca806dfb4f 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -3142,6 +3142,9 @@ class GridPatch(Transform): pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``. pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. + Returns: + MetaTensor: A MetaTensor consisting of a batch of all the patches with associated metadata + """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -3266,6 +3269,9 @@ class RandGridPatch(GridPatch, RandomizableTransform): pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``. pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. + Returns: + MetaTensor: A MetaTensor consisting of a batch of all the patches with associated metadata + """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] From 907c75a14528cc7745e871775621243eda7e2a86 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 6 Sep 2022 18:09:33 +0000 Subject: [PATCH 05/19] MetaTensor output for GridPatchd (and RandGridPatchd) Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/transforms/spatial/array.py | 6 ++-- monai/transforms/spatial/dictionary.py | 50 ++++++-------------------- 2 files changed, 14 insertions(+), 42 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index ca806dfb4f..0881d31431 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -63,7 +63,7 @@ pytorch_after, ) from monai.utils.deprecate_utils import deprecated_arg -from monai.utils.enums import GridPatchSort, PytorchPadMode, TraceKeys, TransformBackends +from monai.utils.enums import GridPatchSort, PytorchPadMode, TraceKeys, TransformBackends, WSIPatchKeys from monai.utils.misc import ImageMetaKey as Key from monai.utils.module import look_up_option from monai.utils.type_conversion import convert_data_type, get_equivalent_dtype, get_torch_dtype_from_string @@ -3241,7 +3241,9 @@ def __call__(self, array: NdarrayOrTensor): # Convert to MetaTensor metadata = array.meta if isinstance(array, MetaTensor) else {} - metadata["location"] = locations.T + metadata[WSIPatchKeys.LOCATION] = locations.T + metadata[WSIPatchKeys.COUNT] = len(locations) + metadata[WSIPatchKeys.SIZE] = self.patch_size output = MetaTensor(x=patched_image, meta=metadata) output.is_batch = True diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 720be615fa..eb1aa1a0b3 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -1853,23 +1853,9 @@ def __init__( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict]: d = dict(data) - original_spatial_shape = d[first(self.keys)].shape[1:] - output = [] - results = [self.patcher(d[key]) for key in self.keys] - num_patches = min(len(r) for r in results) - for patch in zip(*results): - new_dict = {k: v[0] for k, v in zip(self.keys, patch)} - # fill in the extra keys with unmodified data - for k in set(d.keys()).difference(set(self.keys)): - new_dict[k] = deepcopy(d[k]) - # fill additional metadata - new_dict["original_spatial_shape"] = original_spatial_shape - new_dict[WSIPatchKeys.LOCATION] = patch[0][1] # use the starting coordinate of the first item - new_dict[WSIPatchKeys.SIZE] = self.patcher.patch_size - new_dict[WSIPatchKeys.COUNT] = num_patches - new_dict["offset"] = self.patcher.offset - output.append(new_dict) - return output + for key in self.key_iterator(d): + d[key] = self.patcher(d[key]) + return d class RandGridPatchd(RandomizableTransform, MapTransform): @@ -1944,29 +1930,13 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict]: d = dict(data) - original_spatial_shape = d[first(self.keys)].shape[1:] - # all the keys share the same random noise - first_key: Union[Hashable, List] = self.first_key(d) - if first_key == []: - return [d] - self.patcher.randomize(d[first_key]) # type: ignore - results = [self.patcher(d[key], randomize=False) for key in self.keys] - - num_patches = min(len(r) for r in results) - output = [] - for patch in zip(*results): - new_dict = {k: v[0] for k, v in zip(self.keys, patch)} - # fill in the extra keys with unmodified data - for k in set(d.keys()).difference(set(self.keys)): - new_dict[k] = deepcopy(d[k]) - # fill additional metadata - new_dict["original_spatial_shape"] = original_spatial_shape - new_dict[WSIPatchKeys.LOCATION] = patch[0][1] # use the starting coordinate of the first item - new_dict[WSIPatchKeys.SIZE] = self.patcher.patch_size - new_dict[WSIPatchKeys.COUNT] = num_patches - new_dict["offset"] = self.patcher.offset - output.append(new_dict) - return output + # All the keys share the same random noise + for key in self.key_iterator(d): + self.patcher.randomize(d[key]) + break + for key in self.key_iterator(d): + d[key] = self.patcher(d[key], randomize=False) + return d SpatialResampleD = SpatialResampleDict = SpatialResampled From e353964e3dc59fff399e0b60b4215e766141b935 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 6 Sep 2022 18:10:18 +0000 Subject: [PATCH 06/19] Update unittests for grid patch dict Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_grid_patchd.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_grid_patchd.py b/tests/test_grid_patchd.py index 8f1e238b42..0f1bea5f8a 100644 --- a/tests/test_grid_patchd.py +++ b/tests/test_grid_patchd.py @@ -74,10 +74,10 @@ def test_grid_patchd(self, in_type, input_parameters, image_dict, expected): if k == image_key: input_dict[k] = in_type(v) splitter = GridPatchd(keys=image_key, **input_parameters) - output = list(splitter(input_dict)) - self.assertEqual(len(output), len(expected)) - for output_patch, expected_patch in zip(output, expected): - assert_allclose(output_patch[image_key], expected_patch, type_test=False) + output = splitter(input_dict) + self.assertEqual(len(output[image_key]), len(expected)) + for output_patch, expected_patch in zip(output[image_key], expected): + assert_allclose(output_patch, expected_patch, type_test=False) if __name__ == "__main__": From c98edc5d562242ad3fc4fca8d736ca32b27042dc Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 6 Sep 2022 18:10:27 +0000 Subject: [PATCH 07/19] Update unittests for rand grid patch dict Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_rand_grid_patchd.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_rand_grid_patchd.py b/tests/test_rand_grid_patchd.py index 656fbd9e36..4f3ec3bb6a 100644 --- a/tests/test_rand_grid_patchd.py +++ b/tests/test_rand_grid_patchd.py @@ -83,10 +83,10 @@ def test_rand_grid_patchd(self, in_type, input_parameters, image_dict, expected) input_dict[k] = in_type(v) splitter = RandGridPatchd(keys=image_key, **input_parameters) splitter.set_random_state(1234) - output = list(splitter(input_dict)) - self.assertEqual(len(output), len(expected)) - for output_patch, expected_patch in zip(output, expected): - assert_allclose(output_patch[image_key], expected_patch, type_test=False) + output = splitter(input_dict) + self.assertEqual(len(output[image_key]), len(expected)) + for output_patch, expected_patch in zip(output[image_key], expected): + assert_allclose(output_patch, expected_patch, type_test=False) if __name__ == "__main__": From 51286e87f428293155157ae025b3329b3bcd64c7 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Wed, 7 Sep 2022 14:57:36 +0000 Subject: [PATCH 08/19] Make get_default_affine staticmethod Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/data/meta_tensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 6c2b15ef5b..5a7d81ad8e 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -312,7 +312,8 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): except AttributeError: return NotImplemented - def get_default_affine(self, dtype=torch.float64) -> torch.Tensor: + @staticmethod + def get_default_affine(dtype=torch.float64) -> torch.Tensor: return torch.eye(4, device=torch.device("cpu"), dtype=dtype) def as_tensor(self) -> torch.Tensor: From 000b02a59c4a058e3e6705004a0ff1cf4cb2d9f4 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Wed, 7 Sep 2022 15:04:08 +0000 Subject: [PATCH 09/19] Update affine metadata and return types Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/transforms/spatial/array.py | 4 ++-- monai/transforms/spatial/dictionary.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 0881d31431..00a244effe 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -3240,10 +3240,10 @@ def __call__(self, array: NdarrayOrTensor): locations = np.pad(locations, [[0, padding], [0, 0]], constant_values=0) # Convert to MetaTensor - metadata = array.meta if isinstance(array, MetaTensor) else {} + metadata = array.meta if isinstance(array, MetaTensor) else MetaTensor.get_default_meta() metadata[WSIPatchKeys.LOCATION] = locations.T metadata[WSIPatchKeys.COUNT] = len(locations) - metadata[WSIPatchKeys.SIZE] = self.patch_size + metadata["affine"] = torch.stack([MetaTensor.get_default_affine()] * len(locations)) output = MetaTensor(x=patched_image, meta=metadata) output.is_batch = True diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index eb1aa1a0b3..13e10b4233 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -1851,7 +1851,7 @@ def __init__( **pad_kwargs, ) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.patcher(d[key]) @@ -1928,7 +1928,7 @@ def set_random_state( self.patcher.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) # All the keys share the same random noise for key in self.key_iterator(d): From f26c8fab5c492c0743f394c28ba52a22ca98d633 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Thu, 8 Sep 2022 17:00:42 +0000 Subject: [PATCH 10/19] Not explicitly set affine Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/transforms/spatial/array.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 00a244effe..ec234b2fa6 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -3243,7 +3243,6 @@ def __call__(self, array: NdarrayOrTensor): metadata = array.meta if isinstance(array, MetaTensor) else MetaTensor.get_default_meta() metadata[WSIPatchKeys.LOCATION] = locations.T metadata[WSIPatchKeys.COUNT] = len(locations) - metadata["affine"] = torch.stack([MetaTensor.get_default_affine()] * len(locations)) output = MetaTensor(x=patched_image, meta=metadata) output.is_batch = True From 9915d105f3fb4f79fbd0c1fabc67fef7e53ea679 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Sep 2022 17:07:49 +0000 Subject: [PATCH 11/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/spatial/dictionary.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 13e10b4233..bb6b0f1714 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -15,7 +15,6 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ -from copy import deepcopy from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np @@ -58,12 +57,10 @@ GridSamplePadMode, InterpolateMode, NumpyPadMode, - WSIPatchKeys, convert_to_tensor, ensure_tuple, ensure_tuple_rep, fall_back_tuple, - first, ) from monai.utils.deprecate_utils import deprecated_arg from monai.utils.enums import PytorchPadMode, TraceKeys From 824da296c253c7cc3d9ba3ab5beac08a236aa643 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Thu, 8 Sep 2022 17:45:09 +0000 Subject: [PATCH 12/19] set_track_meta in tests Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_grid_patch.py | 3 ++- tests/test_rand_grid_patch.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_grid_patch.py b/tests/test_grid_patch.py index 5eaa169403..48f75c6e1f 100644 --- a/tests/test_grid_patch.py +++ b/tests/test_grid_patch.py @@ -14,7 +14,7 @@ import numpy as np from parameterized import parameterized -from monai.data import MetaTensor +from monai.data import MetaTensor, set_track_meta from monai.transforms.spatial.array import GridPatch from tests.utils import TEST_NDARRAYS, assert_allclose @@ -96,6 +96,7 @@ def test_grid_patch(self, in_type, input_parameters, image, expected): @parameterized.expand([TEST_CASE_MEAT_0, TEST_CASE_MEAT_1]) def test_grid_patch_meta(self, input_parameters, image, expected, expected_meta): + set_track_meta(True) splitter = GridPatch(**input_parameters) output = splitter(image) self.assertEqual(len(output), len(expected)) diff --git a/tests/test_rand_grid_patch.py b/tests/test_rand_grid_patch.py index 4ae2dbdd34..e99718ea84 100644 --- a/tests/test_rand_grid_patch.py +++ b/tests/test_rand_grid_patch.py @@ -14,7 +14,7 @@ import numpy as np from parameterized import parameterized -from monai.data import MetaTensor +from monai.data import MetaTensor, set_track_meta from monai.transforms.spatial.array import RandGridPatch from monai.utils import set_determinism from tests.utils import TEST_NDARRAYS, assert_allclose @@ -105,6 +105,7 @@ def test_rand_grid_patch(self, in_type, input_parameters, image, expected): @parameterized.expand([TEST_CASE_MEAT_0, TEST_CASE_MEAT_1]) def test_rand_grid_patch_meta(self, input_parameters, image, expected, expected_meta): + set_track_meta(True) splitter = RandGridPatch(**input_parameters) splitter.set_random_state(1234) output = splitter(image) From 3cc4e917145d9d0803d25dcfc138d0619069054e Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Thu, 8 Sep 2022 20:07:13 +0000 Subject: [PATCH 13/19] Update tests Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_grid_patch.py | 3 ++- tests/test_rand_grid_patch.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_grid_patch.py b/tests/test_grid_patch.py index 48f75c6e1f..d4ee4723f8 100644 --- a/tests/test_grid_patch.py +++ b/tests/test_grid_patch.py @@ -103,10 +103,11 @@ def test_grid_patch_meta(self, input_parameters, image, expected, expected_meta) if "path" in expected_meta[0]: self.assertTrue(output.meta["path"] == expected_meta[0]["path"]) for output_patch, expected_patch, expected_patch_meta in zip(output, expected, expected_meta): + self.assertTrue(isinstance(output_patch, MetaTensor)) assert_allclose(output_patch, expected_patch, type_test=False) + self.assertTrue(output_patch.meta["location"] == expected_patch_meta["location"]) if "path" in expected_meta[0]: self.assertTrue(output_patch.meta["path"] == expected_patch_meta["path"]) - self.assertTrue(output_patch.meta["location"] == expected_patch_meta["location"]) if __name__ == "__main__": diff --git a/tests/test_rand_grid_patch.py b/tests/test_rand_grid_patch.py index e99718ea84..e2e2c20604 100644 --- a/tests/test_rand_grid_patch.py +++ b/tests/test_rand_grid_patch.py @@ -113,10 +113,11 @@ def test_rand_grid_patch_meta(self, input_parameters, image, expected, expected_ if "path" in expected_meta[0]: self.assertTrue(output.meta["path"] == expected_meta[0]["path"]) for output_patch, expected_patch, expected_patch_meta in zip(output, expected, expected_meta): + self.assertTrue(isinstance(output_patch, MetaTensor)) assert_allclose(output_patch, expected_patch, type_test=False) + self.assertTrue(output_patch.meta["location"] == expected_patch_meta["location"]) if "path" in expected_meta[0]: self.assertTrue(output_patch.meta["path"] == expected_patch_meta["path"]) - self.assertTrue(output_patch.meta["location"] == expected_patch_meta["location"]) if __name__ == "__main__": From 1319909234cc2d2b71683295d54c5c6ec490d80f Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Thu, 8 Sep 2022 20:11:21 +0000 Subject: [PATCH 14/19] Add to min tests Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/min_tests.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/min_tests.py b/tests/min_tests.py index 0f1e4e61ec..4371f3ad33 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -60,6 +60,7 @@ def run_testsuit(): "test_foreground_mask", "test_foreground_maskd", "test_global_mutual_information_loss", + "test_grid_patch", "test_handler_checkpoint_loader", "test_handler_checkpoint_saver", "test_handler_classification_saver", @@ -135,6 +136,7 @@ def run_testsuit(): "test_png_saver", "test_prepare_batch_default", "test_prepare_batch_extra_input", + "test_rand_grid_patch", "test_rand_rotate", "test_rand_rotated", "test_rand_zoom", From 9d9cc000c7e12be27f4ecfa5149cd5e910b60ec5 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 9 Sep 2022 13:47:25 +0000 Subject: [PATCH 15/19] Update spatial shape Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/transforms/spatial/array.py | 1 + tests/test_grid_patch.py | 3 ++- tests/test_rand_grid_patch.py | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index ec234b2fa6..30c4f246cc 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -3243,6 +3243,7 @@ def __call__(self, array: NdarrayOrTensor): metadata = array.meta if isinstance(array, MetaTensor) else MetaTensor.get_default_meta() metadata[WSIPatchKeys.LOCATION] = locations.T metadata[WSIPatchKeys.COUNT] = len(locations) + metadata["spatial_shape"] = np.tile(np.array(self.patch_size), (len(locations), 1)).T output = MetaTensor(x=patched_image, meta=metadata) output.is_batch = True diff --git a/tests/test_grid_patch.py b/tests/test_grid_patch.py index d4ee4723f8..95caf97fc6 100644 --- a/tests/test_grid_patch.py +++ b/tests/test_grid_patch.py @@ -103,9 +103,10 @@ def test_grid_patch_meta(self, input_parameters, image, expected, expected_meta) if "path" in expected_meta[0]: self.assertTrue(output.meta["path"] == expected_meta[0]["path"]) for output_patch, expected_patch, expected_patch_meta in zip(output, expected, expected_meta): - self.assertTrue(isinstance(output_patch, MetaTensor)) assert_allclose(output_patch, expected_patch, type_test=False) + self.assertTrue(isinstance(output_patch, MetaTensor)) self.assertTrue(output_patch.meta["location"] == expected_patch_meta["location"]) + self.assertTrue(output_patch.meta["spatial_shape"], list(output_patch.shape[1:])) if "path" in expected_meta[0]: self.assertTrue(output_patch.meta["path"] == expected_patch_meta["path"]) diff --git a/tests/test_rand_grid_patch.py b/tests/test_rand_grid_patch.py index e2e2c20604..17a4894dce 100644 --- a/tests/test_rand_grid_patch.py +++ b/tests/test_rand_grid_patch.py @@ -113,9 +113,10 @@ def test_rand_grid_patch_meta(self, input_parameters, image, expected, expected_ if "path" in expected_meta[0]: self.assertTrue(output.meta["path"] == expected_meta[0]["path"]) for output_patch, expected_patch, expected_patch_meta in zip(output, expected, expected_meta): - self.assertTrue(isinstance(output_patch, MetaTensor)) assert_allclose(output_patch, expected_patch, type_test=False) + self.assertTrue(isinstance(output_patch, MetaTensor)) self.assertTrue(output_patch.meta["location"] == expected_patch_meta["location"]) + self.assertTrue(output_patch.meta["spatial_shape"], list(output_patch.shape[1:])) if "path" in expected_meta[0]: self.assertTrue(output_patch.meta["path"] == expected_patch_meta["path"]) From f6cb8e5e1e474eaf5f39d76b83d8df053f935b74 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 9 Sep 2022 14:19:50 +0000 Subject: [PATCH 16/19] Update splitdimd with the option of list output Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/transforms/utility/dictionary.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index cde6bd8cc2..2039b3a5a5 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -384,6 +384,7 @@ def __init__( dim: int = 0, keepdim: bool = True, update_meta: bool = True, + list_output: bool = False, allow_missing_keys: bool = False, ) -> None: """ @@ -399,15 +400,34 @@ def __init__( dimension will be squeezed. update_meta: if `True`, copy `[key]_meta_dict` for each output and update affine to reflect the cropped image + list_output: it `True`, the output will be a list of dictionaries with the same keys as original. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) self.output_postfixes = output_postfixes self.splitter = SplitDim(dim, keepdim, update_meta) + self.list_output = list_output + if self.list_output is None and self.output_postfixes is not None: + raise ValueError("`output_postfixes` should not be provided when `list_output` is `True`.") - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__( + self, data: Mapping[Hashable, torch.Tensor] + ) -> Union[Dict[Hashable, torch.Tensor], List[Dict[Hashable, torch.Tensor]]]: d = dict(data) - for key in self.key_iterator(d): + all_keys = list(set(self.key_iterator(d))) + + if self.list_output: + output = [] + results = [self.splitter(d[key]) for key in all_keys] + for row in zip(*results): + new_dict = {k: v for k, v in zip(all_keys, row)} + # fill in the extra keys with unmodified data + for k in set(d.keys()).difference(set(all_keys)): + new_dict[k] = deepcopy(d[k]) + output.append(new_dict) + return output + + for key in all_keys: rets = self.splitter(d[key]) postfixes: Sequence = list(range(len(rets))) if self.output_postfixes is None else self.output_postfixes if len(postfixes) != len(rets): From 3564b38bd641ceb185f5548dd2cc80ccffb23e87 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 9 Sep 2022 14:20:06 +0000 Subject: [PATCH 17/19] Upadate unittest to include list_output Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_splitdimd.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/tests/test_splitdimd.py b/tests/test_splitdimd.py index 1e39439b86..ee8cc043e4 100644 --- a/tests/test_splitdimd.py +++ b/tests/test_splitdimd.py @@ -25,7 +25,8 @@ for p in TEST_NDARRAYS: for keepdim in (True, False): for update_meta in (True, False): - TESTS.append((keepdim, p, update_meta)) + for list_output in (True, False): + TESTS.append((keepdim, p, update_meta, list_output)) class TestSplitDimd(unittest.TestCase): @@ -39,14 +40,18 @@ def setUpClass(cls): cls.data: MetaTensor = loader(data) @parameterized.expand(TESTS) - def test_correct(self, keepdim, im_type, update_meta): + def test_correct(self, keepdim, im_type, update_meta, list_output): data = deepcopy(self.data) data["i"] = im_type(data["i"]) arr = data["i"] for dim in range(arr.ndim): - out = SplitDimd("i", dim=dim, keepdim=keepdim, update_meta=update_meta)(data) - self.assertIsInstance(out, dict) - self.assertEqual(len(out.keys()), len(data.keys()) + arr.shape[dim]) + out = SplitDimd("i", dim=dim, keepdim=keepdim, update_meta=update_meta, list_output=list_output)(data) + if list_output: + self.assertIsInstance(out, list) + self.assertEqual(len(out), arr.shape[dim]) + else: + self.assertIsInstance(out, dict) + self.assertEqual(len(out.keys()), len(data.keys()) + arr.shape[dim]) # if updating metadata, pick some random points and # check same world coordinates between input and output if update_meta: @@ -55,14 +60,20 @@ def test_correct(self, keepdim, im_type, update_meta): split_im_idx = idx[dim] split_idx = deepcopy(idx) split_idx[dim] = 0 - split_im = out[f"i_{split_im_idx}"] + if list_output: + split_im = out[split_im_idx]["i"] + else: + split_im = out[f"i_{split_im_idx}"] if isinstance(data, MetaTensor) and isinstance(split_im, MetaTensor): # idx[1:] to remove channel and then add 1 for 4th element real_world = data.affine @ torch.tensor(idx[1:] + [1]).double() real_world2 = split_im.affine @ torch.tensor(split_idx[1:] + [1]).double() assert_allclose(real_world, real_world2) - out = out["i_0"] + if list_output: + out = out[0]["i"] + else: + out = out["i_0"] expected_ndim = arr.ndim if keepdim else arr.ndim - 1 self.assertEqual(out.ndim, expected_ndim) # assert is a shallow copy From 0c07eddf8322c78cd7dd3c7ac7886a9ac9382526 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 9 Sep 2022 16:01:02 +0000 Subject: [PATCH 18/19] Skip for older torch versions Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_grid_patch.py | 3 ++- tests/test_rand_grid_patch.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_grid_patch.py b/tests/test_grid_patch.py index 95caf97fc6..90a312a93f 100644 --- a/tests/test_grid_patch.py +++ b/tests/test_grid_patch.py @@ -16,7 +16,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms.spatial.array import GridPatch -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS, assert_allclose, SkipIfBeforePyTorchVersion A = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1) A11 = A[:, :2, :2] @@ -95,6 +95,7 @@ def test_grid_patch(self, in_type, input_parameters, image, expected): assert_allclose(output_patch, expected_patch, type_test=False) @parameterized.expand([TEST_CASE_MEAT_0, TEST_CASE_MEAT_1]) + @SkipIfBeforePyTorchVersion((1, 9, 1)) def test_grid_patch_meta(self, input_parameters, image, expected, expected_meta): set_track_meta(True) splitter = GridPatch(**input_parameters) diff --git a/tests/test_rand_grid_patch.py b/tests/test_rand_grid_patch.py index 17a4894dce..3a614b320d 100644 --- a/tests/test_rand_grid_patch.py +++ b/tests/test_rand_grid_patch.py @@ -17,7 +17,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms.spatial.array import RandGridPatch from monai.utils import set_determinism -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS, assert_allclose, SkipIfBeforePyTorchVersion set_determinism(1234) @@ -104,6 +104,7 @@ def test_rand_grid_patch(self, in_type, input_parameters, image, expected): assert_allclose(output_patch, expected_patch, type_test=False) @parameterized.expand([TEST_CASE_MEAT_0, TEST_CASE_MEAT_1]) + @SkipIfBeforePyTorchVersion((1, 9, 1)) def test_rand_grid_patch_meta(self, input_parameters, image, expected, expected_meta): set_track_meta(True) splitter = RandGridPatch(**input_parameters) From 14870c41afacd33ad13f2bf007cfdf79811e86f7 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 9 Sep 2022 16:18:57 +0000 Subject: [PATCH 19/19] formatting Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_grid_patch.py | 2 +- tests/test_rand_grid_patch.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_grid_patch.py b/tests/test_grid_patch.py index 90a312a93f..03b33147dd 100644 --- a/tests/test_grid_patch.py +++ b/tests/test_grid_patch.py @@ -16,7 +16,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms.spatial.array import GridPatch -from tests.utils import TEST_NDARRAYS, assert_allclose, SkipIfBeforePyTorchVersion +from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose A = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1) A11 = A[:, :2, :2] diff --git a/tests/test_rand_grid_patch.py b/tests/test_rand_grid_patch.py index 3a614b320d..417915fbab 100644 --- a/tests/test_rand_grid_patch.py +++ b/tests/test_rand_grid_patch.py @@ -17,7 +17,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms.spatial.array import RandGridPatch from monai.utils import set_determinism -from tests.utils import TEST_NDARRAYS, assert_allclose, SkipIfBeforePyTorchVersion +from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose set_determinism(1234)