From 6727e49b0e5b330de53b098a32177522381865ab Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 23 Nov 2021 13:45:18 +0000 Subject: [PATCH 1/4] Refactor some variable names Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- .../pathology/transforms/spatial/array.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/monai/apps/pathology/transforms/spatial/array.py b/monai/apps/pathology/transforms/spatial/array.py index abfdf3cac0..4b8d19347e 100644 --- a/monai/apps/pathology/transforms/spatial/array.py +++ b/monai/apps/pathology/transforms/spatial/array.py @@ -75,12 +75,12 @@ def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: .contiguous() ) elif isinstance(image, np.ndarray): - h_step, w_step = steps - c_stride, h_stride, w_stride = image.strides + x_step, y_step = steps + c_stride, x_stride, y_stride = image.strides patches = as_strided( image, shape=(*self.grid_size, 3, patch_size[0], patch_size[1]), - strides=(h_stride * h_step, w_stride * w_step, c_stride, h_stride, w_stride), + strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), writeable=False, ) # flatten the first two dimensions @@ -210,17 +210,17 @@ def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: ) # extact tiles - h_step, w_step = self.step, self.step - h_size, w_size = self.tile_size, self.tile_size - c_len, h_len, w_len = img_np.shape - c_stride, h_stride, w_stride = img_np.strides + x_step, y_step = self.step, self.step + h_tile, w_tile = self.tile_size, self.tile_size + c_image, h_image, w_image = img_np.shape + c_stride, x_stride, y_stride = img_np.strides llw = as_strided( img_np, - shape=((h_len - h_size) // h_step + 1, (w_len - w_size) // w_step + 1, c_len, h_size, w_size), - strides=(h_stride * h_step, w_stride * w_step, c_stride, h_stride, w_stride), + shape=((h_image - h_tile) // x_step + 1, (w_image - w_tile) // y_step + 1, c_image, h_tile, w_tile), + strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), writeable=False, ) - img_np = llw.reshape(-1, c_len, h_size, w_size) + img_np = llw.reshape(-1, c_image, h_tile, w_tile) # if keeping all patches if self.tile_count is None: From dfb78c33681b8e98b1e847110281e393ee05958e Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 23 Nov 2021 13:54:21 +0000 Subject: [PATCH 2/4] Fix SplitOnGrid issue with numpy backend Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/apps/pathology/transforms/spatial/array.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/apps/pathology/transforms/spatial/array.py b/monai/apps/pathology/transforms/spatial/array.py index 4b8d19347e..3591be0303 100644 --- a/monai/apps/pathology/transforms/spatial/array.py +++ b/monai/apps/pathology/transforms/spatial/array.py @@ -77,9 +77,10 @@ def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: elif isinstance(image, np.ndarray): x_step, y_step = steps c_stride, x_stride, y_stride = image.strides + c_image = image.shape[0] patches = as_strided( image, - shape=(*self.grid_size, 3, patch_size[0], patch_size[1]), + shape=(*self.grid_size, c_image, patch_size[0], patch_size[1]), strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), writeable=False, ) From a63d8f14a3481d6ad9f14616e43710bc5a46eb00 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 23 Nov 2021 14:25:31 +0000 Subject: [PATCH 3/4] Add unittest to cover the fixed issue Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_split_on_grid.py | 6 ++++++ tests/test_split_on_grid_dict.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/tests/test_split_on_grid.py b/tests/test_split_on_grid.py index a3bf2674f4..8017fcccbd 100644 --- a/tests/test_split_on_grid.py +++ b/tests/test_split_on_grid.py @@ -34,6 +34,11 @@ TEST_CASE_5 = [{"grid_size": 1, "patch_size": 4}, A, torch.stack([A])] TEST_CASE_6 = [{"grid_size": 2, "patch_size": 2}, A, torch.stack([A11, A12, A21, A22])] TEST_CASE_7 = [{"grid_size": 1}, A, torch.stack([A])] +TEST_CASE_8 = [ + {"grid_size": (2, 2), "patch_size": 2}, + torch.arange(12).reshape(1, 3, 4).to(torch.float32), + torch.Tensor([[[[0, 1], [4, 5]]], [[[2, 3], [6, 7]]], [[[4, 5], [8, 9]]], [[[6, 7], [10, 11]]]]).to(torch.float32), +] TEST_SINGLE = [] for p in TEST_NDARRAYS: @@ -45,6 +50,7 @@ TEST_SINGLE.append([p, *TEST_CASE_5]) TEST_SINGLE.append([p, *TEST_CASE_6]) TEST_SINGLE.append([p, *TEST_CASE_7]) + TEST_SINGLE.append([p, *TEST_CASE_8]) TEST_CASE_MC_0 = [{"grid_size": (2, 2)}, [A, A], [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])]] TEST_CASE_MC_1 = [{"grid_size": (2, 1)}, [A] * 5, [torch.stack([A1, A2])] * 5] diff --git a/tests/test_split_on_grid_dict.py b/tests/test_split_on_grid_dict.py index 5f3e442640..7b96fc4190 100644 --- a/tests/test_split_on_grid_dict.py +++ b/tests/test_split_on_grid_dict.py @@ -34,6 +34,11 @@ TEST_CASE_5 = [{"keys": "image", "grid_size": 1, "patch_size": 4}, {"image": A}, torch.stack([A])] TEST_CASE_6 = [{"keys": "image", "grid_size": 2, "patch_size": 2}, {"image": A}, torch.stack([A11, A12, A21, A22])] TEST_CASE_7 = [{"keys": "image", "grid_size": 1}, {"image": A}, torch.stack([A])] +TEST_CASE_8 = [ + {"keys": "image", "grid_size": (2, 2), "patch_size": 2}, + {"image": torch.arange(12).reshape(1, 3, 4).to(torch.float32)}, + torch.Tensor([[[[0, 1], [4, 5]]], [[[2, 3], [6, 7]]], [[[4, 5], [8, 9]]], [[[6, 7], [10, 11]]]]).to(torch.float32), +] TEST_SINGLE = [] for p in TEST_NDARRAYS: @@ -45,6 +50,7 @@ TEST_SINGLE.append([p, *TEST_CASE_5]) TEST_SINGLE.append([p, *TEST_CASE_6]) TEST_SINGLE.append([p, *TEST_CASE_7]) + TEST_SINGLE.append([p, *TEST_CASE_8]) TEST_CASE_MC_0 = [ {"keys": "image", "grid_size": (2, 2)}, From 6cb978a4ce8654bb523204569a7713bc66292e75 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 23 Nov 2021 14:28:44 +0000 Subject: [PATCH 4/4] Rename c_image to n_channels Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/apps/pathology/transforms/spatial/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/apps/pathology/transforms/spatial/array.py b/monai/apps/pathology/transforms/spatial/array.py index 3591be0303..58fef8969f 100644 --- a/monai/apps/pathology/transforms/spatial/array.py +++ b/monai/apps/pathology/transforms/spatial/array.py @@ -77,10 +77,10 @@ def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: elif isinstance(image, np.ndarray): x_step, y_step = steps c_stride, x_stride, y_stride = image.strides - c_image = image.shape[0] + n_channels = image.shape[0] patches = as_strided( image, - shape=(*self.grid_size, c_image, patch_size[0], patch_size[1]), + shape=(*self.grid_size, n_channels, patch_size[0], patch_size[1]), strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), writeable=False, )