Skip to content

Refactor Split transform #4013

@bhashemian

Description

@bhashemian

Refactor and possibly generalize SplitOnGrid transform in pathology, for blending it into core MONAI as laid out in #4005.

class SplitOnGrid(Transform):
"""
Split the image into patches based on the provided grid shape.
This transform works only with torch.Tensor inputs.
Args:
grid_size: a tuple or an integer define the shape of the grid upon which to extract patches.
If it's an integer, the value will be repeated for each dimension. Default is 2x2
patch_size: a tuple or an integer that defines the output patch sizes.
If it's an integer, the value will be repeated for each dimension.
The default is (0, 0), where the patch size will be inferred from the grid shape.
Note: the shape of the input image is inferred based on the first image used.
"""
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
def __init__(
self, grid_size: Union[int, Tuple[int, int]] = (2, 2), patch_size: Optional[Union[int, Tuple[int, int]]] = None
):
# Grid size
if isinstance(grid_size, int):
self.grid_size = (grid_size, grid_size)
else:
self.grid_size = grid_size
# Patch size
self.patch_size = None
if isinstance(patch_size, int):
self.patch_size = (patch_size, patch_size)
else:
self.patch_size = patch_size
def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor:
if self.grid_size == (1, 1) and self.patch_size is None:
if isinstance(image, torch.Tensor):
return torch.stack([image])
elif isinstance(image, np.ndarray):
return np.stack([image]) # type: ignore
else:
raise ValueError(f"Input type [{type(image)}] is not supported.")
patch_size, steps = self.get_params(image.shape[1:])
patches: NdarrayOrTensor
if isinstance(image, torch.Tensor):
patches = (
image.unfold(1, patch_size[0], steps[0])
.unfold(2, patch_size[1], steps[1])
.flatten(1, 2)
.transpose(0, 1)
.contiguous()
)
elif isinstance(image, np.ndarray):
x_step, y_step = steps
c_stride, x_stride, y_stride = image.strides
n_channels = image.shape[0]
patches = as_strided(
image,
shape=(*self.grid_size, n_channels, patch_size[0], patch_size[1]),
strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride),
writeable=False,
)
# flatten the first two dimensions
patches = patches.reshape(np.prod(patches.shape[:2]), *patches.shape[2:])
# make it a contiguous array
patches = np.ascontiguousarray(patches)
else:
raise ValueError(f"Input type [{type(image)}] is not supported.")
return patches
def get_params(self, image_size):
if self.patch_size is None:
patch_size = tuple(image_size[i] // self.grid_size[i] for i in range(2))
else:
patch_size = self.patch_size
steps = tuple(
(image_size[i] - patch_size[i]) // (self.grid_size[i] - 1) if self.grid_size[i] > 1 else image_size[i]
for i in range(2)
)
return patch_size, steps

Metadata

Metadata

Assignees

Labels

refactorNon-breaking feature enhancements

Type

No type

Projects

Status

💯 Complete

Relationships

None yet

Development

No branches or pull requests

Issue actions