From c024e591104e85370a902f6cfa6f1e92e361f592 Mon Sep 17 00:00:00 2001 From: Chenglong Wang Date: Tue, 2 Aug 2022 21:55:08 +0800 Subject: [PATCH 01/12] feat(transforms): Add RandSoftCopyPaste (only 3D currently) and its test. --- monai_ex/tests/test_RandSoftCopyPaste.py | 33 +++++++ monai_ex/transforms/io/array.py | 6 +- monai_ex/transforms/utility/array.py | 121 +++++++++++++++++++---- 3 files changed, 137 insertions(+), 23 deletions(-) create mode 100644 monai_ex/tests/test_RandSoftCopyPaste.py diff --git a/monai_ex/tests/test_RandSoftCopyPaste.py b/monai_ex/tests/test_RandSoftCopyPaste.py new file mode 100644 index 0000000..1c94b81 --- /dev/null +++ b/monai_ex/tests/test_RandSoftCopyPaste.py @@ -0,0 +1,33 @@ +import pytest + +from pathlib import Path +import nibabel as nib + +import numpy as np +from monai_ex.transforms.utility.array import RandSoftCopyPaste +from monai.data.synthetic import create_test_image_3d +from monai_ex.transforms.io.array import GenerateSyntheticData + +def test_randsoftcopypaste(): + spatial_size = (100, 100, 100) + generator = GenerateSyntheticData( + *spatial_size, + num_objs=1, + rad_max=10, + rad_min=9, + noise_max=0.2, + num_seg_classes=1, + channel_dim=0, + ) + + src_image, src_mask = generator(None) + tar_image, tar_mask = generator(None) + + print("dummy data, mask shape:", src_image.shape, src_image.shape) + print("mask label: ", np.unique(src_mask)) + sythetic_img = RandSoftCopyPaste(2, 4, label_idx=1)(src_image, src_mask, tar_image, tar_mask==0) + + assert sythetic_img.shape == (1, *spatial_size) + + # save_fpath = Path.home() / "sythetic_img.nii.gz" + # nib.save(nib.Nifti1Image(sythetic_img.squeeze(), np.eye(4)), save_fpath) diff --git a/monai_ex/transforms/io/array.py b/monai_ex/transforms/io/array.py index 65039a9..1b61464 100644 --- a/monai_ex/transforms/io/array.py +++ b/monai_ex/transforms/io/array.py @@ -57,7 +57,7 @@ def __call__(self, data: Any): self.channel_dim, self.random_state, ) - + if self.num_seg_classes < 1: return img, img @@ -88,9 +88,7 @@ def __call__(self, data: Any): image = np.random.rand(self.width, self.height) if self.channel_dim is not None: - if not ( - isinstance(self.channel_dim, int) and self.channel_dim in (-1, 0, 3) - ): + if not (isinstance(self.channel_dim, int) and self.channel_dim in (-1, 0, 3)): raise AssertionError("invalid channel dim.") image = image[None] if self.channel_dim == 0 else image[..., None] diff --git a/monai_ex/transforms/utility/array.py b/monai_ex/transforms/utility/array.py index b2c1df5..95c8956 100644 --- a/monai_ex/transforms/utility/array.py +++ b/monai_ex/transforms/utility/array.py @@ -7,8 +7,10 @@ from monai.transforms.compose import Transform, Randomizable from monai.transforms import DataStats, SaveImage, CastToType +from monai.transforms.utils import generate_pos_neg_label_crop_centers, map_binary_to_indices from monai.config import NdarrayTensor, DtypeLike from monai_ex.utils import convert_data_type_ex +from strix.utilities.utils import bbox_3D class CastToTypeEx(CastToType): @@ -24,9 +26,7 @@ def __init__(self, dtype=np.float32) -> None: """ self.dtype = dtype - def __call__( - self, img: Any, dtype: Optional[Union[DtypeLike, torch.dtype]] = None - ) -> Any: + def __call__(self, img: Any, dtype: Optional[Union[DtypeLike, torch.dtype]] = None) -> Any: """ Apply the transform to `img`, assuming `img` is a numpy array or PyTorch Tensor. @@ -39,12 +39,8 @@ def __call__( """ type_list = (torch.Tensor, np.ndarray, int, bool, float, list, tuple) if not isinstance(img, type_list): - raise TypeError( - f"img must be one of ({type_list}) but is {type(img).__name__}." - ) - img_out, *_ = convert_data_type_ex( - img, output_type=type(img), dtype=dtype or self.dtype - ) + raise TypeError(f"img must be one of ({type_list}) but is {type(img).__name__}.") + img_out, *_ = convert_data_type_ex(img, output_type=type(img), dtype=dtype or self.dtype) return img_out @@ -53,7 +49,7 @@ class ToTensorEx(Transform): Converts the input image to a tensor without applying any other transformations. """ - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + def __call__(self, img: NdarrayTensor) -> torch.Tensor: """ Apply the transform to `img` and make it contiguous. """ @@ -105,9 +101,7 @@ def __call__( data_value: Optional[bool] = None, additional_info: Optional[Callable] = None, ) -> NdarrayTensor: - img = super().__init__( - img, prefix, data_type, data_shape, value_range, data_value, additional_info - ) + img = super().__init__(img, prefix, data_type, data_shape, value_range, data_value, additional_info) if self.save_data: saver = SaveImage( @@ -204,12 +198,101 @@ def __call__( if img.shape[0] > 1: data = img[[self.select_label]] else: - data = np.where(np.in1d(img, self.select_label), True, False).reshape( - img.shape + data = np.where(np.in1d(img, self.select_label), True, False).reshape(img.shape) + + return np.any(data, axis=0, keepdims=True) if (merge_channels or self.merge_channels) else data + + +class RandSoftCopyPaste(Randomizable, Transform): + """ + Perform Soft Copy&Paste augmentation. + Reference: `https://arxiv.org/ftp/arxiv/papers/2203/2203.10507.pdf` + """ + + def __init__( + self, + k_erode: int, + k_dilate: int, + alpha: float = 0.8, + label_idx: Optional[int] = None, + ) -> None: + super().__init__() + self.k_erode = k_erode + self.k_dilate = k_dilate + self.alpha = alpha + self.label_idx = label_idx + + def soften(self, src_mask): + struct = ndi.generate_binary_structure(src_mask.ndim, 2) + mask = ndi.binary_erosion(src_mask, struct, iterations=self.k_erode).astype(src_mask.dtype) + for j in range(self.k_dilate): + mask_binary = np.where(mask > 0, 1, 0) + mask_dilate = ndi.binary_dilation(mask_binary, struct, iterations=1).astype(mask.dtype) + mask_alpha = mask_dilate * (self.alpha ** (j + 1)) + mask = (1 - mask) * mask_alpha + mask + + return mask + + def paste( + self, + softed_image: NdarrayTensor, + softed_mask: NdarrayTensor, + target_image: NdarrayTensor, + target_mask: NdarrayTensor, + ): + if target_mask is None: + pass + else: + x1, x2, y1, y2, z1, z2 = bbox_3D(softed_mask[0, ...]) + x_sz, y_sz, z_sz = x2 - x1, y2 - y1, z2 - z1 + fg_indices_, bg_indices_ = map_binary_to_indices(target_mask, None, None) + centers = generate_pos_neg_label_crop_centers( + (x_sz, y_sz, z_sz), + 1, + 1, + softed_mask.shape[1:], + fg_indices_, + bg_indices_, + self.R, + False, ) - return ( - np.any(data, axis=0, keepdims=True) - if (merge_channels or self.merge_channels) - else data + shifted_src_image = np.zeros_like(target_image) + shifted_src_mask = np.zeros_like(target_image) + x_range = slice(int(centers[0][0] - x_sz // 2), int(centers[0][0] - x_sz // 2 + x_sz)) + y_range = slice(int(centers[0][1] - y_sz // 2), int(centers[0][1] - y_sz // 2 + y_sz)) + z_range = slice(int(centers[0][2] - z_sz // 2), int(centers[0][2] - z_sz // 2 + z_sz)) + + shifted_src_image[:, x_range, y_range, z_range] = softed_image[:, x1:x2, y1:y2, z1:z2] + shifted_src_mask[:, x_range, y_range, z_range] = softed_mask[:, x1:x2, y1:y2, z1:z2] + sythetic_image = shifted_src_image + (1 - shifted_src_mask) * target_image + return sythetic_image + + def __call__( + self, + source_image: NdarrayTensor, + source_mask: NdarrayTensor, + target_image: NdarrayTensor, + target_mask: Optional[NdarrayTensor] = None, + ) -> NdarrayTensor: + if source_mask.shape[0] > 1: + if self.label_idx is None: + raise ValueError("Multi-channel label data need to specify label_idx") + else: + source_mask = source_mask[self.label_idx, ...] + elif source_mask.shape[0] == 1: + if self.label_idx is None: + source_mask = (source_mask > 0).squeeze(0) + else: + source_mask = (source_mask == self.label_idx).squeeze(0) + + softed_mask = self.soften(source_mask) + softed_mask = softed_mask[np.newaxis, ...] + if source_image.shape[0] > 1: + softed_mask = np.repeat(softed_mask, repeats=source_image.shape[0], axis=0) + + softed_image = source_image * softed_mask + sythetic_image = self.paste( + softed_image=softed_image, softed_mask=softed_mask, target_image=target_image, target_mask=target_mask ) + return sythetic_image From 6ab9e2545df51489c0efe2ef259c6ace3548e991 Mon Sep 17 00:00:00 2001 From: Chenglong Wang Date: Wed, 3 Aug 2022 13:01:36 +0800 Subject: [PATCH 02/12] feat(utils): feat(utils): Add bbox_nd function for ndarray. Support arbitary dimensional data --- monai_ex/tests/test_bbox_nd.py | 23 +++++++++++++++++ monai_ex/utils/misc.py | 45 +++++++++++++++++++++++++++------- 2 files changed, 59 insertions(+), 9 deletions(-) create mode 100644 monai_ex/tests/test_bbox_nd.py diff --git a/monai_ex/tests/test_bbox_nd.py b/monai_ex/tests/test_bbox_nd.py new file mode 100644 index 0000000..0920e83 --- /dev/null +++ b/monai_ex/tests/test_bbox_nd.py @@ -0,0 +1,23 @@ +import pytest +from monai_ex.utils.misc import bbox_ND +import numpy as np + +dummy_data_3d = np.zeros([10, 10, 10]) +dummy_data_3d[4:7, 4:8, 4:9] = 1 + +dummy_data_2d = np.zeros([10, 10]) +dummy_data_2d[4:7, 4:8] = 1 + +@pytest.mark.parametrize('data', [dummy_data_2d, dummy_data_3d]) +def test_bbox_nd(data): + bounding = bbox_ND(data, False) + if len(data) == 3: + assert bounding == (4, 6, 4, 7, 4, 8) + elif len(data) == 2: + assert bounding == (4, 6, 4, 7) + + bbox_range = bbox_ND(data, True) + if len(data) == 3: + assert bbox_range == (2, 3, 4) + elif len(data) == 2: + assert bbox_range == (2, 3) diff --git a/monai_ex/utils/misc.py b/monai_ex/utils/misc.py index 9ae7b62..5ba3d9a 100644 --- a/monai_ex/utils/misc.py +++ b/monai_ex/utils/misc.py @@ -1,14 +1,36 @@ +import itertools +from functools import partial from typing import Any, List -import torch - +import numpy as np +import torch from monai.utils.misc import issequenceiterable -from functools import partial from utils_cw import catch_exception + from .exceptions import GenericException +trycatch = partial(catch_exception, handled_exception_type=GenericException, path_keywords=["strix", "monai_ex"]) + -trycatch = partial(catch_exception, handled_exception_type=GenericException, path_keywords=['strix','monai_ex']) +# Copied from strix +def bbox_ND(img: np.ndarray, return_range: bool = False): + """Compute boundingbox of n-dimensional data. + + Args: + img (np.ndarray): Input nd data. + return_range (bool, optional): if return size of each boundingbox. Defaults to False. + + Returns: + list: boundingbox coord/size list + """ + N = img.ndim + out = [] + for ax in itertools.combinations(reversed(range(N)), N - 1): + nonzero = np.any(img, axis=ax) + out.extend(np.where(nonzero)[0][[0, -1]]) + if return_range: + return tuple(out[2 * i + 1] - out[2 * i] for i in range(len(out) // 2)) + return tuple(out) def ensure_same_dim(tensor1, tensor2): @@ -30,7 +52,9 @@ def ensure_list(vals: Any): Returns a list of `vals`. """ if not issequenceiterable(vals) or isinstance(vals, dict): - vals = [vals, ] + vals = [ + vals, + ] return list(vals) @@ -43,7 +67,9 @@ def ensure_list_rep(vals: Any, dim: int) -> List[Any]: ValueError: When ``tup`` is a sequence and ``tup`` length is not ``dim``. """ if not issequenceiterable(vals): - return [vals,] * dim + return [ + vals, + ] * dim elif len(vals) == dim: return list(vals) @@ -56,7 +82,7 @@ def _register_generic(module_dict, module_name, module): class Registry(dict): - ''' + """ A helper class for managing registering modules, it extends a dictionary and provides a register functions. @@ -76,7 +102,8 @@ def foo(): Access of module is just like using a dictionary, eg: f = some_registry["foo_modeul"] - ''' + """ + def __init__(self, *args, **kwargs): super(Registry, self).__init__(*args, **kwargs) @@ -91,4 +118,4 @@ def register_fn(fn): _register_generic(self, module_name, fn) return fn - return register_fn \ No newline at end of file + return register_fn From 643c70aebc361a34c4dbf731ce68342212e4d334 Mon Sep 17 00:00:00 2001 From: Chenglong Wang Date: Wed, 3 Aug 2022 13:31:31 +0800 Subject: [PATCH 03/12] feat(transforms): Add ndarray support not only 3D --- monai_ex/tests/test_RandSoftCopyPaste.py | 9 +++++---- monai_ex/transforms/utility/array.py | 23 ++++++++++++----------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/monai_ex/tests/test_RandSoftCopyPaste.py b/monai_ex/tests/test_RandSoftCopyPaste.py index 1c94b81..ef4254a 100644 --- a/monai_ex/tests/test_RandSoftCopyPaste.py +++ b/monai_ex/tests/test_RandSoftCopyPaste.py @@ -8,8 +8,9 @@ from monai.data.synthetic import create_test_image_3d from monai_ex.transforms.io.array import GenerateSyntheticData -def test_randsoftcopypaste(): - spatial_size = (100, 100, 100) +@pytest.mark.parametrize('dim', [2, 3]) +def test_randsoftcopypaste(dim): + spatial_size = (100,) * dim generator = GenerateSyntheticData( *spatial_size, num_objs=1, @@ -29,5 +30,5 @@ def test_randsoftcopypaste(): assert sythetic_img.shape == (1, *spatial_size) - # save_fpath = Path.home() / "sythetic_img.nii.gz" - # nib.save(nib.Nifti1Image(sythetic_img.squeeze(), np.eye(4)), save_fpath) + save_fpath = Path.home() / f"sythetic_img_{dim}.nii.gz" + nib.save(nib.Nifti1Image(sythetic_img.squeeze(), np.eye(4)), save_fpath) diff --git a/monai_ex/transforms/utility/array.py b/monai_ex/transforms/utility/array.py index 95c8956..2389d88 100644 --- a/monai_ex/transforms/utility/array.py +++ b/monai_ex/transforms/utility/array.py @@ -9,8 +9,7 @@ from monai.transforms import DataStats, SaveImage, CastToType from monai.transforms.utils import generate_pos_neg_label_crop_centers, map_binary_to_indices from monai.config import NdarrayTensor, DtypeLike -from monai_ex.utils import convert_data_type_ex -from strix.utilities.utils import bbox_3D +from monai_ex.utils import convert_data_type_ex, bbox_ND class CastToTypeEx(CastToType): @@ -243,11 +242,11 @@ def paste( if target_mask is None: pass else: - x1, x2, y1, y2, z1, z2 = bbox_3D(softed_mask[0, ...]) - x_sz, y_sz, z_sz = x2 - x1, y2 - y1, z2 - z1 + boundingbox = bbox_ND(softed_mask[0, ...]) + bbox_size = tuple(boundingbox[2 * i + 1] - boundingbox[2 * i] for i in range(len(boundingbox) // 2)) fg_indices_, bg_indices_ = map_binary_to_indices(target_mask, None, None) centers = generate_pos_neg_label_crop_centers( - (x_sz, y_sz, z_sz), + bbox_size, 1, 1, softed_mask.shape[1:], @@ -259,12 +258,14 @@ def paste( shifted_src_image = np.zeros_like(target_image) shifted_src_mask = np.zeros_like(target_image) - x_range = slice(int(centers[0][0] - x_sz // 2), int(centers[0][0] - x_sz // 2 + x_sz)) - y_range = slice(int(centers[0][1] - y_sz // 2), int(centers[0][1] - y_sz // 2 + y_sz)) - z_range = slice(int(centers[0][2] - z_sz // 2), int(centers[0][2] - z_sz // 2 + z_sz)) - - shifted_src_image[:, x_range, y_range, z_range] = softed_image[:, x1:x2, y1:y2, z1:z2] - shifted_src_mask[:, x_range, y_range, z_range] = softed_mask[:, x1:x2, y1:y2, z1:z2] + n_ch = shifted_src_mask.shape[0] + tar_ranges = tuple(slice(int(center - sz // 2), int(center - sz // 2 + sz)) for center, sz in zip(centers[0], bbox_size)) + tar_slices = [slice(n_ch), *tar_ranges] + src_ranges = tuple(slice(boundingbox[2 * i], boundingbox[2 * i + 1]) for i in range(len(boundingbox) // 2)) + src_slices = [slice(n_ch), *src_ranges] + + shifted_src_image[tar_slices] = softed_image[src_slices] + shifted_src_mask[tar_slices] = softed_mask[src_slices] sythetic_image = shifted_src_image + (1 - shifted_src_mask) * target_image return sythetic_image From e02e72cb75a4816d78db36be32773dd89309614a Mon Sep 17 00:00:00 2001 From: Chenglong Wang Date: Wed, 3 Aug 2022 13:48:29 +0800 Subject: [PATCH 04/12] If target_mask is not provide, paste to original position. --- monai_ex/tests/test_RandSoftCopyPaste.py | 15 +++++++------ monai_ex/transforms/utility/array.py | 27 ++++++++++++------------ 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/monai_ex/tests/test_RandSoftCopyPaste.py b/monai_ex/tests/test_RandSoftCopyPaste.py index ef4254a..f68a69c 100644 --- a/monai_ex/tests/test_RandSoftCopyPaste.py +++ b/monai_ex/tests/test_RandSoftCopyPaste.py @@ -3,12 +3,13 @@ from pathlib import Path import nibabel as nib -import numpy as np +import numpy as np from monai_ex.transforms.utility.array import RandSoftCopyPaste from monai.data.synthetic import create_test_image_3d from monai_ex.transforms.io.array import GenerateSyntheticData -@pytest.mark.parametrize('dim', [2, 3]) + +@pytest.mark.parametrize("dim", [2, 3]) def test_randsoftcopypaste(dim): spatial_size = (100,) * dim generator = GenerateSyntheticData( @@ -26,9 +27,11 @@ def test_randsoftcopypaste(dim): print("dummy data, mask shape:", src_image.shape, src_image.shape) print("mask label: ", np.unique(src_mask)) - sythetic_img = RandSoftCopyPaste(2, 4, label_idx=1)(src_image, src_mask, tar_image, tar_mask==0) - + sythetic_img = RandSoftCopyPaste(2, 4, label_idx=1)(src_image, src_mask, tar_image, tar_mask == 0) assert sythetic_img.shape == (1, *spatial_size) - save_fpath = Path.home() / f"sythetic_img_{dim}.nii.gz" - nib.save(nib.Nifti1Image(sythetic_img.squeeze(), np.eye(4)), save_fpath) + # save_fpath = Path.home() / f"sythetic_img_{dim}.nii.gz" + # nib.save(nib.Nifti1Image(sythetic_img.squeeze(), np.eye(4)), save_fpath) + + sythetic_img = RandSoftCopyPaste(2, 4, label_idx=1)(src_image, src_mask, tar_image, None) + assert sythetic_img.shape == (1, *spatial_size) diff --git a/monai_ex/transforms/utility/array.py b/monai_ex/transforms/utility/array.py index 2389d88..3ab79d3 100644 --- a/monai_ex/transforms/utility/array.py +++ b/monai_ex/transforms/utility/array.py @@ -239,11 +239,16 @@ def paste( target_image: NdarrayTensor, target_mask: NdarrayTensor, ): + n_ch = target_image.shape[0] + boundingbox = bbox_ND(softed_mask[0, ...]) + bbox_size = tuple(boundingbox[2 * i + 1] - boundingbox[2 * i] for i in range(len(boundingbox) // 2)) + src_ranges = tuple(slice(boundingbox[2 * i], boundingbox[2 * i + 1]) for i in range(len(boundingbox) // 2)) + src_slices = [slice(n_ch), *src_ranges] + if target_mask is None: - pass + # ! if no target mask is provided, paste to orignal pos. + tar_slices = src_slices else: - boundingbox = bbox_ND(softed_mask[0, ...]) - bbox_size = tuple(boundingbox[2 * i + 1] - boundingbox[2 * i] for i in range(len(boundingbox) // 2)) fg_indices_, bg_indices_ = map_binary_to_indices(target_mask, None, None) centers = generate_pos_neg_label_crop_centers( bbox_size, @@ -255,19 +260,15 @@ def paste( self.R, False, ) - - shifted_src_image = np.zeros_like(target_image) - shifted_src_mask = np.zeros_like(target_image) - n_ch = shifted_src_mask.shape[0] tar_ranges = tuple(slice(int(center - sz // 2), int(center - sz // 2 + sz)) for center, sz in zip(centers[0], bbox_size)) tar_slices = [slice(n_ch), *tar_ranges] - src_ranges = tuple(slice(boundingbox[2 * i], boundingbox[2 * i + 1]) for i in range(len(boundingbox) // 2)) - src_slices = [slice(n_ch), *src_ranges] - shifted_src_image[tar_slices] = softed_image[src_slices] - shifted_src_mask[tar_slices] = softed_mask[src_slices] - sythetic_image = shifted_src_image + (1 - shifted_src_mask) * target_image - return sythetic_image + shifted_src_image = np.zeros_like(target_image) + shifted_src_mask = np.zeros_like(target_image) + shifted_src_image[tar_slices] = softed_image[src_slices] + shifted_src_mask[tar_slices] = softed_mask[src_slices] + sythetic_image = shifted_src_image + (1 - shifted_src_mask) * target_image + return sythetic_image def __call__( self, From 7f15d03b3fdbb48cae70ace6761b1d8ad21600e2 Mon Sep 17 00:00:00 2001 From: Chenglong Wang Date: Thu, 11 Aug 2022 23:11:54 +0800 Subject: [PATCH 05/12] feat(transforms): Added RandSoftCopyPasteD (not tested). --- monai_ex/transforms/compose.py | 2 - monai_ex/transforms/utility/dictionary.py | 53 +++++++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/monai_ex/transforms/compose.py b/monai_ex/transforms/compose.py index 5c3d42e..c4ab231 100644 --- a/monai_ex/transforms/compose.py +++ b/monai_ex/transforms/compose.py @@ -53,8 +53,6 @@ def __call__(self, input_): return apply_transform(self.selected_trans, input_) -ReturnType = TypeVar("ReturnType") - def _apply_transform( transform: Callable[..., ReturnType], parameters: Any, unpack_parameters: bool = False ) -> ReturnType: diff --git a/monai_ex/transforms/utility/dictionary.py b/monai_ex/transforms/utility/dictionary.py index 7665771..e034eff 100644 --- a/monai_ex/transforms/utility/dictionary.py +++ b/monai_ex/transforms/utility/dictionary.py @@ -3,11 +3,13 @@ import numpy as np import torch +from torch.utils.data import Dataset from monai.config import KeysCollection, NdarrayTensor, NdarrayOrTensor from monai.transforms.compose import MapTransform, Randomizable from monai.utils import ensure_tuple_rep from monai_ex.utils import ensure_list +from monai_ex.utils.exceptions import TransformException from monai.transforms import SplitChannel from monai_ex.transforms.utility.array import ( @@ -16,6 +18,7 @@ DataStatsEx, DataLabelling, RandLabelToMask, + RandSoftCopyPaste ) from monai_ex.transforms import ( @@ -445,6 +448,56 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N d[key] = d[key][index] return d + +class RandSoftCopyPasted(Randomizable, MapTransform): + """Dictionary-based wrapper of :py:class:`monai_ex.transforms.RandSoftCopyPaste`. + + Args: + keys (KeysCollection): keys of the corresponding items to be transformed + mask_key (Optional[str]): key of the mask data + source_dataset (Dataset): a dataset return source image and mask + k_erode (int): erosion times. + k_dilate (int): dilation times. + alpha (float, optional): alpha. Defaults to 0.8. + label_idx (Optional[int], optional): the label of souce mask to be proceed. Defaults to None. + """ + def __init__( + self, + keys: KeysCollection, + mask_key: Optional[str], + source_dataset: Dataset, + k_erode: int, + k_dilate: int, + alpha: float = 0.8, + label_idx: Optional[int] = None, + ) -> None: + super().__init__(keys) + self.mask_key = mask_key + self.source_dataset = source_dataset + self.generator = RandSoftCopyPaste(k_erode, k_dilate, alpha, label_idx) + + def randomize(self) -> None: + return self.R.randint(len(self.source_dataset)) + + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = dict(data) + idx = self.randomize() + + d = dict(data) + for key in self.key_iterator(d): + target_image = d[key] + target_mask = d[self.mask_key] if self.mask_key else None + + try: + source_image, source_mask = self.source_dataset[idx] + except ValueError as e: + raise TransformException("Source dataset should return two data: source_image, source_mask.\nErr msg: {e}") + else: + d[key] = self.generator(source_image, source_mask, target_image, target_mask) + + return d + + ToTensorExD = ToTensorExDict = ToTensorExd CastToTypeExD = CastToTypeExDict = CastToTypeExd DataStatsExD = DataStatsExDict = DataStatsExd From 111352b64c75c4aecb21b284ac90707bfc0497ac Mon Sep 17 00:00:00 2001 From: Chenglong Wang Date: Fri, 12 Aug 2022 15:22:24 +0800 Subject: [PATCH 06/12] Add default value for `GenerateSyntheticData`. Rename `label_idx` to `source_label_value`. Add `mask_select_fn`. --- monai_ex/transforms/io/array.py | 4 ++-- monai_ex/transforms/utility/array.py | 14 ++++++++------ monai_ex/transforms/utility/dictionary.py | 14 ++++++++++---- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/monai_ex/transforms/io/array.py b/monai_ex/transforms/io/array.py index 1b61464..1661d66 100644 --- a/monai_ex/transforms/io/array.py +++ b/monai_ex/transforms/io/array.py @@ -31,7 +31,7 @@ def __init__( self.channel_dim = channel_dim self.random_state = random_state - def __call__(self, data: Any): + def __call__(self, data: Optional[Any] = None): if self.depth: img, seg = create_test_image_3d( self.height, @@ -81,7 +81,7 @@ def __init__( self.channel_dim = channel_dim self.random_state = random_state - def __call__(self, data: Any): + def __call__(self, data: Optional[Any] = None): if self.depth: image = np.random.rand(self.width, self.height, self.depth) else: diff --git a/monai_ex/transforms/utility/array.py b/monai_ex/transforms/utility/array.py index 3ab79d3..a1b8e90 100644 --- a/monai_ex/transforms/utility/array.py +++ b/monai_ex/transforms/utility/array.py @@ -213,13 +213,15 @@ def __init__( k_erode: int, k_dilate: int, alpha: float = 0.8, - label_idx: Optional[int] = None, + source_label_value: Optional[int] = None, + log_name: Optional[str] = None, ) -> None: super().__init__() self.k_erode = k_erode self.k_dilate = k_dilate self.alpha = alpha - self.label_idx = label_idx + self.source_label_value = source_label_value + self.logger = logging.getLogger(log_name) def soften(self, src_mask): struct = ndi.generate_binary_structure(src_mask.ndim, 2) @@ -278,15 +280,15 @@ def __call__( target_mask: Optional[NdarrayTensor] = None, ) -> NdarrayTensor: if source_mask.shape[0] > 1: - if self.label_idx is None: + if self.source_label_value is None: raise ValueError("Multi-channel label data need to specify label_idx") else: - source_mask = source_mask[self.label_idx, ...] + source_mask = source_mask[self.source_label_value, ...] elif source_mask.shape[0] == 1: - if self.label_idx is None: + if self.source_label_value is None: source_mask = (source_mask > 0).squeeze(0) else: - source_mask = (source_mask == self.label_idx).squeeze(0) + source_mask = (source_mask == self.source_label_value).squeeze(0) softed_mask = self.soften(source_mask) softed_mask = softed_mask[np.newaxis, ...] diff --git a/monai_ex/transforms/utility/dictionary.py b/monai_ex/transforms/utility/dictionary.py index e034eff..9832b32 100644 --- a/monai_ex/transforms/utility/dictionary.py +++ b/monai_ex/transforms/utility/dictionary.py @@ -11,7 +11,8 @@ from monai_ex.utils import ensure_list from monai_ex.utils.exceptions import TransformException -from monai.transforms import SplitChannel +from monai.transforms.utility.array import SplitChannel +from monai.transforms.utils import is_positive from monai_ex.transforms.utility.array import ( CastToTypeEx, ToTensorEx, @@ -469,12 +470,16 @@ def __init__( k_erode: int, k_dilate: int, alpha: float = 0.8, - label_idx: Optional[int] = None, + mask_select_fn: Callable = is_positive, + source_label_value: Optional[int] = None, + log_name: Optional[str] = None, ) -> None: super().__init__(keys) self.mask_key = mask_key self.source_dataset = source_dataset - self.generator = RandSoftCopyPaste(k_erode, k_dilate, alpha, label_idx) + self.generator = RandSoftCopyPaste(k_erode, k_dilate, alpha, source_label_value, log_name) + self.mask_select_fn = mask_select_fn + self.logger = logging.getLogger(log_name) def randomize(self) -> None: return self.R.randint(len(self.source_dataset)) @@ -493,7 +498,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda except ValueError as e: raise TransformException("Source dataset should return two data: source_image, source_mask.\nErr msg: {e}") else: - d[key] = self.generator(source_image, source_mask, target_image, target_mask) + d[key] = self.generator(source_image, source_mask, target_image, self.mask_select_fn(target_mask)) return d @@ -507,3 +512,4 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda RandCrop2dByPosNegLabelD = RandCrop2dByPosNegLabelDict = RandCrop2dByPosNegLabeld RandLabelToMaskD = RandLabelToMaskDict = RandLabelToMaskd GetItemD = GetItemDict = GetItemd +RandSoftCopyPasteD = RandSoftCopyPasteDict = RandSoftCopyPasted From 755657bed94ad4d8154f0e86980127be970f6c1e Mon Sep 17 00:00:00 2001 From: Chenglong Wang Date: Fri, 12 Aug 2022 15:22:32 +0800 Subject: [PATCH 07/12] Add tests. --- monai_ex/tests/test_RandSoftCopyPaste.py | 4 +- monai_ex/tests/test_RandSoftCopyPasteD.py | 54 +++++++++++++++++++++++ 2 files changed, 56 insertions(+), 2 deletions(-) create mode 100644 monai_ex/tests/test_RandSoftCopyPasteD.py diff --git a/monai_ex/tests/test_RandSoftCopyPaste.py b/monai_ex/tests/test_RandSoftCopyPaste.py index f68a69c..3fcc25f 100644 --- a/monai_ex/tests/test_RandSoftCopyPaste.py +++ b/monai_ex/tests/test_RandSoftCopyPaste.py @@ -27,11 +27,11 @@ def test_randsoftcopypaste(dim): print("dummy data, mask shape:", src_image.shape, src_image.shape) print("mask label: ", np.unique(src_mask)) - sythetic_img = RandSoftCopyPaste(2, 4, label_idx=1)(src_image, src_mask, tar_image, tar_mask == 0) + sythetic_img = RandSoftCopyPaste(2, 4, source_label_value=1)(src_image, src_mask, tar_image, tar_mask == 0) assert sythetic_img.shape == (1, *spatial_size) # save_fpath = Path.home() / f"sythetic_img_{dim}.nii.gz" # nib.save(nib.Nifti1Image(sythetic_img.squeeze(), np.eye(4)), save_fpath) - sythetic_img = RandSoftCopyPaste(2, 4, label_idx=1)(src_image, src_mask, tar_image, None) + sythetic_img = RandSoftCopyPaste(2, 4, source_label_value=1)(src_image, src_mask, tar_image, None) assert sythetic_img.shape == (1, *spatial_size) diff --git a/monai_ex/tests/test_RandSoftCopyPasteD.py b/monai_ex/tests/test_RandSoftCopyPasteD.py new file mode 100644 index 0000000..7368006 --- /dev/null +++ b/monai_ex/tests/test_RandSoftCopyPasteD.py @@ -0,0 +1,54 @@ +import pytest + +from pathlib import Path +import nibabel as nib + +import numpy as np +from monai_ex.transforms.utility.dictionary import RandSoftCopyPasteD +from monai.data.synthetic import create_test_image_3d +from monai.data import Dataset +from monai_ex.transforms import GenerateSyntheticData, Compose, adaptor + + +@pytest.mark.parametrize("dim", [2, 3]) +def test_randsoftcopypaste(dim): + data_num = 3 + spatial_size = (100,) * dim + generator = GenerateSyntheticData( + *spatial_size, + num_objs=1, + rad_max=5, + rad_min=4, + noise_max=0, + num_seg_classes=1, + channel_dim=0, + ) + + img, msk = generator() + volume_size = np.count_nonzero(msk) + + dummy_fpath = [{"image": "d.nii", "label": "l.nii"} for i in range(data_num)] + + source_dataset = Dataset(['dummy.nii' for i in range(data_num)], transform=generator) + + dataset = Dataset( + dummy_fpath, transform=Compose([ + adaptor(generator, ["image", "label"]), + RandSoftCopyPasteD( + keys="image", mask_key="label", + source_dataset=source_dataset, # will generate image & mask + k_erode=2, + k_dilate=5, + alpha=0.8, + source_label_value=1, + mask_select_fn=lambda x: x == 0, + ) + ]) + ) + + for i, item in enumerate(dataset): + image, label = item["image"], item["label"] + assert np.count_nonzero(image) == 2 * volume_size + + save_fpath = Path.home() / f"sythetic_img_{i}.nii.gz" + nib.save(nib.Nifti1Image(image.squeeze(), np.eye(4)), save_fpath) From 8bc977c18b09065b13fc27f7f0031d25f5f8eacf Mon Sep 17 00:00:00 2001 From: Chenglong Wang Date: Fri, 12 Aug 2022 18:48:32 +0800 Subject: [PATCH 08/12] Return both sythetic_image and sythetic_mask. --- monai_ex/tests/test_RandSoftCopyPasteD.py | 13 +++++---- monai_ex/transforms/utility/array.py | 32 +++++++++++++++++------ monai_ex/transforms/utility/dictionary.py | 11 ++++---- 3 files changed, 38 insertions(+), 18 deletions(-) diff --git a/monai_ex/tests/test_RandSoftCopyPasteD.py b/monai_ex/tests/test_RandSoftCopyPasteD.py index 7368006..f5c3872 100644 --- a/monai_ex/tests/test_RandSoftCopyPasteD.py +++ b/monai_ex/tests/test_RandSoftCopyPasteD.py @@ -12,14 +12,14 @@ @pytest.mark.parametrize("dim", [2, 3]) def test_randsoftcopypaste(dim): - data_num = 3 + data_num = 2 spatial_size = (100,) * dim generator = GenerateSyntheticData( *spatial_size, num_objs=1, rad_max=5, rad_min=4, - noise_max=0, + noise_max=0.5, num_seg_classes=1, channel_dim=0, ) @@ -48,7 +48,10 @@ def test_randsoftcopypaste(dim): for i, item in enumerate(dataset): image, label = item["image"], item["label"] - assert np.count_nonzero(image) == 2 * volume_size - save_fpath = Path.home() / f"sythetic_img_{i}.nii.gz" - nib.save(nib.Nifti1Image(image.squeeze(), np.eye(4)), save_fpath) + # save_fpath = Path.home() / f"sythetic_{dim}Dimg_{i}.nii.gz" + # nib.save(nib.Nifti1Image(image.squeeze(), np.eye(4)), save_fpath) + # save_fpath = Path.home() / f"sythetic_{dim}Dlabel_{i}.nii.gz" + # nib.save(nib.Nifti1Image(label.squeeze(), np.eye(4)), save_fpath) + + assert volume_size < np.count_nonzero(label) <= 2 * volume_size \ No newline at end of file diff --git a/monai_ex/transforms/utility/array.py b/monai_ex/transforms/utility/array.py index a1b8e90..34dec36 100644 --- a/monai_ex/transforms/utility/array.py +++ b/monai_ex/transforms/utility/array.py @@ -7,7 +7,7 @@ from monai.transforms.compose import Transform, Randomizable from monai.transforms import DataStats, SaveImage, CastToType -from monai.transforms.utils import generate_pos_neg_label_crop_centers, map_binary_to_indices +from monai.transforms.utils import generate_pos_neg_label_crop_centers, map_binary_to_indices, is_positive from monai.config import NdarrayTensor, DtypeLike from monai_ex.utils import convert_data_type_ex, bbox_ND @@ -213,6 +213,7 @@ def __init__( k_erode: int, k_dilate: int, alpha: float = 0.8, + mask_select_fn: Callable = is_positive, source_label_value: Optional[int] = None, log_name: Optional[str] = None, ) -> None: @@ -220,6 +221,7 @@ def __init__( self.k_erode = k_erode self.k_dilate = k_dilate self.alpha = alpha + self.mask_select_fn = mask_select_fn self.source_label_value = source_label_value self.logger = logging.getLogger(log_name) @@ -237,6 +239,7 @@ def soften(self, src_mask): def paste( self, softed_image: NdarrayTensor, + origin_mask: NdarrayTensor, softed_mask: NdarrayTensor, target_image: NdarrayTensor, target_mask: NdarrayTensor, @@ -245,13 +248,13 @@ def paste( boundingbox = bbox_ND(softed_mask[0, ...]) bbox_size = tuple(boundingbox[2 * i + 1] - boundingbox[2 * i] for i in range(len(boundingbox) // 2)) src_ranges = tuple(slice(boundingbox[2 * i], boundingbox[2 * i + 1]) for i in range(len(boundingbox) // 2)) - src_slices = [slice(n_ch), *src_ranges] + src_slices = (slice(n_ch), *src_ranges) if target_mask is None: # ! if no target mask is provided, paste to orignal pos. tar_slices = src_slices else: - fg_indices_, bg_indices_ = map_binary_to_indices(target_mask, None, None) + fg_indices_, bg_indices_ = map_binary_to_indices(self.mask_select_fn(target_mask), None, None) centers = generate_pos_neg_label_crop_centers( bbox_size, 1, @@ -263,14 +266,15 @@ def paste( False, ) tar_ranges = tuple(slice(int(center - sz // 2), int(center - sz // 2 + sz)) for center, sz in zip(centers[0], bbox_size)) - tar_slices = [slice(n_ch), *tar_ranges] + tar_slices = (slice(n_ch), *tar_ranges) shifted_src_image = np.zeros_like(target_image) shifted_src_mask = np.zeros_like(target_image) shifted_src_image[tar_slices] = softed_image[src_slices] shifted_src_mask[tar_slices] = softed_mask[src_slices] sythetic_image = shifted_src_image + (1 - shifted_src_mask) * target_image - return sythetic_image + shifted_src_mask[tar_slices] = origin_mask[src_slices] + return sythetic_image, shifted_src_mask def __call__( self, @@ -296,7 +300,19 @@ def __call__( softed_mask = np.repeat(softed_mask, repeats=source_image.shape[0], axis=0) softed_image = source_image * softed_mask - sythetic_image = self.paste( - softed_image=softed_image, softed_mask=softed_mask, target_image=target_image, target_mask=target_mask + sythetic_image, shifted_src_mask = self.paste( + softed_image=softed_image, + origin_mask=source_mask[None], + softed_mask=softed_mask, + target_image=target_image, + target_mask=target_mask ) - return sythetic_image + + if target_mask is None: + shifted_src_mask[shifted_src_mask > 0] = self.source_label_value + sythetic_mask = shifted_src_mask + else: + target_mask[shifted_src_mask > 0] = self.source_label_value + sythetic_mask = target_mask + + return sythetic_image, sythetic_mask diff --git a/monai_ex/transforms/utility/dictionary.py b/monai_ex/transforms/utility/dictionary.py index 9832b32..60b360a 100644 --- a/monai_ex/transforms/utility/dictionary.py +++ b/monai_ex/transforms/utility/dictionary.py @@ -477,8 +477,8 @@ def __init__( super().__init__(keys) self.mask_key = mask_key self.source_dataset = source_dataset - self.generator = RandSoftCopyPaste(k_erode, k_dilate, alpha, source_label_value, log_name) - self.mask_select_fn = mask_select_fn + self.source_label_value = source_label_value + self.generator = RandSoftCopyPaste(k_erode, k_dilate, alpha, mask_select_fn, source_label_value, log_name) self.logger = logging.getLogger(log_name) def randomize(self) -> None: @@ -496,10 +496,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda try: source_image, source_mask = self.source_dataset[idx] except ValueError as e: - raise TransformException("Source dataset should return two data: source_image, source_mask.\nErr msg: {e}") + raise TransformException("Source dataset should return a tuple: (source_image, source_mask).\nErr msg: {e}") else: - d[key] = self.generator(source_image, source_mask, target_image, self.mask_select_fn(target_mask)) - + sythetic_image, sythetic_mask = self.generator(source_image, source_mask, target_image, target_mask) + d[key] = sythetic_image + d[self.mask_key] = sythetic_mask return d From 5dba8e925c0f85791bab5dda31103e67a209856f Mon Sep 17 00:00:00 2001 From: Chenglong Wang Date: Fri, 12 Aug 2022 20:45:00 +0800 Subject: [PATCH 09/12] Add `prob`. --- monai_ex/tests/test_RandSoftCopyPaste.py | 26 ++++++++++++++++++----- monai_ex/tests/test_RandSoftCopyPasteD.py | 1 + monai_ex/transforms/utility/array.py | 25 +++++++++++++--------- monai_ex/transforms/utility/dictionary.py | 17 +++++++++++---- 4 files changed, 50 insertions(+), 19 deletions(-) diff --git a/monai_ex/tests/test_RandSoftCopyPaste.py b/monai_ex/tests/test_RandSoftCopyPaste.py index 3fcc25f..6d029ff 100644 --- a/monai_ex/tests/test_RandSoftCopyPaste.py +++ b/monai_ex/tests/test_RandSoftCopyPaste.py @@ -10,7 +10,8 @@ @pytest.mark.parametrize("dim", [2, 3]) -def test_randsoftcopypaste(dim): +@pytest.mark.parametrize("prob", [0, 1]) +def test_randsoftcopypaste(dim, prob): spatial_size = (100,) * dim generator = GenerateSyntheticData( *spatial_size, @@ -24,14 +25,29 @@ def test_randsoftcopypaste(dim): src_image, src_mask = generator(None) tar_image, tar_mask = generator(None) + volume_size = np.count_nonzero(src_mask) + np.count_nonzero(tar_mask) print("dummy data, mask shape:", src_image.shape, src_image.shape) print("mask label: ", np.unique(src_mask)) - sythetic_img = RandSoftCopyPaste(2, 4, source_label_value=1)(src_image, src_mask, tar_image, tar_mask == 0) - assert sythetic_img.shape == (1, *spatial_size) + sythetic_img, sythetic_msk = RandSoftCopyPaste( + 2, 4, prob=prob, mask_select_fn=lambda x: x==0, source_label_value=1 + )(tar_image, tar_mask, src_image, src_mask) + if prob == 0: + assert np.all(sythetic_img == tar_image) + assert np.all(sythetic_msk == tar_mask) + else: + assert sythetic_img.shape == (1, *spatial_size) + assert volume_size/2 <= np.count_nonzero(sythetic_msk) <= volume_size # save_fpath = Path.home() / f"sythetic_img_{dim}.nii.gz" # nib.save(nib.Nifti1Image(sythetic_img.squeeze(), np.eye(4)), save_fpath) - sythetic_img = RandSoftCopyPaste(2, 4, source_label_value=1)(src_image, src_mask, tar_image, None) - assert sythetic_img.shape == (1, *spatial_size) + sythetic_img, sythetic_msk = RandSoftCopyPaste( + 2, 4, prob=prob, source_label_value=1 + )(tar_image, None, src_image, src_mask) + if prob == 0: + assert np.all(sythetic_img == tar_image) + assert sythetic_msk is None + else: + assert sythetic_img.shape == (1, *spatial_size) + assert volume_size/2 <= np.count_nonzero(sythetic_msk) <= volume_size diff --git a/monai_ex/tests/test_RandSoftCopyPasteD.py b/monai_ex/tests/test_RandSoftCopyPasteD.py index f5c3872..78390c2 100644 --- a/monai_ex/tests/test_RandSoftCopyPasteD.py +++ b/monai_ex/tests/test_RandSoftCopyPasteD.py @@ -40,6 +40,7 @@ def test_randsoftcopypaste(dim): k_erode=2, k_dilate=5, alpha=0.8, + prob=1, source_label_value=1, mask_select_fn=lambda x: x == 0, ) diff --git a/monai_ex/transforms/utility/array.py b/monai_ex/transforms/utility/array.py index 34dec36..0e01ecb 100644 --- a/monai_ex/transforms/utility/array.py +++ b/monai_ex/transforms/utility/array.py @@ -5,7 +5,7 @@ import torch from scipy import ndimage as ndi -from monai.transforms.compose import Transform, Randomizable +from monai.transforms.compose import Transform, Randomizable, RandomizableTransform from monai.transforms import DataStats, SaveImage, CastToType from monai.transforms.utils import generate_pos_neg_label_crop_centers, map_binary_to_indices, is_positive from monai.config import NdarrayTensor, DtypeLike @@ -202,7 +202,7 @@ def __call__( return np.any(data, axis=0, keepdims=True) if (merge_channels or self.merge_channels) else data -class RandSoftCopyPaste(Randomizable, Transform): +class RandSoftCopyPaste(RandomizableTransform): """ Perform Soft Copy&Paste augmentation. Reference: `https://arxiv.org/ftp/arxiv/papers/2203/2203.10507.pdf` @@ -213,11 +213,12 @@ def __init__( k_erode: int, k_dilate: int, alpha: float = 0.8, + prob: float = 0.1, mask_select_fn: Callable = is_positive, source_label_value: Optional[int] = None, log_name: Optional[str] = None, ) -> None: - super().__init__() + RandomizableTransform.__init__(self, prob) self.k_erode = k_erode self.k_dilate = k_dilate self.alpha = alpha @@ -278,11 +279,15 @@ def paste( def __call__( self, + image: NdarrayTensor, + mask: Optional[NdarrayTensor], source_image: NdarrayTensor, source_mask: NdarrayTensor, - target_image: NdarrayTensor, - target_mask: Optional[NdarrayTensor] = None, ) -> NdarrayTensor: + self.randomize(None) + if not self._do_transform: + return image, mask + if source_mask.shape[0] > 1: if self.source_label_value is None: raise ValueError("Multi-channel label data need to specify label_idx") @@ -304,15 +309,15 @@ def __call__( softed_image=softed_image, origin_mask=source_mask[None], softed_mask=softed_mask, - target_image=target_image, - target_mask=target_mask + target_image=image, + target_mask=mask ) - if target_mask is None: + if mask is None: shifted_src_mask[shifted_src_mask > 0] = self.source_label_value sythetic_mask = shifted_src_mask else: - target_mask[shifted_src_mask > 0] = self.source_label_value - sythetic_mask = target_mask + mask[shifted_src_mask > 0] = self.source_label_value + sythetic_mask = mask return sythetic_image, sythetic_mask diff --git a/monai_ex/transforms/utility/dictionary.py b/monai_ex/transforms/utility/dictionary.py index 60b360a..9db4ec7 100644 --- a/monai_ex/transforms/utility/dictionary.py +++ b/monai_ex/transforms/utility/dictionary.py @@ -470,6 +470,7 @@ def __init__( k_erode: int, k_dilate: int, alpha: float = 0.8, + prob: float = 0.1, mask_select_fn: Callable = is_positive, source_label_value: Optional[int] = None, log_name: Optional[str] = None, @@ -478,7 +479,15 @@ def __init__( self.mask_key = mask_key self.source_dataset = source_dataset self.source_label_value = source_label_value - self.generator = RandSoftCopyPaste(k_erode, k_dilate, alpha, mask_select_fn, source_label_value, log_name) + self.generator = RandSoftCopyPaste( + k_erode=k_erode, + k_dilate=k_dilate, + alpha=alpha, + prob=prob, + mask_select_fn=mask_select_fn, + source_label_value=source_label_value, + log_name=log_name + ) self.logger = logging.getLogger(log_name) def randomize(self) -> None: @@ -490,15 +499,15 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d = dict(data) for key in self.key_iterator(d): - target_image = d[key] - target_mask = d[self.mask_key] if self.mask_key else None + image = d[key] + mask = d[self.mask_key] if self.mask_key else None try: source_image, source_mask = self.source_dataset[idx] except ValueError as e: raise TransformException("Source dataset should return a tuple: (source_image, source_mask).\nErr msg: {e}") else: - sythetic_image, sythetic_mask = self.generator(source_image, source_mask, target_image, target_mask) + sythetic_image, sythetic_mask = self.generator(image, mask, source_image, source_mask) d[key] = sythetic_image d[self.mask_key] = sythetic_mask return d From d085d172a7161f30f2111e8c583a1f458bd728ba Mon Sep 17 00:00:00 2001 From: Chenglong Wang Date: Tue, 23 Aug 2022 22:49:51 +0800 Subject: [PATCH 10/12] Add `strict_paste` and `tolerance`. Fixed `RandSoftCopyPasteD` for multiple keys. --- monai_ex/tests/test_RandSoftCopyPasteD.py | 111 ++++++++++++-- monai_ex/transforms/utility/array.py | 167 +++++++++++++------- monai_ex/transforms/utility/dictionary.py | 176 ++++++++++++---------- 3 files changed, 313 insertions(+), 141 deletions(-) diff --git a/monai_ex/tests/test_RandSoftCopyPasteD.py b/monai_ex/tests/test_RandSoftCopyPasteD.py index 78390c2..65d8dab 100644 --- a/monai_ex/tests/test_RandSoftCopyPasteD.py +++ b/monai_ex/tests/test_RandSoftCopyPasteD.py @@ -7,14 +7,60 @@ from monai_ex.transforms.utility.dictionary import RandSoftCopyPasteD from monai.data.synthetic import create_test_image_3d from monai.data import Dataset -from monai_ex.transforms import GenerateSyntheticData, Compose, adaptor +from monai_ex.transforms import MapTransform, GenerateSyntheticData, Compose, adaptor + + + +class GenerateSyntheticDataD(MapTransform): + def __init__( + self, + keys, + label_key, + height: int, + width: int, + depth: int = None, + num_objs: int = 12, + rad_max: int = 30, + rad_min: int = 5, + noise_max: float = 0.0, + num_seg_classes: int = 5, + channel_dim: int = None, + random_state: np.random.RandomState = None, + allow_missing_keys: bool = False, + ): + super().__init__(keys, allow_missing_keys) + + self.label_key = label_key + self.loader = GenerateSyntheticData( + height, + width, + depth, + num_objs, + rad_max, + rad_min, + noise_max, + num_seg_classes, + channel_dim, + random_state, + ) + + def __call__(self, filename: dict): + test_data = self.loader(None) + + data = {} + for key in self.keys: + data[key] = test_data[0] + data[self.label_key] = test_data[1] + return data @pytest.mark.parametrize("dim", [2, 3]) def test_randsoftcopypaste(dim): data_num = 2 spatial_size = (100,) * dim - generator = GenerateSyntheticData( + generator = GenerateSyntheticDataD( + "image", + "label", *spatial_size, num_objs=1, rad_max=5, @@ -24,24 +70,28 @@ def test_randsoftcopypaste(dim): channel_dim=0, ) - img, msk = generator() - volume_size = np.count_nonzero(msk) - dummy_fpath = [{"image": "d.nii", "label": "l.nii"} for i in range(data_num)] - source_dataset = Dataset(['dummy.nii' for i in range(data_num)], transform=generator) + output = generator(dummy_fpath[0]) + volume_size = np.count_nonzero(output["label"]) + + source_dataset = Dataset( + [{"image": 'dummy.nii', "label": 'dummy_label.nii'} for i in range(data_num)], + transform=generator + ) dataset = Dataset( dummy_fpath, transform=Compose([ - adaptor(generator, ["image", "label"]), + generator, RandSoftCopyPasteD( keys="image", mask_key="label", source_dataset=source_dataset, # will generate image & mask + source_fg_key="label", + source_fg_value=1, k_erode=2, k_dilate=5, alpha=0.8, prob=1, - source_label_value=1, mask_select_fn=lambda x: x == 0, ) ]) @@ -55,4 +105,47 @@ def test_randsoftcopypaste(dim): # save_fpath = Path.home() / f"sythetic_{dim}Dlabel_{i}.nii.gz" # nib.save(nib.Nifti1Image(label.squeeze(), np.eye(4)), save_fpath) - assert volume_size < np.count_nonzero(label) <= 2 * volume_size \ No newline at end of file + assert volume_size < np.count_nonzero(label) <= 2 * volume_size + + + +@pytest.mark.parametrize("dim", [2, 3]) +def test_randsoftcopypaste_multiimage(dim): + data_num = 2 + spatial_size = (100,) * dim + generator = GenerateSyntheticDataD( + ["image1", "image2"], + "label", + *spatial_size, + num_objs=1, + rad_max=5, + rad_min=4, + noise_max=0, + num_seg_classes=1, + channel_dim=0, + ) + + dummy_fpath = [{"image1": "d.nii", "image2": "d.nii", "label": "l.nii"} for i in range(data_num)] + source_dataset = Dataset( + [{"image1": '1', "image2": "2", "label": '1'} for i in range(data_num)], + transform=generator + ) + + outputs = generator({"image1": '1', "image2": '2', "label": '1'}) + + generator = RandSoftCopyPasteD( + keys=["image1", "image2"], mask_key="label", + source_dataset=source_dataset, + source_fg_key="label", + source_fg_value=1, + k_erode=2, + k_dilate=5, + alpha=0.8, + prob=1, + mask_select_fn=lambda x: x == 0, + ) + + generated_item = generator(outputs) + assert generated_item["image1"].shape == (1, *spatial_size) + assert generated_item["image2"].shape == (1, *spatial_size) + assert np.all(generated_item["image1"] == generated_item["image2"]) diff --git a/monai_ex/transforms/utility/array.py b/monai_ex/transforms/utility/array.py index 0e01ecb..440dd23 100644 --- a/monai_ex/transforms/utility/array.py +++ b/monai_ex/transforms/utility/array.py @@ -206,6 +206,17 @@ class RandSoftCopyPaste(RandomizableTransform): """ Perform Soft Copy&Paste augmentation. Reference: `https://arxiv.org/ftp/arxiv/papers/2203/2203.10507.pdf` + + Args: + k_erode (int): erosion iteration num. + k_dilate (int): dilation iteration num. + alpha (float, optional): transparence ratio. Defaults to 0.8. + prob (float, optional): Probability to perform this aug. Defaults to 0.1. + mask_select_fn (Callable, optional): function to select expected foreground, default is to select values > 0. + source_label_value (Optional[int], optional): source foregound value. Defaults to None. + strict_paste (bool, optional): whether to strictly paste source mask inside of target mask region. Defaults to False. + tolerance (int, optional): even in strict_paste mode, there is a tolerance to allow paste to the edge. Defaults to 10. + log_name (Optional[str], optional): logger name. Defaults to None. """ def __init__( @@ -216,6 +227,8 @@ def __init__( prob: float = 0.1, mask_select_fn: Callable = is_positive, source_label_value: Optional[int] = None, + strict_paste: bool = False, + tolerance: int = 100, log_name: Optional[str] = None, ) -> None: RandomizableTransform.__init__(self, prob) @@ -224,10 +237,23 @@ def __init__( self.alpha = alpha self.mask_select_fn = mask_select_fn self.source_label_value = source_label_value + self.strict_paste = strict_paste + self.tolerance = tolerance self.logger = logging.getLogger(log_name) def soften(self, src_mask): - struct = ndi.generate_binary_structure(src_mask.ndim, 2) + if src_mask.shape[0] > 1: + if self.source_label_value is None: + raise ValueError("Multi-channel label data need to specify label_idx") + else: + src_mask = src_mask[self.source_label_value, ...] + elif src_mask.shape[0] == 1: + if self.source_label_value is None: + src_mask = (src_mask > 0).squeeze(0) + else: + src_mask = (src_mask == self.source_label_value).squeeze(0) + + struct = ndi.generate_binary_structure(src_mask.ndim, src_mask.ndim - 1) mask = ndi.binary_erosion(src_mask, struct, iterations=self.k_erode).astype(src_mask.dtype) for j in range(self.k_dilate): mask_binary = np.where(mask > 0, 1, 0) @@ -237,87 +263,126 @@ def soften(self, src_mask): return mask + def ensure_strict_center(self, softed_mask, target_mask): + src_ranges = tuple( + slice(self.boundingbox[2 * i], self.boundingbox[2 * i + 1]) for i in range(len(self.boundingbox) // 2) + ) + src_slices = (slice(softed_mask.shape[0]), *src_ranges) + shifted_src_mask = np.zeros_like(target_mask) + for tar_slice in self.target_slices: + shifted_src_mask[tar_slice] = softed_mask[src_slices] + if np.count_nonzero(target_mask[shifted_src_mask] == 0) <= self.tolerance: # full contain! + return [tar_slice] # only support sample num = 1 + return [] + + def compute_target_position(self, src_mask, softed_mask, target_image, target_mask) -> None: + n_ch = target_image.shape[0] + self.boundingbox = bbox_ND(softed_mask[0, ...]) + bbox_size = tuple( + self.boundingbox[2 * i + 1] - self.boundingbox[2 * i] for i in range(len(self.boundingbox) // 2) + ) + selected_target_mask = self.mask_select_fn(target_mask) + fg_indices_, bg_indices_ = map_binary_to_indices(selected_target_mask, None, None) + centers = generate_pos_neg_label_crop_centers( + bbox_size, + 10, # pick 10 candidates + 1, + softed_mask.shape[1:], + fg_indices_, + bg_indices_, + self.R, + False, + ) + target_ranges = [] + for center in centers: + target_ranges.append( + tuple(slice(int(center - sz // 2), int(center - sz // 2 + sz)) for center, sz in zip(center, bbox_size)) + ) + self.target_slices = [(slice(n_ch), *ranges) for ranges in target_ranges] + if self.strict_paste: + self.logger.debug("Enter strict filtering mode.") + self.target_slices = self.ensure_strict_center( + src_mask, selected_target_mask + ) + def paste( self, softed_image: NdarrayTensor, origin_mask: NdarrayTensor, softed_mask: NdarrayTensor, target_image: NdarrayTensor, - target_mask: NdarrayTensor, + target_bg_mask: NdarrayTensor, + randomize: True, ): n_ch = target_image.shape[0] - boundingbox = bbox_ND(softed_mask[0, ...]) - bbox_size = tuple(boundingbox[2 * i + 1] - boundingbox[2 * i] for i in range(len(boundingbox) // 2)) - src_ranges = tuple(slice(boundingbox[2 * i], boundingbox[2 * i + 1]) for i in range(len(boundingbox) // 2)) + if randomize: + self.compute_target_position(origin_mask, softed_mask, target_image, target_bg_mask) + src_ranges = tuple( + slice(self.boundingbox[2 * i], self.boundingbox[2 * i + 1]) for i in range(len(self.boundingbox) // 2) + ) src_slices = (slice(n_ch), *src_ranges) - if target_mask is None: - # ! if no target mask is provided, paste to orignal pos. - tar_slices = src_slices - else: - fg_indices_, bg_indices_ = map_binary_to_indices(self.mask_select_fn(target_mask), None, None) - centers = generate_pos_neg_label_crop_centers( - bbox_size, - 1, - 1, - softed_mask.shape[1:], - fg_indices_, - bg_indices_, - self.R, - False, - ) - tar_ranges = tuple(slice(int(center - sz // 2), int(center - sz // 2 + sz)) for center, sz in zip(centers[0], bbox_size)) - tar_slices = (slice(n_ch), *tar_ranges) + if not self.target_slices: + self.logger.debug("No position is found to paste strictly! Skip copy&paste") + return None shifted_src_image = np.zeros_like(target_image) shifted_src_mask = np.zeros_like(target_image) - shifted_src_image[tar_slices] = softed_image[src_slices] - shifted_src_mask[tar_slices] = softed_mask[src_slices] + shifted_src_image[self.target_slices[0]] = softed_image[src_slices] + shifted_src_mask[self.target_slices[0]] = softed_mask[src_slices] sythetic_image = shifted_src_image + (1 - shifted_src_mask) * target_image - shifted_src_mask[tar_slices] = origin_mask[src_slices] + shifted_src_mask[self.target_slices[0]] = origin_mask[src_slices] return sythetic_image, shifted_src_mask def __call__( self, image: NdarrayTensor, - mask: Optional[NdarrayTensor], + fg_mask: Optional[NdarrayTensor], + bg_mask: NdarrayTensor, source_image: NdarrayTensor, - source_mask: NdarrayTensor, + source_fg_mask: NdarrayTensor, + softed_fg_mask: Optional[NdarrayTensor] = None, + randomize: bool = True, ) -> NdarrayTensor: - self.randomize(None) + if randomize: + self.randomize(None) + if not self._do_transform: - return image, mask + return image, fg_mask - if source_mask.shape[0] > 1: - if self.source_label_value is None: - raise ValueError("Multi-channel label data need to specify label_idx") - else: - source_mask = source_mask[self.source_label_value, ...] - elif source_mask.shape[0] == 1: - if self.source_label_value is None: - source_mask = (source_mask > 0).squeeze(0) - else: - source_mask = (source_mask == self.source_label_value).squeeze(0) + if self.strict_paste and np.count_nonzero(self.mask_select_fn(bg_mask)) < np.count_nonzero(source_fg_mask): + self.logger.debug("Target mask area is smaller than source foreground area. Skip copy&paste") + return image, fg_mask + + if softed_fg_mask is None: + softed_fg_mask = self.soften(source_fg_mask) + if np.count_nonzero(softed_fg_mask) == 0: + self.logger.debug("Source foreground area is too small to be soften. Skip copy&paste") + return image, fg_mask - softed_mask = self.soften(source_mask) - softed_mask = softed_mask[np.newaxis, ...] - if source_image.shape[0] > 1: - softed_mask = np.repeat(softed_mask, repeats=source_image.shape[0], axis=0) + softed_fg_mask = softed_fg_mask[np.newaxis, ...] + if source_image.shape[0] > 1: + softed_fg_mask = np.repeat(softed_fg_mask, repeats=source_image.shape[0], axis=0) - softed_image = source_image * softed_mask - sythetic_image, shifted_src_mask = self.paste( + softed_image = source_image * softed_fg_mask + processed_data = self.paste( softed_image=softed_image, - origin_mask=source_mask[None], - softed_mask=softed_mask, + origin_mask=source_fg_mask, + softed_mask=softed_fg_mask, target_image=image, - target_mask=mask + target_bg_mask=bg_mask, + randomize=randomize, ) + if processed_data is None: + return image, fg_mask + + sythetic_image, shifted_src_mask = processed_data - if mask is None: + if fg_mask is None: shifted_src_mask[shifted_src_mask > 0] = self.source_label_value sythetic_mask = shifted_src_mask else: - mask[shifted_src_mask > 0] = self.source_label_value - sythetic_mask = mask + fg_mask[shifted_src_mask > 0] = self.source_label_value + sythetic_mask = fg_mask return sythetic_image, sythetic_mask diff --git a/monai_ex/transforms/utility/dictionary.py b/monai_ex/transforms/utility/dictionary.py index 9db4ec7..ede044d 100644 --- a/monai_ex/transforms/utility/dictionary.py +++ b/monai_ex/transforms/utility/dictionary.py @@ -19,7 +19,7 @@ DataStatsEx, DataLabelling, RandLabelToMask, - RandSoftCopyPaste + RandSoftCopyPaste, ) from monai_ex.transforms import ( @@ -37,9 +37,7 @@ class CastToTypeExd(MapTransform): def __init__( self, keys: KeysCollection, - dtype: Union[ - Sequence[Union[np.dtype, torch.dtype, str]], np.dtype, torch.dtype, str - ] = np.float32, + dtype: Union[Sequence[Union[np.dtype, torch.dtype, str]], np.dtype, torch.dtype, str] = np.float32, ) -> None: """ Args: @@ -78,9 +76,7 @@ def __init__(self, keys: KeysCollection) -> None: super().__init__(keys) self.converter = ToTensorEx() - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key in self.keys: d[key] = self.converter(d[key]) @@ -116,19 +112,9 @@ def __init__( self.logger_handler = logger_handler self.printer = DataStatsEx(logger_handler=logger_handler) - def __call__( - self, data: Mapping[Hashable, NdarrayTensor] - ) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: d = dict(data) - for ( - key, - prefix, - data_type, - data_shape, - value_range, - data_value, - additional_info, - ) in self.key_iterator( + for (key, prefix, data_type, data_shape, value_range, data_value, additional_info,) in self.key_iterator( d, self.prefix, self.data_type, @@ -153,7 +139,7 @@ def __call__( class SplitChannelExd(MapTransform): """ Extension of `monai.transforms.SplitChanneld`. - Extended: `output_names`: the names to construct keys to store split data if + Extended: `output_names`: the names to construct keys to store split data if you don't want postfixes. `remove_origin`: delete original data of given keys @@ -169,7 +155,7 @@ def __init__( channel_dim: int = 0, remove_origin: bool = False, allow_missing_keys: bool = False, - meta_key_postfix='meta_dict', + meta_key_postfix="meta_dict", ) -> None: """ Args: @@ -209,6 +195,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N d.pop(f"{key}_{self.meta_key_postfix}") return d + class DataLabellingd(MapTransform): def __init__( self, @@ -217,16 +204,13 @@ def __init__( super().__init__(keys) self.converter = DataLabelling() - def __call__( - self, img: Mapping[Hashable, torch.Tensor] - ) -> Dict[Hashable, torch.Tensor]: + def __call__(self, img: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(img) for idx, key in enumerate(self.keys): d[key] = self.converter(d[key]) return d - class ConcatModalityd(MapTransform): """Concat multi-modality data by given keys.""" @@ -272,9 +256,7 @@ def __init__( self.n_layer = n_layer if pos < 0 or neg < 0: - raise ValueError( - f"pos and neg must be nonnegative, got pos={pos} neg={neg}." - ) + raise ValueError(f"pos and neg must be nonnegative, got pos={pos} neg={neg}.") if pos + neg == 0: raise ValueError("Incompatible values: pos=0 and neg=0.") self.pos_ratio = pos / (pos + neg) @@ -304,9 +286,7 @@ def randomize( image: Optional[np.ndarray] = None, ) -> None: if fg_indices is None or bg_indices is None: - fg_indices_, bg_indices_ = map_binary_to_indices( - label, image, self.image_threshold - ) + fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) else: fg_indices_ = fg_indices bg_indices_ = bg_indices @@ -321,38 +301,24 @@ def randomize( self.R, ) - def __call__( - self, data: Mapping[Hashable, np.ndarray] - ) -> List[Dict[Hashable, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]: d = dict(data) label = d[self.label_key] image = d[self.image_key] if self.image_key else None - fg_indices = ( - d.get(self.fg_indices_key, None) - if self.fg_indices_key is not None - else None - ) - bg_indices = ( - d.get(self.bg_indices_key, None) - if self.bg_indices_key is not None - else None - ) + fg_indices = d.get(self.fg_indices_key, None) if self.fg_indices_key is not None else None + bg_indices = d.get(self.bg_indices_key, None) if self.bg_indices_key is not None else None self.randomize(label, fg_indices, bg_indices, image) assert isinstance(self.spatial_size, tuple) assert self.centers is not None - results: List[Dict[Hashable, np.ndarray]] = [ - dict() for _ in range(self.num_samples) - ] + results: List[Dict[Hashable, np.ndarray]] = [dict() for _ in range(self.num_samples)] for key in data.keys(): if key in self.keys: img = d[key] for i, center in enumerate(self.centers): if self.crop_mode in ["single", "parallel"]: size_ = self.get_new_spatial_size() - slice_ = SpatialCrop(roi_center=tuple(center), roi_size=size_)( - img - ) + slice_ = SpatialCrop(roi_center=tuple(center), roi_size=size_)(img) seg_sum = slice_.squeeze().sum() results[i][key] = np.moveaxis(slice_.squeeze(0), self.z_axis, 0) @@ -360,9 +326,7 @@ def __call__( cross_slices = np.zeros(shape=(3,) + self.spatial_size) for k in range(3): size_ = np.insert(self.spatial_size, k, 1) - slice_ = SpatialCrop( - roi_center=tuple(center), roi_size=size_ - )(img) + slice_ = SpatialCrop(roi_center=tuple(center), roi_size=size_)(img) cross_slices[k] = slice_.squeeze() results[i][key] = cross_slices else: @@ -393,7 +357,7 @@ def __init__( # pytype: disable=annotation-type-mismatch select_labels: Union[Sequence[int], int], merge_channels: bool = False, cls_label_key: Optional[KeysCollection] = None, - select_msk_label: Optional[int] = None, #! for tmp debug + select_msk_label: Optional[int] = None, #! for tmp debug ) -> None: super().__init__(keys) self.select_labels = select_labels @@ -411,14 +375,16 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda if self.cls_label_key is not None: label = d[self.cls_label_key] - assert len(label) == len(self.select_labels), 'length of cls_label_key must equal to length of mask select_labels' + assert len(label) == len( + self.select_labels + ), "length of cls_label_key must equal to length of mask select_labels" if isinstance(label, (list, tuple)): - label = { i:L for i, L in enumerate(label, 1)} + label = {i: L for i, L in enumerate(label, 1)} elif isinstance(label, (int, float)): - label = {1:label} - assert isinstance(label, dict), 'Only support dict type label' - + label = {1: label} + assert isinstance(label, dict), "Only support dict type label" + d[self.cls_label_key] = label[self.select_label] for key in self.keys: @@ -435,6 +401,7 @@ class GetItemd(MapTransform): keys (KeysCollection): keys of the corresponding items to be transformed. index (Union[Sequence[int], int]): i-th item you want to select. """ + def __init__( # pytype: disable=annotation-type-mismatch self, keys: KeysCollection, @@ -454,62 +421,109 @@ class RandSoftCopyPasted(Randomizable, MapTransform): """Dictionary-based wrapper of :py:class:`monai_ex.transforms.RandSoftCopyPaste`. Args: - keys (KeysCollection): keys of the corresponding items to be transformed - mask_key (Optional[str]): key of the mask data - source_dataset (Dataset): a dataset return source image and mask - k_erode (int): erosion times. - k_dilate (int): dilation times. - alpha (float, optional): alpha. Defaults to 0.8. - label_idx (Optional[int], optional): the label of souce mask to be proceed. Defaults to None. + keys (KeysCollection): keys of the corresponding items to be transformed. + mask_key (Optional[str]): key of the mask data. + source_dataset (Dataset): a dataset for process source data, return dict data. + source_fg_key (str): key of source foreground data. + source_fg_value (Optional[int]): source foregound value. + k_erode (int): erosion iteration num. + k_dilate (int): dilation iteration num. + alpha (float, optional): transparence ratio. Defaults to 0.8. + prob (float, optional): Probability to perform this aug. Defaults to 0.1. + mask_select_fn (Callable, optional): function to select expected foreground, default is to select values > 0. + strict_paste (bool, optional): whether to strictly paste source mask inside of target mask region. Defaults to False. + tolerance (int, optional): even in strict_paste mode, there is a tolerance to allow paste to the edge. Defaults to 10. + log_name (Optional[str], optional): logger name. Defaults to None. """ + def __init__( self, keys: KeysCollection, mask_key: Optional[str], source_dataset: Dataset, + source_fg_key: str, + source_fg_value: Optional[int], k_erode: int, k_dilate: int, alpha: float = 0.8, prob: float = 0.1, mask_select_fn: Callable = is_positive, - source_label_value: Optional[int] = None, + strict_paste: bool = False, + tolerance: int = 100, log_name: Optional[str] = None, ) -> None: super().__init__(keys) self.mask_key = mask_key self.source_dataset = source_dataset - self.source_label_value = source_label_value + self.source_fg_value = source_fg_value + self.source_fg_key = source_fg_key self.generator = RandSoftCopyPaste( k_erode=k_erode, k_dilate=k_dilate, alpha=alpha, prob=prob, mask_select_fn=mask_select_fn, - source_label_value=source_label_value, - log_name=log_name + source_label_value=source_fg_value, + strict_paste=strict_paste, + tolerance=tolerance, + log_name=log_name, ) self.logger = logging.getLogger(log_name) def randomize(self) -> None: return self.R.randint(len(self.source_dataset)) + def compute_target_position(self, src_mask, softed_mask, target_image, target_mask): + self.generator.compute_target_position(src_mask, softed_mask, target_image, target_mask) + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) idx = self.randomize() - d = dict(data) + try: + source = self.source_dataset[idx] + except Exception as e: + raise TransformException("Source dataset crashed.\nErr msg: {e}") + + if self.source_fg_key not in source: + raise TransformException(f"Source dataset did not contain foregound mask key: {self.source_fg_key}") + + if self.generator.strict_paste and ( + np.count_nonzero(self.generator.mask_select_fn(d[self.mask_key])) + < np.count_nonzero(source[self.source_fg_key]) + ): + self.logger.debug("Target mask area is smaller than source foreground area. Skip copy&paste") + return d + + softed_mask = self.generator.soften(source[self.source_fg_key]) + if np.count_nonzero(softed_mask) == 0: + self.logger.debug("Source foreground area is too small to be soften. Skip copy&paste") + return d + + softed_mask = softed_mask[np.newaxis, ...] + + first_img_key = self.first_key(d) + if source[first_img_key].shape[0] > 1: + softed_mask = np.repeat(softed_mask, repeats=source_image.shape[0], axis=0) + + self.compute_target_position(source[self.source_fg_key], softed_mask, d[first_img_key], d[self.mask_key]) + for key in self.key_iterator(d): image = d[key] - mask = d[self.mask_key] if self.mask_key else None - - try: - source_image, source_mask = self.source_dataset[idx] - except ValueError as e: - raise TransformException("Source dataset should return a tuple: (source_image, source_mask).\nErr msg: {e}") - else: - sythetic_image, sythetic_mask = self.generator(image, mask, source_image, source_mask) - d[key] = sythetic_image - d[self.mask_key] = sythetic_mask + bg_mask = d[self.mask_key] if self.mask_key else None + source_image, source_fg = source[key], source[self.source_fg_key] + + sythetic_image, sythetic_mask = self.generator( + image, + fg_mask=d.get(self.source_fg_key), + bg_mask=bg_mask, + source_image=source_image, + source_fg_mask=source_fg, + softed_fg_mask=softed_mask, + randomize=False, + ) + d[key] = sythetic_image + d[self.source_fg_key] = sythetic_mask return d From 1e38d236332c2b607cc25d3475a3c6be3f02147b Mon Sep 17 00:00:00 2001 From: Chenglong Wang Date: Tue, 23 Aug 2022 22:51:12 +0800 Subject: [PATCH 11/12] Add `SelectSlicesByMask` and `SelectSlicesByMaskD`. Add tests. --- monai_ex/tests/test_CenterMask2DSliceCropD.py | 42 +++++++++++ monai_ex/tests/test_SelectSlicesByMask.py | 52 ++++++++++++++ monai_ex/transforms/croppad/array.py | 70 +++++++++++++++++-- monai_ex/transforms/croppad/dictionary.py | 39 +++++++++-- 4 files changed, 190 insertions(+), 13 deletions(-) create mode 100644 monai_ex/tests/test_CenterMask2DSliceCropD.py create mode 100644 monai_ex/tests/test_SelectSlicesByMask.py diff --git a/monai_ex/tests/test_CenterMask2DSliceCropD.py b/monai_ex/tests/test_CenterMask2DSliceCropD.py new file mode 100644 index 0000000..b52e21e --- /dev/null +++ b/monai_ex/tests/test_CenterMask2DSliceCropD.py @@ -0,0 +1,42 @@ +import pytest + +from monai_ex.transforms.croppad.dictionary import CenterMask2DSliceCropD +from monai.data import Dataset +from monai_ex.transforms import GenerateSyntheticDataD, Compose + + +@pytest.mark.parametrize("crop_size,crop_mode,expected", [((50,50), "single", (1,50,50)), ((50,50), "parallel", (3,50,50))]) +def test_fullimiage2dslicecropd(crop_size, crop_mode, expected): + dim = 3 + spatial_size = (100,) * dim + + generator = GenerateSyntheticDataD( + ["image", "label"], + *spatial_size, + num_objs=1, + rad_max=5, + rad_min=4, + noise_max=0.5, + num_seg_classes=1, + channel_dim=0, + ) + + source_dataset = Dataset( + [{"image": 'dummy.nii', "label": 'dummy_label.nii'} for i in range(2)], + transform=Compose([ + generator, + CenterMask2DSliceCropD( + keys="image", + mask_key="label", + roi_size=crop_size, + crop_mode=crop_mode, + center_mode="center", + z_axis=2, + n_slices=3 + ) + ]) + ) + + output_item = source_dataset[0] + + assert output_item['image'].shape == expected diff --git a/monai_ex/tests/test_SelectSlicesByMask.py b/monai_ex/tests/test_SelectSlicesByMask.py new file mode 100644 index 0000000..7bc7bc9 --- /dev/null +++ b/monai_ex/tests/test_SelectSlicesByMask.py @@ -0,0 +1,52 @@ +import pytest + +import numpy as np +from monai_ex.transforms.croppad.array import SelectSlicesByMask +from monai_ex.transforms.croppad.dictionary import SelectSlicesByMaskD +from monai_ex.transforms import GenerateSyntheticData, GenerateSyntheticDataD + +def test_selectslicesbymask(): + dim = 3 + spatial_size = (100,) * dim + + generator = GenerateSyntheticData( + *spatial_size, + num_objs=1, + rad_max=5, + rad_min=4, + noise_max=0, + num_seg_classes=1, + channel_dim=0, + ) + + image, label = generator(None) + cropper = SelectSlicesByMask(z_axis=2, center_mode='center', mask_data=label) + img_slice = cropper(image) + + assert img_slice.shape == (1, 100, 100) + assert np.count_nonzero(img_slice) > 0 + + +def test_selectslicesbymaskdict(): + dim = 3 + spatial_size = (100,) * dim + + generator = GenerateSyntheticDataD( + ["image", "label"], + *spatial_size, + num_objs=1, + rad_max=5, + rad_min=4, + noise_max=0, + num_seg_classes=1, + channel_dim=0, + ) + + outputs = generator({"image": "1", "label": "1"}) + cropper = SelectSlicesByMaskD(keys=["image", "label"], mask_key="label", z_axis=2, center_mode='center') + img_slice = cropper(outputs) + + assert img_slice["image"].shape == (1, 100, 100) + assert img_slice["label"].shape == (1, 100, 100) + assert np.count_nonzero(img_slice["image"]) > 0 + assert np.count_nonzero(img_slice["label"]) > 0 diff --git a/monai_ex/transforms/croppad/array.py b/monai_ex/transforms/croppad/array.py index 7f5b7d4..f8204a8 100644 --- a/monai_ex/transforms/croppad/array.py +++ b/monai_ex/transforms/croppad/array.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Sequence, Union, Any +from typing import List, Optional, Sequence, Union, Any, Callable import numpy as np import torch @@ -9,6 +9,7 @@ fall_back_tuple, ) from monai.transforms.utils import ( + is_positive, map_binary_to_indices, generate_spatial_bounding_box, generate_pos_neg_label_crop_centers, @@ -25,15 +26,21 @@ class CenterMask2DSliceCrop(Transform): - """ - Extract 2D slices from the image at the - center of mask with specified ROI size. + """Extract 2D slices from the image at the + center of mask with specified ROI size. Args: - roi_size: the spatial size of the crop region e.g. [224,224,128] - If its components have non-positive values, the corresponding size of input image will be used. - """ + roi_size (Union[Sequence[int], int]): the 2D spatial size of the crop region e.g. [224,224] + crop_mode (str): 2D slice crop mode: "single", "cross", "parallel" + z_axis (int): the index of z axis (channel dim not counted) + center_mode (Optional[str], optional): center point calculation mode: "center", "maximum". Defaults to "center". + mask_data (Optional[np.ndarray], optional): mask data. Defaults to None. + n_slices (int, optional): the slice# will be croped, if crop_mode is "cross" or "parallel". Defaults to 3. + Raises: + ValueError: _description_ + ValueError: _description_ + """ def __init__( self, roi_size: Union[Sequence[int], int], @@ -471,6 +478,55 @@ def __call__( return results +class SelectSlicesByMask(CenterMask2DSliceCrop): + backend = SpatialCrop.backend + + def __init__( + self, + z_axis: int, + center_mode: Optional[str] = "center", + mask_data: Optional[np.ndarray] = None, + mask_select_fn: Callable = is_positive, + ) -> None: + """Select one slice based on mask data + + Args: + roi_size (Union[Sequence[int], int]): the 2D spatial size of the crop region e.g. [224,224] + z_axis (int): the index of z axis (channel dim not counted) + center_mode (Optional[str], optional): center point calculation mode: "center", "maximum". Defaults to "center". + mask_data (Optional[np.ndarray], optional): mask data. Defaults to None. + """ + self.mask_select_fn = mask_select_fn + super().__init__(None, "single", z_axis, center_mode, mask_data, 1) + + def __call__( + self, + img: np.ndarray, + msk: Optional[np.ndarray] = None, + ) -> Any: + if self.mask_data is None and msk is None: + raise ValueError("Unknown mask_data.") + + mask_data_ = msk if msk is not None else self.mask_data + mask_data_ = self.mask_select_fn(mask_data_) + + if mask_data_.shape[0] != 1 and mask_data_.shape[0] != img.shape[0]: + raise ValueError( + "When mask_data is not single channel, mask_data channels must match img, " + f"got img={img.shape[0]} mask_data={mask_data_.shape[0]}." + ) + + center = self.get_center_pos(mask_data_, self.z_axis) + slice_idx = center[self.z_axis] + + if isinstance(img, np.ndarray): + return np.take(img, slice_idx, axis=self.z_axis + 1) + elif isinstance(img, torch.Tensor): + return torch.index_select(img, dim=self.z_axis + 1, index=torch.tensor(slice_idx)).squeeze(self.z_axis + 1) + else: + raise NotImplementedError(f"Only support np.array and torch.Tensor, but got {type(img)}") + + class RandSelectSlicesFromImage(Randomizable): backend = SpatialCrop.backend diff --git a/monai_ex/transforms/croppad/dictionary.py b/monai_ex/transforms/croppad/dictionary.py index a9963f1..e671c4a 100644 --- a/monai_ex/transforms/croppad/dictionary.py +++ b/monai_ex/transforms/croppad/dictionary.py @@ -1,4 +1,4 @@ -from typing import Dict, Hashable, Mapping, Optional, Sequence, Union, List +from typing import Dict, Hashable, Mapping, Optional, Sequence, Union, List, Callable import torch import numpy as np @@ -10,6 +10,7 @@ generate_pos_neg_label_crop_centers, ) from monai.transforms import ( + is_positive, RandCropByPosNegLabeld, SpatialCrop, ResizeWithPadOrCrop, @@ -27,6 +28,7 @@ FullMask2DSliceCrop, GetMaxSlices3direcCrop, RandSelectSlicesFromImage, + SelectSlicesByMask, ) @@ -406,15 +408,39 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n return results -class RandSelectSlicesFromImaged(Randomizable, MapTransform): - backend = RandSelectSlicesFromImage.backend +class SelectSlicesByMaskd(MapTransform): + backend = SelectSlicesByMask.backend def __init__( self, keys: KeysCollection, - dim: int = 0, - num_samples: int = 1, - allow_missing_keys: bool = False + mask_key: str, + z_axis: int, + center_mode: Optional[str] = "center", + mask_select_fn: Callable = is_positive, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.mask_key = mask_key + self.cropper = SelectSlicesByMask( + z_axis=z_axis, center_mode=center_mode, mask_data=None, mask_select_fn=mask_select_fn + ) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + data = d[key] + mask = d[self.mask_key] + d[key] = self.cropper(data, mask) + + return d + + +class RandSelectSlicesFromImaged(Randomizable, MapTransform): + backend = RandSelectSlicesFromImage.backend + + def __init__( + self, keys: KeysCollection, dim: int = 0, num_samples: int = 1, allow_missing_keys: bool = False ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) self.dim = dim @@ -451,3 +477,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N RandCropByPosNegLabelExD = RandCropByPosNegLabelExDict = RandCropByPosNegLabelExd RandCrop2dByPosNegLabelD = RandCrop2dByPosNegLabelDict = RandCrop2dByPosNegLabeld RandSelectSlicesFromImageD = RandSelectSlicesFromImageDict = RandSelectSlicesFromImaged +SelectSlicesByMaskD = SelectSlicesByMaskDict = SelectSlicesByMaskd From 4e90bec8490e468536d3f5fef2152e8ab858d110 Mon Sep 17 00:00:00 2001 From: Chenglong Wang Date: Mon, 29 Aug 2022 18:59:58 +0800 Subject: [PATCH 12/12] Minor fixed. --- monai_ex/transforms/utility/array.py | 43 +++++++++++++++++------ monai_ex/transforms/utility/dictionary.py | 4 ++- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/monai_ex/transforms/utility/array.py b/monai_ex/transforms/utility/array.py index 440dd23..7f08a5b 100644 --- a/monai_ex/transforms/utility/array.py +++ b/monai_ex/transforms/utility/array.py @@ -9,7 +9,7 @@ from monai.transforms import DataStats, SaveImage, CastToType from monai.transforms.utils import generate_pos_neg_label_crop_centers, map_binary_to_indices, is_positive from monai.config import NdarrayTensor, DtypeLike -from monai_ex.utils import convert_data_type_ex, bbox_ND +from monai_ex.utils import convert_data_type_ex, bbox_ND, ensure_list class CastToTypeEx(CastToType): @@ -169,7 +169,7 @@ def __init__( # pytype: disable=annotation-type-mismatch select_labels: Union[Sequence[int], int], merge_channels: bool = False, ) -> None: # pytype: disable=annotation-type-mismatch - self.select_labels = ensure_tuple(select_labels) + self.select_labels = ensure_list(select_labels) self.merge_channels = merge_channels def randomize(self): @@ -208,7 +208,8 @@ class RandSoftCopyPaste(RandomizableTransform): Reference: `https://arxiv.org/ftp/arxiv/papers/2203/2203.10507.pdf` Args: - k_erode (int): erosion iteration num. + k_erode (int | float): erosion iteration num. + Float value denotes the percentage from edge to center. k_dilate (int): dilation iteration num. alpha (float, optional): transparence ratio. Defaults to 0.8. prob (float, optional): Probability to perform this aug. Defaults to 0.1. @@ -221,7 +222,7 @@ class RandSoftCopyPaste(RandomizableTransform): def __init__( self, - k_erode: int, + k_erode: Union[int, float], k_dilate: int, alpha: float = 0.8, prob: float = 0.1, @@ -229,6 +230,7 @@ def __init__( source_label_value: Optional[int] = None, strict_paste: bool = False, tolerance: int = 100, + shift_source_intensity: bool = False, log_name: Optional[str] = None, ) -> None: RandomizableTransform.__init__(self, prob) @@ -239,6 +241,7 @@ def __init__( self.source_label_value = source_label_value self.strict_paste = strict_paste self.tolerance = tolerance + self.shift_intensity = shift_source_intensity self.logger = logging.getLogger(log_name) def soften(self, src_mask): @@ -253,8 +256,15 @@ def soften(self, src_mask): else: src_mask = (src_mask == self.source_label_value).squeeze(0) + def _minmax_norm(x): + minValue, maxValue = np.min(x), np.max(x) + return (x - minValue) / (maxValue - minValue) + struct = ndi.generate_binary_structure(src_mask.ndim, src_mask.ndim - 1) - mask = ndi.binary_erosion(src_mask, struct, iterations=self.k_erode).astype(src_mask.dtype) + if 0 < self.k_erode < 1: + mask = _minmax_norm(ndi.distance_transform_edt(src_mask)) > self.k_erode + else: + mask = ndi.binary_erosion(src_mask, struct, iterations=self.k_erode).astype(src_mask.dtype) for j in range(self.k_dilate): mask_binary = np.where(mask > 0, 1, 0) mask_dilate = ndi.binary_dilation(mask_binary, struct, iterations=1).astype(mask.dtype) @@ -307,7 +317,7 @@ def compute_target_position(self, src_mask, softed_mask, target_image, target_ma def paste( self, - softed_image: NdarrayTensor, + source_image: NdarrayTensor, origin_mask: NdarrayTensor, softed_mask: NdarrayTensor, target_image: NdarrayTensor, @@ -328,8 +338,20 @@ def paste( shifted_src_image = np.zeros_like(target_image) shifted_src_mask = np.zeros_like(target_image) - shifted_src_image[self.target_slices[0]] = softed_image[src_slices] + + offset = 0 + if self.shift_intensity: + src_mean_intensity = np.mean(source_image[src_slices]) + tar_mean_intensity = np.mean(target_image[self.target_slices[0]]) + offset = tar_mean_intensity - src_mean_intensity + + # softed_image = source_image * softed_mask + + src_region = source_image[src_slices] + offset + shifted_src_image[self.target_slices[0]] = src_region shifted_src_mask[self.target_slices[0]] = softed_mask[src_slices] + shifted_src_image *= shifted_src_mask + # shifted_src_image = np.clip(shifted_src_image, np.min(target_image), np.max(target_image)) sythetic_image = shifted_src_image + (1 - shifted_src_mask) * target_image shifted_src_mask[self.target_slices[0]] = origin_mask[src_slices] return sythetic_image, shifted_src_mask @@ -364,9 +386,8 @@ def __call__( if source_image.shape[0] > 1: softed_fg_mask = np.repeat(softed_fg_mask, repeats=source_image.shape[0], axis=0) - softed_image = source_image * softed_fg_mask processed_data = self.paste( - softed_image=softed_image, + source_image=source_image, origin_mask=source_fg_mask, softed_mask=softed_fg_mask, target_image=image, @@ -382,7 +403,7 @@ def __call__( shifted_src_mask[shifted_src_mask > 0] = self.source_label_value sythetic_mask = shifted_src_mask else: - fg_mask[shifted_src_mask > 0] = self.source_label_value - sythetic_mask = fg_mask + sythetic_mask = fg_mask.copy() + sythetic_mask[shifted_src_mask > 0] = self.source_label_value return sythetic_image, sythetic_mask diff --git a/monai_ex/transforms/utility/dictionary.py b/monai_ex/transforms/utility/dictionary.py index ede044d..dda089b 100644 --- a/monai_ex/transforms/utility/dictionary.py +++ b/monai_ex/transforms/utility/dictionary.py @@ -1,6 +1,6 @@ import logging from typing import Callable, Dict, Hashable, Mapping, Optional, Sequence, Union, List - +import copy import numpy as np import torch from torch.utils.data import Dataset @@ -450,6 +450,7 @@ def __init__( mask_select_fn: Callable = is_positive, strict_paste: bool = False, tolerance: int = 100, + shift_source_intensity: bool = False, log_name: Optional[str] = None, ) -> None: super().__init__(keys) @@ -466,6 +467,7 @@ def __init__( source_label_value=source_fg_value, strict_paste=strict_paste, tolerance=tolerance, + shift_source_intensity=shift_source_intensity, log_name=log_name, ) self.logger = logging.getLogger(log_name)