diff --git a/monai/apps/pathology/transforms/spatial/array.py b/monai/apps/pathology/transforms/spatial/array.py index abfdf3cac0..58fef8969f 100644 --- a/monai/apps/pathology/transforms/spatial/array.py +++ b/monai/apps/pathology/transforms/spatial/array.py @@ -75,12 +75,13 @@ def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: .contiguous() ) elif isinstance(image, np.ndarray): - h_step, w_step = steps - c_stride, h_stride, w_stride = image.strides + x_step, y_step = steps + c_stride, x_stride, y_stride = image.strides + n_channels = image.shape[0] patches = as_strided( image, - shape=(*self.grid_size, 3, patch_size[0], patch_size[1]), - strides=(h_stride * h_step, w_stride * w_step, c_stride, h_stride, w_stride), + shape=(*self.grid_size, n_channels, patch_size[0], patch_size[1]), + strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), writeable=False, ) # flatten the first two dimensions @@ -210,17 +211,17 @@ def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: ) # extact tiles - h_step, w_step = self.step, self.step - h_size, w_size = self.tile_size, self.tile_size - c_len, h_len, w_len = img_np.shape - c_stride, h_stride, w_stride = img_np.strides + x_step, y_step = self.step, self.step + h_tile, w_tile = self.tile_size, self.tile_size + c_image, h_image, w_image = img_np.shape + c_stride, x_stride, y_stride = img_np.strides llw = as_strided( img_np, - shape=((h_len - h_size) // h_step + 1, (w_len - w_size) // w_step + 1, c_len, h_size, w_size), - strides=(h_stride * h_step, w_stride * w_step, c_stride, h_stride, w_stride), + shape=((h_image - h_tile) // x_step + 1, (w_image - w_tile) // y_step + 1, c_image, h_tile, w_tile), + strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), writeable=False, ) - img_np = llw.reshape(-1, c_len, h_size, w_size) + img_np = llw.reshape(-1, c_image, h_tile, w_tile) # if keeping all patches if self.tile_count is None: diff --git a/tests/test_split_on_grid.py b/tests/test_split_on_grid.py index a3bf2674f4..8017fcccbd 100644 --- a/tests/test_split_on_grid.py +++ b/tests/test_split_on_grid.py @@ -34,6 +34,11 @@ TEST_CASE_5 = [{"grid_size": 1, "patch_size": 4}, A, torch.stack([A])] TEST_CASE_6 = [{"grid_size": 2, "patch_size": 2}, A, torch.stack([A11, A12, A21, A22])] TEST_CASE_7 = [{"grid_size": 1}, A, torch.stack([A])] +TEST_CASE_8 = [ + {"grid_size": (2, 2), "patch_size": 2}, + torch.arange(12).reshape(1, 3, 4).to(torch.float32), + torch.Tensor([[[[0, 1], [4, 5]]], [[[2, 3], [6, 7]]], [[[4, 5], [8, 9]]], [[[6, 7], [10, 11]]]]).to(torch.float32), +] TEST_SINGLE = [] for p in TEST_NDARRAYS: @@ -45,6 +50,7 @@ TEST_SINGLE.append([p, *TEST_CASE_5]) TEST_SINGLE.append([p, *TEST_CASE_6]) TEST_SINGLE.append([p, *TEST_CASE_7]) + TEST_SINGLE.append([p, *TEST_CASE_8]) TEST_CASE_MC_0 = [{"grid_size": (2, 2)}, [A, A], [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])]] TEST_CASE_MC_1 = [{"grid_size": (2, 1)}, [A] * 5, [torch.stack([A1, A2])] * 5] diff --git a/tests/test_split_on_grid_dict.py b/tests/test_split_on_grid_dict.py index 5f3e442640..7b96fc4190 100644 --- a/tests/test_split_on_grid_dict.py +++ b/tests/test_split_on_grid_dict.py @@ -34,6 +34,11 @@ TEST_CASE_5 = [{"keys": "image", "grid_size": 1, "patch_size": 4}, {"image": A}, torch.stack([A])] TEST_CASE_6 = [{"keys": "image", "grid_size": 2, "patch_size": 2}, {"image": A}, torch.stack([A11, A12, A21, A22])] TEST_CASE_7 = [{"keys": "image", "grid_size": 1}, {"image": A}, torch.stack([A])] +TEST_CASE_8 = [ + {"keys": "image", "grid_size": (2, 2), "patch_size": 2}, + {"image": torch.arange(12).reshape(1, 3, 4).to(torch.float32)}, + torch.Tensor([[[[0, 1], [4, 5]]], [[[2, 3], [6, 7]]], [[[4, 5], [8, 9]]], [[[6, 7], [10, 11]]]]).to(torch.float32), +] TEST_SINGLE = [] for p in TEST_NDARRAYS: @@ -45,6 +50,7 @@ TEST_SINGLE.append([p, *TEST_CASE_5]) TEST_SINGLE.append([p, *TEST_CASE_6]) TEST_SINGLE.append([p, *TEST_CASE_7]) + TEST_SINGLE.append([p, *TEST_CASE_8]) TEST_CASE_MC_0 = [ {"keys": "image", "grid_size": (2, 2)},