diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index ec38dfaffe..d37e69650c 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -3278,7 +3278,7 @@ def filter_count(self, image_np: NdarrayOrTensor, locations: np.ndarray) -> tupl elif self.sort_fn == GridPatchSort.MAX: idx = argsort(-image_np.sum(tuple(range(1, n_dims)))) else: - raise ValueError(f'`sort_fn` should be either "min", "max" or None! {self.sort_fn} provided!') + raise ValueError(f'`sort_fn` should be either "min", "max", or None! {self.sort_fn} provided!') idx = idx[: self.num_patches] idx_np = convert_data_type(idx, np.ndarray)[0] image_np = image_np[idx] @@ -3371,7 +3371,7 @@ class RandGridPatch(GridPatch, RandomizableTransform, MultiSampleTrait): overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`), - lowest values (`"min"`), or in their default order (`None`). Default to None. + lowest values (`"min"`), in random ("random"), or in their default order (`None`). Default to None. threshold: a value to keep only the patches whose sum of intensities are less than the threshold. Defaults to no filtering. pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries. @@ -3423,6 +3423,8 @@ def __init__( ) self.min_offset = min_offset self.max_offset = max_offset + self.num_patches = num_patches + self.sort_fn = sort_fn def randomize(self, array): if self.min_offset is None: @@ -3436,6 +3438,18 @@ def randomize(self, array): self.offset = tuple(self.R.randint(low=low, high=high + 1) for low, high in zip(min_offset, max_offset)) + def filter_count(self, image_np: NdarrayOrTensor, locations: np.ndarray) -> tuple[NdarrayOrTensor, np.ndarray]: + if self.sort_fn == GridPatchSort.RANDOM: + idx = self.R.permutation(image_np.shape[0]) + idx = idx[: self.num_patches] + idx_np = convert_data_type(idx, np.ndarray)[0] + image_np = image_np[idx] # type: ignore + locations = locations[idx_np] + return image_np, locations + elif self.sort_fn not in (None, GridPatchSort.MIN, GridPatchSort.MAX): + raise ValueError(f'`sort_fn` should be either "min", "max", "random" or None! {self.sort_fn} provided!') + return super().filter_count(image_np, locations) + def __call__(self, array: NdarrayOrTensor, randomize: bool = True): if randomize: self.randomize(array) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 4ba5849c46..79742f0582 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -2436,7 +2436,7 @@ class RandGridPatchd(RandomizableTransform, MapTransform, MultiSampleTrait): overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`), - lowest values (`"min"`), or in their default order (`None`). Default to None. + lowest values (`"min"`), in random ("random"), or in their default order (`None`). Default to None. threshold: a value to keep only the patches whose sum of intensities are less than the threshold. Defaults to no filtering. pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries. diff --git a/tests/test_rand_grid_patch.py b/tests/test_rand_grid_patch.py index a118a0f420..494330584a 100644 --- a/tests/test_rand_grid_patch.py +++ b/tests/test_rand_grid_patch.py @@ -59,6 +59,13 @@ [np.pad(A[:, :2, 1:], ((0, 0), (1, 0), (0, 0)), mode="constant", constant_values=255)], ] TEST_CASE_10 = [{"patch_size": (2, 2), "min_offset": 0, "max_offset": 0, "threshold": 50.0}, A, [A11]] +TEST_CASE_11 = [{"patch_size": (2, 2), "sort_fn": "random", "num_patches": 2}, A, [A11, A12]] +TEST_CASE_12 = [{"patch_size": (2, 2), "sort_fn": "random", "num_patches": 4}, A, [A11, A12, A21, A22]] +TEST_CASE_13 = [ + {"patch_size": (2, 2), "min_offset": 0, "max_offset": 1, "num_patches": 1, "sort_fn": "random"}, + A, + [A[:, 1:3, 1:3]], +] TEST_CASE_META_0 = [ {"patch_size": (2, 2)}, @@ -92,6 +99,9 @@ TEST_SINGLE.append([p, *TEST_CASE_8]) TEST_SINGLE.append([p, *TEST_CASE_9]) TEST_SINGLE.append([p, *TEST_CASE_10]) + TEST_SINGLE.append([p, *TEST_CASE_11]) + TEST_SINGLE.append([p, *TEST_CASE_12]) + TEST_SINGLE.append([p, *TEST_CASE_13]) class TestRandGridPatch(unittest.TestCase): diff --git a/tests/test_rand_grid_patchd.py b/tests/test_rand_grid_patchd.py index e93a6f6cd1..23ca4a7881 100644 --- a/tests/test_rand_grid_patchd.py +++ b/tests/test_rand_grid_patchd.py @@ -58,6 +58,13 @@ [np.pad(A[:, :2, 1:], ((0, 0), (1, 0), (0, 0)), mode="constant", constant_values=255)], ] TEST_CASE_10 = [{"patch_size": (2, 2), "min_offset": 0, "max_offset": 0, "threshold": 50.0}, {"image": A}, [A11]] +TEST_CASE_11 = [{"patch_size": (2, 2), "sort_fn": "random", "num_patches": 2}, {"image": A}, [A11, A12]] +TEST_CASE_12 = [{"patch_size": (2, 2), "sort_fn": "random", "num_patches": 4}, {"image": A}, [A11, A12, A21, A22]] +TEST_CASE_13 = [ + {"patch_size": (2, 2), "min_offset": 0, "max_offset": 1, "num_patches": 1, "sort_fn": "random"}, + {"image": A}, + [A[:, 1:3, 1:3]], +] TEST_SINGLE = [] for p in TEST_NDARRAYS: @@ -72,6 +79,9 @@ TEST_SINGLE.append([p, *TEST_CASE_8]) TEST_SINGLE.append([p, *TEST_CASE_9]) TEST_SINGLE.append([p, *TEST_CASE_10]) + TEST_SINGLE.append([p, *TEST_CASE_11]) + TEST_SINGLE.append([p, *TEST_CASE_12]) + TEST_SINGLE.append([p, *TEST_CASE_13]) class TestRandGridPatchd(unittest.TestCase):