Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 35 additions & 28 deletions monai/apps/pathology/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from monai.config.type_definitions import NdarrayOrTensor
from monai.transforms.transform import Randomizable, Transform
from monai.utils import convert_data_type, convert_to_dst_type
from monai.utils.enums import TransformBackends

__all__ = ["SplitOnGrid", "TileOnGrid"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like SplitOnGrid(2, 2)(torch.arange(12).reshape(1, 3, 4)) works fine but SplitOnGrid(2, 2)(np.arange(12).reshape(1, 3, 4)) doesn't work, could you help fix it here? #3378 was merged too quickly

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wyli, sorry I missed this. I will fix it now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expand Down Expand Up @@ -129,6 +130,8 @@ class TileOnGrid(Randomizable, Transform):

"""

backend = [TransformBackends.NUMPY]

def __init__(
self,
tile_count: Optional[int] = None,
Expand Down Expand Up @@ -185,37 +188,39 @@ def randomize(self, img_size: Sequence[int]) -> None:
else:
self.random_idxs = np.array((0,))

def __call__(self, image: np.ndarray) -> np.ndarray:
def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor:
img_np: np.ndarray
img_np, *_ = convert_data_type(image, np.ndarray) # type: ignore

# add random offset
self.randomize(img_size=image.shape)
self.randomize(img_size=img_np.shape)

if self.random_offset and (self.offset[0] > 0 or self.offset[1] > 0):
image = image[:, self.offset[0] :, self.offset[1] :]
img_np = img_np[:, self.offset[0] :, self.offset[1] :]

# pad to full size, divisible by tile_size
if self.pad_full:
c, h, w = image.shape
c, h, w = img_np.shape
pad_h = (self.tile_size - h % self.tile_size) % self.tile_size
pad_w = (self.tile_size - w % self.tile_size) % self.tile_size
image = np.pad(
image,
img_np = np.pad(
img_np,
[[0, 0], [pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2]],
constant_values=self.background_val,
)

# extact tiles
x_step, y_step = self.step, self.step
x_size, y_size = self.tile_size, self.tile_size
c_len, x_len, y_len = image.shape
c_stride, x_stride, y_stride = image.strides
h_step, w_step = self.step, self.step
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel the previous variable x_step is better than h_step, x indicates the axis direction, h indicates height which is the length of x direction

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can fix this too. sorry that I didn't see your comments before merging.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

h_size, w_size = self.tile_size, self.tile_size
c_len, h_len, w_len = img_np.shape
c_stride, h_stride, w_stride = img_np.strides
llw = as_strided(
image,
shape=((x_len - x_size) // x_step + 1, (y_len - y_size) // y_step + 1, c_len, x_size, y_size),
strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride),
img_np,
shape=((h_len - h_size) // h_step + 1, (w_len - w_size) // w_step + 1, c_len, h_size, w_size),
strides=(h_stride * h_step, w_stride * w_step, c_stride, h_stride, w_stride),
writeable=False,
)
image = llw.reshape(-1, c_len, x_size, y_size)
img_np = llw.reshape(-1, c_len, h_size, w_size)

# if keeping all patches
if self.tile_count is None:
Expand All @@ -224,32 +229,34 @@ def __call__(self, image: np.ndarray) -> np.ndarray:
thresh = 0.999 * 3 * self.background_val * self.tile_size * self.tile_size
if self.filter_mode == "min":
# default, keep non-background tiles (small values)
idxs = np.argwhere(image.sum(axis=(1, 2, 3)) < thresh)
image = image[idxs.reshape(-1)]
idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) < thresh)
img_np = img_np[idxs.reshape(-1)]
elif self.filter_mode == "max":
idxs = np.argwhere(image.sum(axis=(1, 2, 3)) >= thresh)
image = image[idxs.reshape(-1)]
idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) >= thresh)
img_np = img_np[idxs.reshape(-1)]

else:
if len(image) > self.tile_count:
if len(img_np) > self.tile_count:

if self.filter_mode == "min":
# default, keep non-background tiles (smallest values)
idxs = np.argsort(image.sum(axis=(1, 2, 3)))[: self.tile_count]
image = image[idxs]
idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[: self.tile_count]
img_np = img_np[idxs]
elif self.filter_mode == "max":
idxs = np.argsort(image.sum(axis=(1, 2, 3)))[-self.tile_count :]
image = image[idxs]
idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[-self.tile_count :]
img_np = img_np[idxs]
else:
# random subset (more appropriate for WSIs without distinct background)
if self.random_idxs is not None:
image = image[self.random_idxs]
img_np = img_np[self.random_idxs]

elif len(image) < self.tile_count:
image = np.pad(
image,
[[0, self.tile_count - len(image)], [0, 0], [0, 0], [0, 0]],
elif len(img_np) < self.tile_count:
img_np = np.pad(
img_np,
[[0, self.tile_count - len(img_np)], [0, 0], [0, 0], [0, 0]],
constant_values=self.background_val,
)

image, *_ = convert_to_dst_type(src=img_np, dst=image, dtype=image.dtype)

return image
8 changes: 5 additions & 3 deletions monai/apps/pathology/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import copy
from typing import Any, Dict, Hashable, List, Mapping, Optional, Tuple, Union

import numpy as np

from monai.config import KeysCollection
from monai.config.type_definitions import NdarrayOrTensor
from monai.transforms.transform import MapTransform, Randomizable
Expand Down Expand Up @@ -81,6 +79,8 @@ class TileOnGridd(Randomizable, MapTransform):

"""

backend = SplitOnGrid.backend

def __init__(
self,
keys: KeysCollection,
Expand Down Expand Up @@ -112,7 +112,9 @@ def __init__(
def randomize(self, data: Any = None) -> None:
self.seed = self.R.randint(10000) # type: ignore

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Union[Dict[Hashable, np.ndarray], List[Dict]]:
def __call__(
self, data: Mapping[Hashable, NdarrayOrTensor]
) -> Union[Dict[Hashable, NdarrayOrTensor], List[Dict[Hashable, NdarrayOrTensor]]]:

self.randomize()

Expand Down
28 changes: 20 additions & 8 deletions tests/test_tile_on_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from parameterized import parameterized

from monai.apps.pathology.transforms import TileOnGrid
from tests.utils import TEST_NDARRAYS, assert_allclose

TEST_CASES = []
for tile_count in [16, 64]:
Expand All @@ -38,6 +39,10 @@
for step in [4, 8]:
TEST_CASES.append([{"tile_count": 16, "step": step, "tile_size": tile_size}])

TESTS = []
for p in TEST_NDARRAYS:
for tc in TEST_CASES:
TESTS.append([p, *tc])

TEST_CASES2 = []
for tile_count in [16, 64]:
Expand All @@ -56,6 +61,11 @@
]
)

TESTS2 = []
for p in TEST_NDARRAYS:
for tc in TEST_CASES2:
TESTS2.append([p, *tc])


def make_image(
tile_count: int,
Expand Down Expand Up @@ -104,25 +114,27 @@ def make_image(


class TestTileOnGrid(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_tile_patch_single_call(self, input_parameters):
@parameterized.expand(TESTS)
def test_tile_patch_single_call(self, in_type, input_parameters):

img, tiles = make_image(**input_parameters)
input_img = in_type(img)

tiler = TileOnGrid(**input_parameters)
output = tiler(img)
np.testing.assert_equal(output, tiles)
output = tiler(input_img)
assert_allclose(output, tiles, type_test=False)

@parameterized.expand(TEST_CASES2)
def test_tile_patch_random_call(self, input_parameters):
@parameterized.expand(TESTS2)
def test_tile_patch_random_call(self, in_type, input_parameters):

img, tiles = make_image(**input_parameters, seed=123)
input_img = in_type(img)

tiler = TileOnGrid(**input_parameters)
tiler.set_random_state(seed=123)

output = tiler(img)
np.testing.assert_equal(output, tiles)
output = tiler(input_img)
assert_allclose(output, tiles, type_test=False)


if __name__ == "__main__":
Expand Down
40 changes: 28 additions & 12 deletions tests/test_tile_on_grid_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
from typing import Optional

import numpy as np
import torch
from parameterized import parameterized

from monai.apps.pathology.transforms import TileOnGridDict
from tests.utils import TEST_NDARRAYS, assert_allclose

TEST_CASES = []
for tile_count in [16, 64]:
Expand All @@ -36,11 +38,14 @@
]
)


for tile_size in [8, 16]:
for step in [4, 8]:
TEST_CASES.append([{"tile_count": 16, "step": step, "tile_size": tile_size}])

TESTS = []
for p in TEST_NDARRAYS:
for tc in TEST_CASES:
TESTS.append([p, *tc])

TEST_CASES2 = []
for tile_count in [16, 64]:
Expand All @@ -61,6 +66,10 @@
]
)

