-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Closed
Labels
refactorNon-breaking feature enhancementsNon-breaking feature enhancements
Description
Refactor, simplify and possibly generalize TileOnGrid transform in pathology, for blending it into core MONAI as laid out in #4005.
MONAI/monai/apps/pathology/transforms/spatial/array.py
Lines 110 to 262 in a676e38
| class TileOnGrid(Randomizable, Transform): | |
| """ | |
| Tile the 2D image into patches on a grid and maintain a subset of it. | |
| This transform works only with np.ndarray inputs for 2D images. | |
| Args: | |
| tile_count: number of tiles to extract, if None extracts all non-background tiles | |
| Defaults to ``None``. | |
| tile_size: size of the square tile | |
| Defaults to ``256``. | |
| step: step size | |
| Defaults to ``None`` (same as tile_size) | |
| random_offset: Randomize position of the grid, instead of starting from the top-left corner | |
| Defaults to ``False``. | |
| pad_full: pad image to the size evenly divisible by tile_size | |
| Defaults to ``False``. | |
| background_val: the background constant (e.g. 255 for white background) | |
| Defaults to ``255``. | |
| filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is more than tile_size, | |
| then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for random) subset | |
| Defaults to ``min`` (which assumes background is high value) | |
| """ | |
| backend = [TransformBackends.NUMPY] | |
| def __init__( | |
| self, | |
| tile_count: Optional[int] = None, | |
| tile_size: int = 256, | |
| step: Optional[int] = None, | |
| random_offset: bool = False, | |
| pad_full: bool = False, | |
| background_val: int = 255, | |
| filter_mode: str = "min", | |
| ): | |
| self.tile_count = tile_count | |
| self.tile_size = tile_size | |
| self.random_offset = random_offset | |
| self.pad_full = pad_full | |
| self.background_val = background_val | |
| self.filter_mode = filter_mode | |
| if step is None: | |
| # non-overlapping grid | |
| self.step = self.tile_size | |
| else: | |
| self.step = step | |
| self.offset = (0, 0) | |
| self.random_idxs = np.array((0,)) | |
| if self.filter_mode not in ["min", "max", "random"]: | |
| raise ValueError("Unsupported filter_mode, must be [min, max or random]: " + str(self.filter_mode)) | |
| def randomize(self, img_size: Sequence[int]) -> None: | |
| c, h, w = img_size | |
| self.offset = (0, 0) | |
| if self.random_offset: | |
| pad_h = h % self.tile_size | |
| pad_w = w % self.tile_size | |
| self.offset = (self.R.randint(pad_h) if pad_h > 0 else 0, self.R.randint(pad_w) if pad_w > 0 else 0) | |
| h = h - self.offset[0] | |
| w = w - self.offset[1] | |
| if self.pad_full: | |
| pad_h = (self.tile_size - h % self.tile_size) % self.tile_size | |
| pad_w = (self.tile_size - w % self.tile_size) % self.tile_size | |
| h = h + pad_h | |
| w = w + pad_w | |
| h_n = (h - self.tile_size + self.step) // self.step | |
| w_n = (w - self.tile_size + self.step) // self.step | |
| tile_total = h_n * w_n | |
| if self.tile_count is not None and tile_total > self.tile_count: | |
| self.random_idxs = self.R.choice(range(tile_total), self.tile_count, replace=False) | |
| else: | |
| self.random_idxs = np.array((0,)) | |
| def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: | |
| img_np, *_ = convert_data_type(image, np.ndarray) | |
| # add random offset | |
| self.randomize(img_size=img_np.shape) | |
| if self.random_offset and (self.offset[0] > 0 or self.offset[1] > 0): | |
| img_np = img_np[:, self.offset[0] :, self.offset[1] :] | |
| # pad to full size, divisible by tile_size | |
| if self.pad_full: | |
| c, h, w = img_np.shape | |
| pad_h = (self.tile_size - h % self.tile_size) % self.tile_size | |
| pad_w = (self.tile_size - w % self.tile_size) % self.tile_size | |
| img_np = np.pad( # type: ignore | |
| img_np, | |
| [[0, 0], [pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2]], | |
| constant_values=self.background_val, | |
| ) | |
| # extact tiles | |
| x_step, y_step = self.step, self.step | |
| h_tile, w_tile = self.tile_size, self.tile_size | |
| c_image, h_image, w_image = img_np.shape | |
| c_stride, x_stride, y_stride = img_np.strides | |
| llw = as_strided( | |
| img_np, | |
| shape=((h_image - h_tile) // x_step + 1, (w_image - w_tile) // y_step + 1, c_image, h_tile, w_tile), | |
| strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), | |
| writeable=False, | |
| ) | |
| img_np = llw.reshape(-1, c_image, h_tile, w_tile) # type: ignore | |
| # if keeping all patches | |
| if self.tile_count is None: | |
| # retain only patches with significant foreground content to speed up inference | |
| # FYI, this returns a variable number of tiles, so the batch_size must be 1 (per gpu), e.g during inference | |
| thresh = 0.999 * 3 * self.background_val * self.tile_size * self.tile_size | |
| if self.filter_mode == "min": | |
| # default, keep non-background tiles (small values) | |
| idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) < thresh) | |
| img_np = img_np[idxs.reshape(-1)] | |
| elif self.filter_mode == "max": | |
| idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) >= thresh) | |
| img_np = img_np[idxs.reshape(-1)] | |
| else: | |
| if len(img_np) > self.tile_count: | |
| if self.filter_mode == "min": | |
| # default, keep non-background tiles (smallest values) | |
| idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[: self.tile_count] | |
| img_np = img_np[idxs] | |
| elif self.filter_mode == "max": | |
| idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[-self.tile_count :] | |
| img_np = img_np[idxs] | |
| else: | |
| # random subset (more appropriate for WSIs without distinct background) | |
| if self.random_idxs is not None: | |
| img_np = img_np[self.random_idxs] | |
| elif len(img_np) < self.tile_count: | |
| img_np = np.pad( # type: ignore | |
| img_np, | |
| [[0, self.tile_count - len(img_np)], [0, 0], [0, 0], [0, 0]], | |
| constant_values=self.background_val, | |
| ) | |
| image, *_ = convert_to_dst_type(src=img_np, dst=image, dtype=image.dtype) | |
| return image |
Metadata
Metadata
Assignees
Labels
refactorNon-breaking feature enhancementsNon-breaking feature enhancements
Type
Projects
Status
💯 Complete