Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
44c9005
MetaTensor output for GridPatch (and RandGridPatch)
bhashemian Sep 6, 2022
612f5e8
Add unittests for grid patch MetaTensor
bhashemian Sep 6, 2022
c576a41
Add unittests for rand grid patch MetaTensor
bhashemian Sep 6, 2022
35f47eb
Update returns docstring
bhashemian Sep 6, 2022
907c75a
MetaTensor output for GridPatchd (and RandGridPatchd)
bhashemian Sep 6, 2022
e353964
Update unittests for grid patch dict
bhashemian Sep 6, 2022
c98edc5
Update unittests for rand grid patch dict
bhashemian Sep 6, 2022
51286e8
Make get_default_affine staticmethod
bhashemian Sep 7, 2022
000b02a
Update affine metadata and return types
bhashemian Sep 7, 2022
d6abd5d
Merge branch 'dev' of github.com:Project-MONAI/MONAI into metatensor-…
bhashemian Sep 7, 2022
7f6f648
Merge branch 'dev' of github.com:Project-MONAI/MONAI into metatensor-…
bhashemian Sep 8, 2022
a30ef5e
Merge branch 'dev' of github.com:Project-MONAI/MONAI into metatensor-…
bhashemian Sep 8, 2022
f26c8fa
Not explicitly set affine
bhashemian Sep 8, 2022
9915d10
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2022
824da29
set_track_meta in tests
bhashemian Sep 8, 2022
3cc4e91
Update tests
bhashemian Sep 8, 2022
1319909
Add to min tests
bhashemian Sep 8, 2022
9d9cc00
Update spatial shape
bhashemian Sep 9, 2022
f6cb8e5
Update splitdimd with the option of list output
bhashemian Sep 9, 2022
3564b38
Upadate unittest to include list_output
bhashemian Sep 9, 2022
8dce85c
Merge branch 'dev' into metatensor-gridpatch
bhashemian Sep 9, 2022
0c07edd
Skip for older torch versions
bhashemian Sep 9, 2022
8a7dd77
Merge branch 'listoutput-splitdimd' of github.com:drbeh/MONAI into me…
bhashemian Sep 9, 2022
14870c4
formatting
bhashemian Sep 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
except AttributeError:
return NotImplemented

def get_default_affine(self, dtype=torch.float64) -> torch.Tensor:
@staticmethod
def get_default_affine(dtype=torch.float64) -> torch.Tensor:
return torch.eye(4, device=torch.device("cpu"), dtype=dtype)

def as_tensor(self) -> torch.Tensor:
Expand Down
39 changes: 24 additions & 15 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
pytorch_after,
)
from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.enums import GridPatchSort, PytorchPadMode, TraceKeys, TransformBackends
from monai.utils.enums import GridPatchSort, PytorchPadMode, TraceKeys, TransformBackends, WSIPatchKeys
from monai.utils.misc import ImageMetaKey as Key
from monai.utils.module import look_up_option
from monai.utils.type_conversion import convert_data_type, get_equivalent_dtype, get_torch_dtype_from_string
Expand Down Expand Up @@ -3142,6 +3142,9 @@ class GridPatch(Transform):
pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``.
pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.

Returns:
MetaTensor: A MetaTensor consisting of a batch of all the patches with associated metadata

"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
Expand Down Expand Up @@ -3225,21 +3228,24 @@ def __call__(self, array: NdarrayOrTensor):
elif self.threshold:
patched_image, locations = self.filter_threshold(patched_image, locations)

# Convert to original data type
output = list(
zip(
convert_to_dst_type(src=patched_image, dst=array)[0],
convert_to_dst_type(src=locations, dst=array, dtype=int)[0],
)
)