TESTS2 = []
for p in TEST_NDARRAYS:
for tc in TEST_CASES2:
TESTS2.append([p, *tc])

for tile_size in [8, 16]:
for step in [4, 8]:
Expand Down Expand Up @@ -114,46 +123,53 @@ def make_image(


class TestTileOnGridDict(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_tile_patch_single_call(self, input_parameters):
@parameterized.expand(TESTS)
def test_tile_patch_single_call(self, in_type, input_parameters):

key = "image"
input_parameters["keys"] = key

img, tiles = make_image(**input_parameters)
input_img = in_type(img)

splitter = TileOnGridDict(**input_parameters)

output = splitter({key: img})
output = splitter({key: input_img})

if input_parameters.get("return_list_of_dicts", False):
output = np.stack([ix[key] for ix in output], axis=0)
if isinstance(input_img, torch.Tensor):
output = torch.stack([ix[key] for ix in output], axis=0)
else:
output = np.stack([ix[key] for ix in output], axis=0)
else:
output = output[key]

np.testing.assert_equal(tiles, output)
assert_allclose(output, tiles, type_test=False)

@parameterized.expand(TEST_CASES2)
def test_tile_patch_random_call(self, input_parameters):
@parameterized.expand(TESTS2)
def test_tile_patch_random_call(self, in_type, input_parameters):

key = "image"
input_parameters["keys"] = key

random_state = np.random.RandomState(123)
seed = random_state.randint(10000)
img, tiles = make_image(**input_parameters, seed=seed)
input_img = in_type(img)

splitter = TileOnGridDict(**input_parameters)
splitter.set_random_state(seed=123)

output = splitter({key: img})
output = splitter({key: input_img})

if input_parameters.get("return_list_of_dicts", False):
output = np.stack([ix[key] for ix in output], axis=0)
if isinstance(input_img, torch.Tensor):
output = torch.stack([ix[key] for ix in output], axis=0)
else:
output = np.stack([ix[key] for ix in output], axis=0)
else:
output = output[key]

np.testing.assert_equal(tiles, output)
assert_allclose(output, tiles, type_test=False)


if __name__ == "__main__":
Expand Down