Skip to content

Refactor and Simplify Tile transform #4012

@bhashemian

Description

@bhashemian

Refactor, simplify and possibly generalize TileOnGrid transform in pathology, for blending it into core MONAI as laid out in #4005.

class TileOnGrid(Randomizable, Transform):
"""
Tile the 2D image into patches on a grid and maintain a subset of it.
This transform works only with np.ndarray inputs for 2D images.
Args:
tile_count: number of tiles to extract, if None extracts all non-background tiles
Defaults to ``None``.
tile_size: size of the square tile
Defaults to ``256``.
step: step size
Defaults to ``None`` (same as tile_size)
random_offset: Randomize position of the grid, instead of starting from the top-left corner
Defaults to ``False``.
pad_full: pad image to the size evenly divisible by tile_size
Defaults to ``False``.
background_val: the background constant (e.g. 255 for white background)
Defaults to ``255``.
filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is more than tile_size,
then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for random) subset
Defaults to ``min`` (which assumes background is high value)
"""
backend = [TransformBackends.NUMPY]
def __init__(
self,
tile_count: Optional[int] = None,
tile_size: int = 256,
step: Optional[int] = None,
random_offset: bool = False,
pad_full: bool = False,
background_val: int = 255,
filter_mode: str = "min",
):
self.tile_count = tile_count
self.tile_size = tile_size
self.random_offset = random_offset
self.pad_full = pad_full
self.background_val = background_val
self.filter_mode = filter_mode
if step is None:
# non-overlapping grid
self.step = self.tile_size
else:
self.step = step
self.offset = (0, 0)
self.random_idxs = np.array((0,))
if self.filter_mode not in ["min", "max", "random"]:
raise ValueError("Unsupported filter_mode, must be [min, max or random]: " + str(self.filter_mode))
def randomize(self, img_size: Sequence[int]) -> None:
c, h, w = img_size
self.offset = (0, 0)
if self.random_offset:
pad_h = h % self.tile_size
pad_w = w % self.tile_size
self.offset = (self.R.randint(pad_h) if pad_h > 0 else 0, self.R.randint(pad_w) if pad_w > 0 else 0)
h = h - self.offset[0]
w = w - self.offset[1]
if self.pad_full:
pad_h = (self.tile_size - h % self.tile_size) % self.tile_size
pad_w = (self.tile_size - w % self.tile_size) % self.tile_size
h = h + pad_h
w = w + pad_w
h_n = (h - self.tile_size + self.step) // self.step
w_n = (w - self.tile_size + self.step) // self.step
tile_total = h_n * w_n
if self.tile_count is not None and tile_total > self.tile_count:
self.random_idxs = self.R.choice(range(tile_total), self.tile_count, replace=False)
else:
self.random_idxs = np.array((0,))
def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor:
img_np, *_ = convert_data_type(image, np.ndarray)
# add random offset
self.randomize(img_size=img_np.shape)
if self.random_offset and (self.offset[0] > 0 or self.offset[1] > 0):
img_np = img_np[:, self.offset[0] :, self.offset[1] :]
# pad to full size, divisible by tile_size
if self.pad_full:
c, h, w = img_np.shape
pad_h = (self.tile_size - h % self.tile_size) % self.tile_size
pad_w = (self.tile_size - w % self.tile_size) % self.tile_size
img_np = np.pad( # type: ignore
img_np,
[[0, 0], [pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2]],
constant_values=self.background_val,
)
# extact tiles
x_step, y_step = self.step, self.step
h_tile, w_tile = self.tile_size, self.tile_size
c_image, h_image, w_image = img_np.shape
c_stride, x_stride, y_stride = img_np.strides
llw = as_strided(
img_np,
shape=((h_image - h_tile) // x_step + 1, (w_image - w_tile) // y_step + 1, c_image, h_tile, w_tile),
strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride),
writeable=False,
)
img_np = llw.reshape(-1, c_image, h_tile, w_tile) # type: ignore
# if keeping all patches
if self.tile_count is None:
# retain only patches with significant foreground content to speed up inference
# FYI, this returns a variable number of tiles, so the batch_size must be 1 (per gpu), e.g during inference
thresh = 0.999 * 3 * self.background_val * self.tile_size * self.tile_size
if self.filter_mode == "min":
# default, keep non-background tiles (small values)
idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) < thresh)
img_np = img_np[idxs.reshape(-1)]
elif self.filter_mode == "max":
idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) >= thresh)
img_np = img_np[idxs.reshape(-1)]
else:
if len(img_np) > self.tile_count:
if self.filter_mode == "min":
# default, keep non-background tiles (smallest values)
idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[: self.tile_count]
img_np = img_np[idxs]
elif self.filter_mode == "max":
idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[-self.tile_count :]
img_np = img_np[idxs]
else:
# random subset (more appropriate for WSIs without distinct background)
if self.random_idxs is not None:
img_np = img_np[self.random_idxs]
elif len(img_np) < self.tile_count:
img_np = np.pad( # type: ignore
img_np,
[[0, self.tile_count - len(img_np)], [0, 0], [0, 0], [0, 0]],
constant_values=self.background_val,
)
image, *_ = convert_to_dst_type(src=img_np, dst=image, dtype=image.dtype)
return image

Metadata

Metadata

Assignees

Labels

refactorNon-breaking feature enhancements

Type

No type

Projects

Status

💯 Complete

Relationships

None yet

Development

No branches or pull requests

Issue actions