-
Notifications
You must be signed in to change notification settings - Fork 1.4k
TileOnGrid support for Tensor input
#3384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2202e95
bed4ef5
e752f3e
ba03758
722bc65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
|
|
||
| from monai.config.type_definitions import NdarrayOrTensor | ||
| from monai.transforms.transform import Randomizable, Transform | ||
| from monai.utils import convert_data_type, convert_to_dst_type | ||
| from monai.utils.enums import TransformBackends | ||
|
|
||
| __all__ = ["SplitOnGrid", "TileOnGrid"] | ||
|
|
@@ -129,6 +130,8 @@ class TileOnGrid(Randomizable, Transform): | |
|
|
||
| """ | ||
|
|
||
| backend = [TransformBackends.NUMPY] | ||
|
|
||
| def __init__( | ||
| self, | ||
| tile_count: Optional[int] = None, | ||
|
|
@@ -185,37 +188,39 @@ def randomize(self, img_size: Sequence[int]) -> None: | |
| else: | ||
| self.random_idxs = np.array((0,)) | ||
|
|
||
| def __call__(self, image: np.ndarray) -> np.ndarray: | ||
| def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: | ||
| img_np: np.ndarray | ||
| img_np, *_ = convert_data_type(image, np.ndarray) # type: ignore | ||
|
|
||
| # add random offset | ||
| self.randomize(img_size=image.shape) | ||
| self.randomize(img_size=img_np.shape) | ||
|
|
||
| if self.random_offset and (self.offset[0] > 0 or self.offset[1] > 0): | ||
| image = image[:, self.offset[0] :, self.offset[1] :] | ||
| img_np = img_np[:, self.offset[0] :, self.offset[1] :] | ||
|
|
||
| # pad to full size, divisible by tile_size | ||
| if self.pad_full: | ||
| c, h, w = image.shape | ||
| c, h, w = img_np.shape | ||
| pad_h = (self.tile_size - h % self.tile_size) % self.tile_size | ||
| pad_w = (self.tile_size - w % self.tile_size) % self.tile_size | ||
| image = np.pad( | ||
| image, | ||
| img_np = np.pad( | ||
| img_np, | ||
| [[0, 0], [pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2]], | ||
| constant_values=self.background_val, | ||
| ) | ||
|
|
||
| # extact tiles | ||
| x_step, y_step = self.step, self.step | ||
| x_size, y_size = self.tile_size, self.tile_size | ||
| c_len, x_len, y_len = image.shape | ||
| c_stride, x_stride, y_stride = image.strides | ||
| h_step, w_step = self.step, self.step | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel the previous variable
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I can fix this too. sorry that I didn't see your comments before merging.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| h_size, w_size = self.tile_size, self.tile_size | ||
| c_len, h_len, w_len = img_np.shape | ||
| c_stride, h_stride, w_stride = img_np.strides | ||
| llw = as_strided( | ||
| image, | ||
| shape=((x_len - x_size) // x_step + 1, (y_len - y_size) // y_step + 1, c_len, x_size, y_size), | ||
| strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), | ||
| img_np, | ||
| shape=((h_len - h_size) // h_step + 1, (w_len - w_size) // w_step + 1, c_len, h_size, w_size), | ||
| strides=(h_stride * h_step, w_stride * w_step, c_stride, h_stride, w_stride), | ||
| writeable=False, | ||
| ) | ||
| image = llw.reshape(-1, c_len, x_size, y_size) | ||
| img_np = llw.reshape(-1, c_len, h_size, w_size) | ||
|
|
||
| # if keeping all patches | ||
| if self.tile_count is None: | ||
|
|
@@ -224,32 +229,34 @@ def __call__(self, image: np.ndarray) -> np.ndarray: | |
| thresh = 0.999 * 3 * self.background_val * self.tile_size * self.tile_size | ||
| if self.filter_mode == "min": | ||
| # default, keep non-background tiles (small values) | ||
| idxs = np.argwhere(image.sum(axis=(1, 2, 3)) < thresh) | ||
| image = image[idxs.reshape(-1)] | ||
| idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) < thresh) | ||
| img_np = img_np[idxs.reshape(-1)] | ||
| elif self.filter_mode == "max": | ||
| idxs = np.argwhere(image.sum(axis=(1, 2, 3)) >= thresh) | ||
| image = image[idxs.reshape(-1)] | ||
| idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) >= thresh) | ||
| img_np = img_np[idxs.reshape(-1)] | ||
|
|
||
| else: | ||
| if len(image) > self.tile_count: | ||
| if len(img_np) > self.tile_count: | ||
|
|
||
| if self.filter_mode == "min": | ||
| # default, keep non-background tiles (smallest values) | ||
| idxs = np.argsort(image.sum(axis=(1, 2, 3)))[: self.tile_count] | ||
| image = image[idxs] | ||
| idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[: self.tile_count] | ||
| img_np = img_np[idxs] | ||
| elif self.filter_mode == "max": | ||
| idxs = np.argsort(image.sum(axis=(1, 2, 3)))[-self.tile_count :] | ||
| image = image[idxs] | ||
| idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[-self.tile_count :] | ||
| img_np = img_np[idxs] | ||
| else: | ||
| # random subset (more appropriate for WSIs without distinct background) | ||
| if self.random_idxs is not None: | ||
| image = image[self.random_idxs] | ||
| img_np = img_np[self.random_idxs] | ||
|
|
||
| elif len(image) < self.tile_count: | ||
| image = np.pad( | ||
| image, | ||
| [[0, self.tile_count - len(image)], [0, 0], [0, 0], [0, 0]], | ||
| elif len(img_np) < self.tile_count: | ||
| img_np = np.pad( | ||
| img_np, | ||
| [[0, self.tile_count - len(img_np)], [0, 0], [0, 0], [0, 0]], | ||
| constant_values=self.background_val, | ||
| ) | ||
|
|
||
| image, *_ = convert_to_dst_type(src=img_np, dst=image, dtype=image.dtype) | ||
|
|
||
| return image | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like
SplitOnGrid(2, 2)(torch.arange(12).reshape(1, 3, 4))works fine butSplitOnGrid(2, 2)(np.arange(12).reshape(1, 3, 4))doesn't work, could you help fix it here? #3378 was merged too quicklyThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wyli, sorry I missed this. I will fix it now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#3386