# Pad the patch list to have the requested number of patches
if self.num_patches and len(output) < self.num_patches:
patch = convert_to_dst_type(
src=np.full((array.shape[0], *self.patch_size), self.pad_kwargs.get("constant_values", 0)), dst=array
)[0]
start_location = convert_to_dst_type(src=np.zeros(len(self.patch_size)), dst=array)[0]
output += [(patch, start_location)] * (self.num_patches - len(output))
if self.num_patches:
padding = self.num_patches - len(patched_image)
if padding > 0:
patched_image = np.pad(
patched_image,
[[0, padding], [0, 0]] + [[0, 0]] * len(self.patch_size),
constant_values=self.pad_kwargs.get("constant_values", 0),
)
locations = np.pad(locations, [[0, padding], [0, 0]], constant_values=0)

# Convert to MetaTensor
metadata = array.meta if isinstance(array, MetaTensor) else MetaTensor.get_default_meta()
metadata[WSIPatchKeys.LOCATION] = locations.T
metadata[WSIPatchKeys.COUNT] = len(locations)
metadata["spatial_shape"] = np.tile(np.array(self.patch_size), (len(locations), 1)).T
output = MetaTensor(x=patched_image, meta=metadata)
output.is_batch = True

return output

Expand All @@ -3265,6 +3271,9 @@ class RandGridPatch(GridPatch, RandomizableTransform):
pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``.
pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.

Returns:
MetaTensor: A MetaTensor consisting of a batch of all the patches with associated metadata

