Skip to content
8 changes: 5 additions & 3 deletions monai/data/wsi_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,13 @@ def _get_data(self, sample: Dict):
def _transform(self, index: int):
# Get a single entry of data
sample: Dict = self.data[index]

# Extract patch image and associated metadata
image, metadata = self._get_data(sample)

# Get the label
label = self._get_label(sample)

# Create put all patch information together and apply transforms
patch = {"image": image, "label": label, "metadata": metadata}
return apply_transform(self.transform, patch) if self.transform else patch
# Apply transforms and output
output = {"image": image, "label": label, "metadata": metadata}
return apply_transform(self.transform, output) if self.transform else output
2 changes: 1 addition & 1 deletion monai/data/wsi_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def get_data(
# Verify location
if location is None:
location = (0, 0)
wsi_size = self.get_size(each_wsi, level)
wsi_size = self.get_size(each_wsi, 0)
if location[0] > wsi_size[0] or location[1] > wsi_size[1]:
raise ValueError(f"Location is outside of the image: location={location}, image size={wsi_size}")

Expand Down
60 changes: 30 additions & 30 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2535,62 +2535,62 @@ def __init__(self, grid: Tuple[int, int] = (2, 2), size: Optional[Union[int, Tup
# Patch size
self.size = None if size is None else ensure_tuple_rep(size, len(self.grid))

def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor:
if self.grid == (1, 1) and self.size is None:
if isinstance(image, torch.Tensor):
return torch.stack([image])
elif isinstance(image, np.ndarray):
return np.stack([image]) # type: ignore
else:
raise ValueError(f"Input type [{type(image)}] is not supported.")
def __call__(
self, image: NdarrayOrTensor, size: Optional[Union[int, Tuple[int, int], np.ndarray]] = None
) -> List[NdarrayOrTensor]:
input_size = self.size if size is None else ensure_tuple_rep(size, len(self.grid))

if self.grid == (1, 1) and input_size is None:
return [image]

size, steps = self._get_params(image.shape[1:])
patches: NdarrayOrTensor
split_size, steps = self._get_params(image.shape[1:], input_size)
patches: List[NdarrayOrTensor]
if isinstance(image, torch.Tensor):
patches = (
image.unfold(1, size[0], steps[0])
.unfold(2, size[1], steps[1])
unfolded_image = (
image.unfold(1, split_size[0], steps[0])
.unfold(2, split_size[1], steps[1])
.flatten(1, 2)
.transpose(0, 1)
.contiguous()
)
# Make a list of contiguous patches
patches = [p.contiguous() for p in unfolded_image]
elif isinstance(image, np.ndarray):
x_step, y_step = steps
c_stride, x_stride, y_stride = image.strides
n_channels = image.shape[0]
patches = as_strided(
strided_image = as_strided(
image,
shape=(*self.grid, n_channels, size[0], size[1]),
shape=(*self.grid, n_channels, split_size[0], split_size[1]),
strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride),
writeable=False,
)
# flatten the first two dimensions
patches = patches.reshape(np.prod(patches.shape[:2]), *patches.shape[2:])
# make it a contiguous array
patches = np.ascontiguousarray(patches)
# Flatten the first two dimensions
strided_image = strided_image.reshape(-1, *strided_image.shape[2:])
# Make a list of contiguous patches
patches = [np.ascontiguousarray(p) for p in strided_image]
else:
raise ValueError(f"Input type [{type(image)}] is not supported.")

return patches

def _get_params(self, image_size: Union[Sequence[int], np.ndarray]):
def _get_params(
self, image_size: Union[Sequence[int], np.ndarray], size: Optional[Union[Sequence[int], np.ndarray]] = None
):
"""
Calculate the size and step required for splitting the image
Args:
The size of the input image
"""
if self.size is not None:
# Set the split size to the given default size
if any(self.size[i] > image_size[i] for i in range(len(self.grid))):
raise ValueError("The image size ({image_size})is smaller than the requested split size ({self.size})")
split_size = self.size
else:
if size is None:
# infer each sub-image size from the image size and the grid
split_size = tuple(image_size[i] // self.grid[i] for i in range(len(self.grid)))
size = tuple(image_size[i] // self.grid[i] for i in range(len(self.grid)))

if any(size[i] > image_size[i] for i in range(len(self.grid))):
raise ValueError(f"The image size ({image_size})is smaller than the requested split size ({size})")

steps = tuple(
(image_size[i] - split_size[i]) // (self.grid[i] - 1) if self.grid[i] > 1 else image_size[i]
(image_size[i] - size[i]) // (self.grid[i] - 1) if self.grid[i] > 1 else image_size[i]
for i in range(len(self.grid))
)

return split_size, steps
return size, steps
19 changes: 13 additions & 6 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -2160,7 +2160,8 @@ class GridSplitd(MapTransform):
Args:
keys: keys of the corresponding items to be transformed.
grid: a tuple define the shape of the grid upon which the image is split. Defaults to (2, 2)
size: a tuple or an integer that defines the output patch sizes.
size: a tuple or an integer that defines the output patch sizes,
or a dictionary that define it seperately for each key, like {"image": 3, "mask", (2, 2)}.
If it's an integer, the value will be repeated for each dimension.
The default is None, where the patch size will be inferred from the grid shape.
allow_missing_keys: don't raise exception if key is missing.
Expand All @@ -2174,17 +2175,23 @@ def __init__(
self,
keys: KeysCollection,
grid: Tuple[int, int] = (2, 2),
size: Optional[Union[int, Tuple[int, int]]] = None,
size: Optional[Union[int, Tuple[int, int], Dict[Hashable, Union[int, Tuple[int, int], None]]]] = None,
allow_missing_keys: bool = False,
):
super().__init__(keys, allow_missing_keys)
self.splitter = GridSplit(grid=grid, size=size)
self.grid = grid
self.size = size if isinstance(size, dict) else {key: size for key in self.keys}
self.splitter = GridSplit(grid=grid)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]:
d = dict(data)
n_outputs = np.prod(self.grid)
output: List[Dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(n_outputs)]
for key in self.key_iterator(d):
d[key] = self.splitter(d[key])
return d
result = self.splitter(d[key], self.size[key])
for i in range(n_outputs):
output[i][key] = result[i]
return output


SpatialResampleD = SpatialResampleDict = SpatialResampled
Expand Down
28 changes: 15 additions & 13 deletions tests/test_grid_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
A2 = torch.cat([A21, A22], 2)
A = torch.cat([A1, A2], 1)

TEST_CASE_0 = [{"grid": (2, 2)}, A, torch.stack([A11, A12, A21, A22])]
TEST_CASE_1 = [{"grid": (2, 1)}, A, torch.stack([A1, A2])]
TEST_CASE_2 = [{"grid": (1, 2)}, A1, torch.stack([A11, A12])]
TEST_CASE_3 = [{"grid": (1, 2)}, A2, torch.stack([A21, A22])]
TEST_CASE_4 = [{"grid": (1, 1), "size": (2, 2)}, A, torch.stack([A11])]
TEST_CASE_5 = [{"grid": (1, 1), "size": 4}, A, torch.stack([A])]
TEST_CASE_6 = [{"grid": (2, 2), "size": 2}, A, torch.stack([A11, A12, A21, A22])]
TEST_CASE_7 = [{"grid": (1, 1)}, A, torch.stack([A])]
TEST_CASE_0 = [{"grid": (2, 2)}, A, [A11, A12, A21, A22]]
TEST_CASE_1 = [{"grid": (2, 1)}, A, [A1, A2]]
TEST_CASE_2 = [{"grid": (1, 2)}, A1, [A11, A12]]
TEST_CASE_3 = [{"grid": (1, 2)}, A2, [A21, A22]]
TEST_CASE_4 = [{"grid": (1, 1), "size": (2, 2)}, A, [A11]]
TEST_CASE_5 = [{"grid": (1, 1), "size": 4}, A, [A]]
TEST_CASE_6 = [{"grid": (2, 2), "size": 2}, A, [A11, A12, A21, A22]]
TEST_CASE_7 = [{"grid": (1, 1)}, A, [A]]
TEST_CASE_8 = [
{"grid": (2, 2), "size": 2},
torch.arange(12).reshape(1, 3, 4).to(torch.float32),
Expand All @@ -52,9 +52,9 @@
TEST_SINGLE.append([p, *TEST_CASE_7])
TEST_SINGLE.append([p, *TEST_CASE_8])

TEST_CASE_MC_0 = [{"grid": (2, 2)}, [A, A], [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])]]
TEST_CASE_MC_1 = [{"grid": (2, 1)}, [A] * 5, [torch.stack([A1, A2])] * 5]
TEST_CASE_MC_2 = [{"grid": (1, 2)}, [A1, A2], [torch.stack([A11, A12]), torch.stack([A21, A22])]]
TEST_CASE_MC_0 = [{"grid": (2, 2)}, [A, A], [[A11, A12, A21, A22], [A11, A12, A21, A22]]]
TEST_CASE_MC_1 = [{"grid": (2, 1)}, [A] * 5, [[A1, A2]] * 5]
TEST_CASE_MC_2 = [{"grid": (1, 2)}, [A1, A2], [[A11, A12], [A21, A22]]]

TEST_MULTIPLE = []
for p in TEST_NDARRAYS:
Expand All @@ -69,15 +69,17 @@ def test_split_patch_single_call(self, in_type, input_parameters, image, expecte
input_image = in_type(image)
splitter = GridSplit(**input_parameters)
output = splitter(input_image)
assert_allclose(output, expected, type_test=False)
for output_patch, expected_patch in zip(output, expected):
assert_allclose(output_patch, expected_patch, type_test=False)

@parameterized.expand(TEST_MULTIPLE)
def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list):
splitter = GridSplit(**input_parameters)
for image, expected in zip(img_list, expected_list):
input_image = in_type(image)
output = splitter(input_image)
assert_allclose(output, expected, type_test=False)
for output_patch, expected_patch in zip(output, expected):
assert_allclose(output_patch, expected_patch, type_test=False)


if __name__ == "__main__":
Expand Down
42 changes: 18 additions & 24 deletions tests/test_grid_splitd.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@
A2 = torch.cat([A21, A22], 2)
A = torch.cat([A1, A2], 1)

TEST_CASE_0 = [{"keys": "image", "grid": (2, 2)}, {"image": A}, torch.stack([A11, A12, A21, A22])]
TEST_CASE_1 = [{"keys": "image", "grid": (2, 1)}, {"image": A}, torch.stack([A1, A2])]
TEST_CASE_2 = [{"keys": "image", "grid": (1, 2)}, {"image": A1}, torch.stack([A11, A12])]
TEST_CASE_3 = [{"keys": "image", "grid": (1, 2)}, {"image": A2}, torch.stack([A21, A22])]
TEST_CASE_4 = [{"keys": "image", "grid": (1, 1), "size": (2, 2)}, {"image": A}, torch.stack([A11])]
TEST_CASE_5 = [{"keys": "image", "grid": (1, 1), "size": 4}, {"image": A}, torch.stack([A])]
TEST_CASE_6 = [{"keys": "image", "grid": (2, 2), "size": 2}, {"image": A}, torch.stack([A11, A12, A21, A22])]
TEST_CASE_7 = [{"keys": "image", "grid": (1, 1)}, {"image": A}, torch.stack([A])]
TEST_CASE_0 = [{"keys": "image", "grid": (2, 2)}, {"image": A}, [A11, A12, A21, A22]]
TEST_CASE_1 = [{"keys": "image", "grid": (2, 1)}, {"image": A}, [A1, A2]]
TEST_CASE_2 = [{"keys": "image", "grid": (1, 2)}, {"image": A1}, [A11, A12]]
TEST_CASE_3 = [{"keys": "image", "grid": (1, 2)}, {"image": A2}, [A21, A22]]
TEST_CASE_4 = [{"keys": "image", "grid": (1, 1), "size": {"image": (2, 2)}}, {"image": A}, [A11]]
TEST_CASE_5 = [{"keys": "image", "grid": (1, 1), "size": {"image": 4}}, {"image": A}, [A]]
TEST_CASE_6 = [{"keys": "image", "grid": (2, 2), "size": {"image": 2}}, {"image": A}, [A11, A12, A21, A22]]
TEST_CASE_7 = [{"keys": "image", "grid": (1, 1)}, {"image": A}, [A]]
TEST_CASE_8 = [
{"keys": "image", "grid": (2, 2), "size": 2},
{"keys": "image", "grid": (2, 2), "size": {"image": 2}},
{"image": torch.arange(12).reshape(1, 3, 4).to(torch.float32)},
torch.Tensor([[[[0, 1], [4, 5]]], [[[2, 3], [6, 7]]], [[[4, 5], [8, 9]]], [[[6, 7], [10, 11]]]]).to(torch.float32),
]
Expand All @@ -55,18 +55,10 @@
TEST_CASE_MC_0 = [
{"keys": "image", "grid": (2, 2)},
[{"image": A}, {"image": A}],
[torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])],
]
TEST_CASE_MC_1 = [
{"keys": "image", "grid": (2, 1)},
[{"image": A}, {"image": A}, {"image": A}],
[torch.stack([A1, A2])] * 3,
]
TEST_CASE_MC_2 = [
{"keys": "image", "grid": (1, 2)},
[{"image": A1}, {"image": A2}],
[torch.stack([A11, A12]), torch.stack([A21, A22])],
[[A11, A12, A21, A22], [A11, A12, A21, A22]],
]
TEST_CASE_MC_1 = [{"keys": "image", "grid": (2, 1)}, [{"image": A}, {"image": A}, {"image": A}], [[A1, A2]] * 3]
TEST_CASE_MC_2 = [{"keys": "image", "grid": (1, 2)}, [{"image": A1}, {"image": A2}], [[A11, A12], [A21, A22]]]

TEST_MULTIPLE = []
for p in TEST_NDARRAYS:
Expand All @@ -82,8 +74,9 @@ def test_split_patch_single_call(self, in_type, input_parameters, img_dict, expe
for k, v in img_dict.items():
input_dict[k] = in_type(v)
splitter = GridSplitd(**input_parameters)
output = splitter(input_dict)[input_parameters["keys"]]
assert_allclose(output, expected, type_test=False)
output = splitter(input_dict)
for output_patch, expected_patch in zip(output, expected):
assert_allclose(output_patch[input_parameters["keys"]], expected_patch, type_test=False)

@parameterized.expand(TEST_MULTIPLE)
def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list):
Expand All @@ -92,8 +85,9 @@ def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, ex
input_dict = {}
for k, v in img_dict.items():
input_dict[k] = in_type(v)
output = splitter(input_dict)[input_parameters["keys"]]
assert_allclose(output, expected, type_test=False)
output = splitter(input_dict)
for output_patch, expected_patch in zip(output, expected):
assert_allclose(output_patch[input_parameters["keys"]], expected_patch, type_test=False)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_patch_wsi_dataset_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def setUpClass(cls):
cls.backend = "cucim"


@skipUnless(has_osl, "Requires cucim")
@skipUnless(has_osl, "Requires openslide")
class TestPatchWSIDatasetOpenSlide(PatchWSIDatasetTests.Tests):
@classmethod
def setUpClass(cls):
Expand Down