Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
d1829f3
Add random sorting option to sort_fn argument in GridPatch transform
jmnolte Jul 7, 2023
c2f4137
Add random sorting option to sort_fn argument in GridPatch transform
jmnolte Jul 7, 2023
e8eecfb
Merge branch '6699-random-sort-GridPatch' of https://github.com/jmnol…
jmnolte Jul 7, 2023
4ab6070
DCO Remediation Commit for jmnolte <jakob.nolte@web.de>
jmnolte Jul 7, 2023
f960336
Merge branch 'dev' into 6699-random-sort-GridPatch
jmnolte Jul 10, 2023
044a8cf
Refactor random selection to ensure reproducible random states
jmnolte Jul 10, 2023
6854644
Merge branch origin/6699-random-sort-GridPatch
jmnolte Jul 10, 2023
5b9d5ef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 10, 2023
6ac9175
Fix num_patches in init of RandGridPatch class
jmnolte Jul 10, 2023
1265109
Merge branch origin/6699-random-sort-GridPatch
jmnolte Jul 10, 2023
3cd7566
6699 Override parent filter count method in parent class
jmnolte Jul 12, 2023
bcdb1a6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2023
ab5d849
Update monai/transforms/spatial/array.py
jmnolte Jul 12, 2023
95745e8
Update monai/transforms/spatial/array.py
jmnolte Jul 12, 2023
b322fdf
Merge branch 'dev' into 6699-random-sort-GridPatch
jmnolte Jul 12, 2023
8675f57
6699 Fix type error
jmnolte Jul 12, 2023
c4e780b
6699 Update RandGridPatchd docstring and fix typing error in RandGrid…
jmnolte Jul 13, 2023
c0b29c5
Merge remote-tracking branch 'upstream/dev' into 6699-random-sort-Gri…
wyli Jul 13, 2023
65e3450
type ignore idx type
wyli Jul 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3278,7 +3278,7 @@ def filter_count(self, image_np: NdarrayOrTensor, locations: np.ndarray) -> tupl
elif self.sort_fn == GridPatchSort.MAX:
idx = argsort(-image_np.sum(tuple(range(1, n_dims))))
else:
raise ValueError(f'`sort_fn` should be either "min", "max" or None! {self.sort_fn} provided!')
raise ValueError(f'`sort_fn` should be either "min", "max", or None! {self.sort_fn} provided!')
idx = idx[: self.num_patches]
idx_np = convert_data_type(idx, np.ndarray)[0]
image_np = image_np[idx]
Expand Down Expand Up @@ -3371,7 +3371,7 @@ class RandGridPatch(GridPatch, RandomizableTransform, MultiSampleTrait):
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`),
lowest values (`"min"`), or in their default order (`None`). Default to None.
lowest values (`"min"`), in random ("random"), or in their default order (`None`). Default to None.
threshold: a value to keep only the patches whose sum of intensities are less than the threshold.
Defaults to no filtering.
pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries.
Expand Down Expand Up @@ -3423,6 +3423,8 @@ def __init__(
)
self.min_offset = min_offset
self.max_offset = max_offset
self.num_patches = num_patches
self.sort_fn = sort_fn

def randomize(self, array):
if self.min_offset is None:
Expand All @@ -3436,6 +3438,18 @@ def randomize(self, array):

self.offset = tuple(self.R.randint(low=low, high=high + 1) for low, high in zip(min_offset, max_offset))

def filter_count(self, image_np: NdarrayOrTensor, locations: np.ndarray) -> tuple[NdarrayOrTensor, np.ndarray]:
if self.sort_fn == GridPatchSort.RANDOM:
idx = self.R.permutation(image_np.shape[0])
idx = idx[: self.num_patches]
idx_np = convert_data_type(idx, np.ndarray)[0]
image_np = image_np[idx] # type: ignore
locations = locations[idx_np]
return image_np, locations
elif self.sort_fn not in (None, GridPatchSort.MIN, GridPatchSort.MAX):
raise ValueError(f'`sort_fn` should be either "min", "max", "random" or None! {self.sort_fn} provided!')
return super().filter_count(image_np, locations)

def __call__(self, array: NdarrayOrTensor, randomize: bool = True):
if randomize:
self.randomize(array)
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -2436,7 +2436,7 @@ class RandGridPatchd(RandomizableTransform, MapTransform, MultiSampleTrait):
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`),
lowest values (`"min"`), or in their default order (`None`). Default to None.
lowest values (`"min"`), in random ("random"), or in their default order (`None`). Default to None.
threshold: a value to keep only the patches whose sum of intensities are less than the threshold.
Defaults to no filtering.
pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries.
Expand Down
10 changes: 10 additions & 0 deletions tests/test_rand_grid_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@
[np.pad(A[:, :2, 1:], ((0, 0), (1, 0), (0, 0)), mode="constant", constant_values=255)],
]
TEST_CASE_10 = [{"patch_size": (2, 2), "min_offset": 0, "max_offset": 0, "threshold": 50.0}, A, [A11]]
TEST_CASE_11 = [{"patch_size": (2, 2), "sort_fn": "random", "num_patches": 2}, A, [A11, A12]]
TEST_CASE_12 = [{"patch_size": (2, 2), "sort_fn": "random", "num_patches": 4}, A, [A11, A12, A21, A22]]
TEST_CASE_13 = [
{"patch_size": (2, 2), "min_offset": 0, "max_offset": 1, "num_patches": 1, "sort_fn": "random"},
A,
[A[:, 1:3, 1:3]],
]

TEST_CASE_META_0 = [
{"patch_size": (2, 2)},
Expand Down Expand Up @@ -92,6 +99,9 @@
TEST_SINGLE.append([p, *TEST_CASE_8])
TEST_SINGLE.append([p, *TEST_CASE_9])
TEST_SINGLE.append([p, *TEST_CASE_10])
TEST_SINGLE.append([p, *TEST_CASE_11])
TEST_SINGLE.append([p, *TEST_CASE_12])
TEST_SINGLE.append([p, *TEST_CASE_13])


class TestRandGridPatch(unittest.TestCase):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_rand_grid_patchd.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@
[np.pad(A[:, :2, 1:], ((0, 0), (1, 0), (0, 0)), mode="constant", constant_values=255)],
]
TEST_CASE_10 = [{"patch_size": (2, 2), "min_offset": 0, "max_offset": 0, "threshold": 50.0}, {"image": A}, [A11]]
TEST_CASE_11 = [{"patch_size": (2, 2), "sort_fn": "random", "num_patches": 2}, {"image": A}, [A11, A12]]
TEST_CASE_12 = [{"patch_size": (2, 2), "sort_fn": "random", "num_patches": 4}, {"image": A}, [A11, A12, A21, A22]]
TEST_CASE_13 = [
{"patch_size": (2, 2), "min_offset": 0, "max_offset": 1, "num_patches": 1, "sort_fn": "random"},
{"image": A},
[A[:, 1:3, 1:3]],
]

TEST_SINGLE = []
for p in TEST_NDARRAYS:
Expand All @@ -72,6 +79,9 @@
TEST_SINGLE.append([p, *TEST_CASE_8])
TEST_SINGLE.append([p, *TEST_CASE_9])
TEST_SINGLE.append([p, *TEST_CASE_10])
TEST_SINGLE.append([p, *TEST_CASE_11])
TEST_SINGLE.append([p, *TEST_CASE_12])
TEST_SINGLE.append([p, *TEST_CASE_13])


class TestRandGridPatchd(unittest.TestCase):
Expand Down