-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Closed
Labels
refactorNon-breaking feature enhancementsNon-breaking feature enhancements
Description
Refactor and possibly generalize SplitOnGrid transform in pathology, for blending it into core MONAI as laid out in #4005.
MONAI/monai/apps/pathology/transforms/spatial/array.py
Lines 26 to 107 in a676e38
| class SplitOnGrid(Transform): | |
| """ | |
| Split the image into patches based on the provided grid shape. | |
| This transform works only with torch.Tensor inputs. | |
| Args: | |
| grid_size: a tuple or an integer define the shape of the grid upon which to extract patches. | |
| If it's an integer, the value will be repeated for each dimension. Default is 2x2 | |
| patch_size: a tuple or an integer that defines the output patch sizes. | |
| If it's an integer, the value will be repeated for each dimension. | |
| The default is (0, 0), where the patch size will be inferred from the grid shape. | |
| Note: the shape of the input image is inferred based on the first image used. | |
| """ | |
| backend = [TransformBackends.TORCH, TransformBackends.NUMPY] | |
| def __init__( | |
| self, grid_size: Union[int, Tuple[int, int]] = (2, 2), patch_size: Optional[Union[int, Tuple[int, int]]] = None | |
| ): | |
| # Grid size | |
| if isinstance(grid_size, int): | |
| self.grid_size = (grid_size, grid_size) | |
| else: | |
| self.grid_size = grid_size | |
| # Patch size | |
| self.patch_size = None | |
| if isinstance(patch_size, int): | |
| self.patch_size = (patch_size, patch_size) | |
| else: | |
| self.patch_size = patch_size | |
| def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: | |
| if self.grid_size == (1, 1) and self.patch_size is None: | |
| if isinstance(image, torch.Tensor): | |
| return torch.stack([image]) | |
| elif isinstance(image, np.ndarray): | |
| return np.stack([image]) # type: ignore | |
| else: | |
| raise ValueError(f"Input type [{type(image)}] is not supported.") | |
| patch_size, steps = self.get_params(image.shape[1:]) | |
| patches: NdarrayOrTensor | |
| if isinstance(image, torch.Tensor): | |
| patches = ( | |
| image.unfold(1, patch_size[0], steps[0]) | |
| .unfold(2, patch_size[1], steps[1]) | |
| .flatten(1, 2) | |
| .transpose(0, 1) | |
| .contiguous() | |
| ) | |
| elif isinstance(image, np.ndarray): | |
| x_step, y_step = steps | |
| c_stride, x_stride, y_stride = image.strides | |
| n_channels = image.shape[0] | |
| patches = as_strided( | |
| image, | |
| shape=(*self.grid_size, n_channels, patch_size[0], patch_size[1]), | |
| strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), | |
| writeable=False, | |
| ) | |
| # flatten the first two dimensions | |
| patches = patches.reshape(np.prod(patches.shape[:2]), *patches.shape[2:]) | |
| # make it a contiguous array | |
| patches = np.ascontiguousarray(patches) | |
| else: | |
| raise ValueError(f"Input type [{type(image)}] is not supported.") | |
| return patches | |
| def get_params(self, image_size): | |
| if self.patch_size is None: | |
| patch_size = tuple(image_size[i] // self.grid_size[i] for i in range(2)) | |
| else: | |
| patch_size = self.patch_size | |
| steps = tuple( | |
| (image_size[i] - patch_size[i]) // (self.grid_size[i] - 1) if self.grid_size[i] > 1 else image_size[i] | |
| for i in range(2) | |
| ) | |
| return patch_size, steps |
Metadata
Metadata
Assignees
Labels
refactorNon-breaking feature enhancementsNon-breaking feature enhancements
Type
Projects
Status
💯 Complete