"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
Expand Down
57 changes: 12 additions & 45 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
Class names are ended with 'd' to denote dictionary-based transforms.
"""

from copy import deepcopy
from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -58,12 +57,10 @@
GridSamplePadMode,
InterpolateMode,
NumpyPadMode,
WSIPatchKeys,
convert_to_tensor,
ensure_tuple,
ensure_tuple_rep,
fall_back_tuple,
first,
)
from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.enums import PytorchPadMode, TraceKeys
Expand Down Expand Up @@ -1851,25 +1848,11 @@ def __init__(
**pad_kwargs,
)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
original_spatial_shape = d[first(self.keys)].shape[1:]
output = []
results = [self.patcher(d[key]) for key in self.keys]
num_patches = min(len(r) for r in results)
for patch in zip(*results):
new_dict = {k: v[0] for k, v in zip(self.keys, patch)}
# fill in the extra keys with unmodified data
for k in set(d.keys()).difference(set(self.keys)):
new_dict[k] = deepcopy(d[k])
# fill additional metadata
new_dict["original_spatial_shape"] = original_spatial_shape
new_dict[WSIPatchKeys.LOCATION] = patch[0][1] # use the starting coordinate of the first item
new_dict[WSIPatchKeys.SIZE] = self.patcher.patch_size
new_dict[WSIPatchKeys.COUNT] = num_patches
new_dict["offset"] = self.patcher.offset
output.append(new_dict)
return output
for key in self.key_iterator(d):
d[key] = self.patcher(d[key])
return d


class RandGridPatchd(RandomizableTransform, MapTransform):
Expand Down Expand Up @@ -1942,31 +1925,15 @@ def set_random_state(
self.patcher.set_random_state(seed, state)
return self

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
original_spatial_shape = d[first(self.keys)].shape[1:]
# all the keys share the same random noise
first_key: Union[Hashable, List] = self.first_key(d)
if first_key == []:
return [d]
self.patcher.randomize(d[first_key]) # type: ignore
results = [self.patcher(d[key], randomize=False) for key in self.keys]

num_patches = min(len(r) for r in results)
output = []
for patch in zip(*results):
new_dict = {k: v[0] for k, v in zip(self.keys, patch)}
# fill in the extra keys with unmodified data
for k in set(d.keys()).difference(set(self.keys)):
new_dict[k] = deepcopy(d[k])
# fill additional metadata
new_dict["original_spatial_shape"] = original_spatial_shape
new_dict[WSIPatchKeys.LOCATION] = patch[0][1] # use the starting coordinate of the first item
new_dict[WSIPatchKeys.SIZE] = self.patcher.patch_size
new_dict[WSIPatchKeys.COUNT] = num_patches
new_dict["offset"] = self.patcher.offset
output.append(new_dict)
return output
# All the keys share the same random noise
for key in self.key_iterator(d):
self.patcher.randomize(d[key])
break
for key in self.key_iterator(d):
d[key] = self.patcher(d[key], randomize=False)
return d


SpatialResampleD = SpatialResampleDict = SpatialResampled
Expand Down
24 changes: 22 additions & 2 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def __init__(
dim: int = 0,
keepdim: bool = True,
update_meta: bool = True,
list_output: bool = False,
allow_missing_keys: bool = False,
) -> None:
"""
Expand All @@ -399,15 +400,34 @@ def __init__(
dimension will be squeezed.
update_meta: if `True`, copy `[key]_meta_dict` for each output and update affine to
reflect the cropped image
list_output: it `True`, the output will be a list of dictionaries with the same keys as original.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
self.output_postfixes = output_postfixes
self.splitter = SplitDim(dim, keepdim, update_meta)
self.list_output = list_output
if self.list_output is None and self.output_postfixes is not None:
raise ValueError("`output_postfixes` should not be provided when `list_output` is `True`.")

def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
def __call__(
self, data: Mapping[Hashable, torch.Tensor]
) -> Union[Dict[Hashable, torch.Tensor], List[Dict[Hashable, torch.Tensor]]]:
d = dict(data)
for key in self.key_iterator(d):
all_keys = list(set(self.key_iterator(d)))

if self.list_output:
output = []
results = [self.splitter(d[key]) for key in all_keys]
for row in zip(*results):
new_dict = {k: v for k, v in zip(all_keys, row)}
# fill in the extra keys with unmodified data
for k in set(d.keys()).difference(set(all_keys)):
new_dict[k] = deepcopy(d[k])
output.append(new_dict)
return output

for key in all_keys:
rets = self.splitter(d[key])
postfixes: Sequence = list(range(len(rets))) if self.output_postfixes is None else self.output_postfixes
if len(postfixes) != len(rets):
Expand Down
2 changes: 2 additions & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def run_testsuit():
"test_foreground_mask",
"test_foreground_maskd",
"test_global_mutual_information_loss",
"test_grid_patch",
"test_handler_checkpoint_loader",
"test_handler_checkpoint_saver",
"test_handler_classification_saver",
Expand Down Expand Up @@ -135,6 +136,7 @@ def run_testsuit():
"test_png_saver",
"test_prepare_batch_default",
"test_prepare_batch_extra_input",
"test_rand_grid_patch",
"test_rand_rotate",
"test_rand_rotated",
"test_rand_zoom",
Expand Down
74 changes: 55 additions & 19 deletions tests/test_grid_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
import numpy as np
from parameterized import parameterized

from monai.data import MetaTensor, set_track_meta
from monai.transforms.spatial.array import GridPatch
from tests.utils import TEST_NDARRAYS, assert_allclose
from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose

A = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1)
A11 = A[:, :2, :2]
Expand Down Expand Up @@ -46,34 +47,69 @@
]
TEST_CASE_13 = [{"patch_size": (2, 2), "threshold": 50.0}, A, [A11]]

TEST_CASE_MEAT_0 = [
{"patch_size": (2, 2)},
A,
[A11, A12, A21, A22],
[{"location": [0, 0]}, {"location": [0, 2]}, {"location": [2, 0]}, {"location": [2, 2]}],
]

TEST_CASE_MEAT_1 = [
{"patch_size": (2, 2)},
MetaTensor(x=A, meta={"path": "path/to/file"}),
[A11, A12, A21, A22],
[
{"location": [0, 0], "path": "path/to/file"},
{"location": [0, 2], "path": "path/to/file"},
{"location": [2, 0], "path": "path/to/file"},
{"location": [2, 2], "path": "path/to/file"},
],
]

TEST_SINGLE = []
TEST_CASES = []
for p in TEST_NDARRAYS:
TEST_SINGLE.append([p, *TEST_CASE_0])
TEST_SINGLE.append([p, *TEST_CASE_1])
TEST_SINGLE.append([p, *TEST_CASE_2])
TEST_SINGLE.append([p, *TEST_CASE_3])
TEST_SINGLE.append([p, *TEST_CASE_4])
TEST_SINGLE.append([p, *TEST_CASE_5])
TEST_SINGLE.append([p, *TEST_CASE_6])
TEST_SINGLE.append([p, *TEST_CASE_7])
TEST_SINGLE.append([p, *TEST_CASE_8])
TEST_SINGLE.append([p, *TEST_CASE_9])
TEST_SINGLE.append([p, *TEST_CASE_10])
TEST_SINGLE.append([p, *TEST_CASE_11])
TEST_SINGLE.append([p, *TEST_CASE_12])
TEST_SINGLE.append([p, *TEST_CASE_13])
TEST_CASES.append([p, *TEST_CASE_0])
TEST_CASES.append([p, *TEST_CASE_1])
TEST_CASES.append([p, *TEST_CASE_2])
TEST_CASES.append([p, *TEST_CASE_3])
TEST_CASES.append([p, *TEST_CASE_4])
TEST_CASES.append([p, *TEST_CASE_5])
TEST_CASES.append([p, *TEST_CASE_6])
TEST_CASES.append([p, *TEST_CASE_7])
TEST_CASES.append([p, *TEST_CASE_8])
TEST_CASES.append([p, *TEST_CASE_9])
TEST_CASES.append([p, *TEST_CASE_10])
TEST_CASES.append([p, *TEST_CASE_11])
TEST_CASES.append([p, *TEST_CASE_12])
TEST_CASES.append([p, *TEST_CASE_13])


class TestGridPatch(unittest.TestCase):
@parameterized.expand(TEST_SINGLE)
@parameterized.expand(TEST_CASES)
def test_grid_patch(self, in_type, input_parameters, image, expected):
input_image = in_type(image)
splitter = GridPatch(**input_parameters)
output = list(splitter(input_image))
output = splitter(input_image)
self.assertEqual(len(output), len(expected))
for output_patch, expected_patch in zip(output, expected):
assert_allclose(output_patch[0], expected_patch, type_test=False)
assert_allclose(output_patch, expected_patch, type_test=False)

@parameterized.expand([TEST_CASE_MEAT_0, TEST_CASE_MEAT_1])
@SkipIfBeforePyTorchVersion((1, 9, 1))
def test_grid_patch_meta(self, input_parameters, image, expected, expected_meta):
set_track_meta(True)
splitter = GridPatch(**input_parameters)
output = splitter(image)
self.assertEqual(len(output), len(expected))
if "path" in expected_meta[0]:
self.assertTrue(output.meta["path"] == expected_meta[0]["path"])
for output_patch, expected_patch, expected_patch_meta in zip(output, expected, expected_meta):
assert_allclose(output_patch, expected_patch, type_test=False)
self.assertTrue(isinstance(output_patch, MetaTensor))
self.assertTrue(output_patch.meta["location"] == expected_patch_meta["location"])
self.assertTrue(output_patch.meta["spatial_shape"], list(output_patch.shape[1:]))
if "path" in expected_meta[0]:
self.assertTrue(output_patch.meta["path"] == expected_patch_meta["path"])


if __name__ == "__main__":
Expand Down
8 changes: 4 additions & 4 deletions tests/test_grid_patchd.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def test_grid_patchd(self, in_type, input_parameters, image_dict, expected):
if k == image_key:
input_dict[k] = in_type(v)
splitter = GridPatchd(keys=image_key, **input_parameters)
output = list(splitter(input_dict))
self.assertEqual(len(output), len(expected))
for output_patch, expected_patch in zip(output, expected):
assert_allclose(output_patch[image_key], expected_patch, type_test=False)
output = splitter(input_dict)
self.assertEqual(len(output[image_key]), len(expected))
for output_patch, expected_patch in zip(output[image_key], expected):
assert_allclose(output_patch, expected_patch, type_test=False)


if __name__ == "__main__":
Expand Down
Loading