diff --git a/monai_ex/tests/test_CenterMask2DSliceCropD.py b/monai_ex/tests/test_CenterMask2DSliceCropD.py new file mode 100644 index 0000000..b52e21e --- /dev/null +++ b/monai_ex/tests/test_CenterMask2DSliceCropD.py @@ -0,0 +1,42 @@ +import pytest + +from monai_ex.transforms.croppad.dictionary import CenterMask2DSliceCropD +from monai.data import Dataset +from monai_ex.transforms import GenerateSyntheticDataD, Compose + + +@pytest.mark.parametrize("crop_size,crop_mode,expected", [((50,50), "single", (1,50,50)), ((50,50), "parallel", (3,50,50))]) +def test_fullimiage2dslicecropd(crop_size, crop_mode, expected): + dim = 3 + spatial_size = (100,) * dim + + generator = GenerateSyntheticDataD( + ["image", "label"], + *spatial_size, + num_objs=1, + rad_max=5, + rad_min=4, + noise_max=0.5, + num_seg_classes=1, + channel_dim=0, + ) + + source_dataset = Dataset( + [{"image": 'dummy.nii', "label": 'dummy_label.nii'} for i in range(2)], + transform=Compose([ + generator, + CenterMask2DSliceCropD( + keys="image", + mask_key="label", + roi_size=crop_size, + crop_mode=crop_mode, + center_mode="center", + z_axis=2, + n_slices=3 + ) + ]) + ) + + output_item = source_dataset[0] + + assert output_item['image'].shape == expected diff --git a/monai_ex/tests/test_RandSoftCopyPaste.py b/monai_ex/tests/test_RandSoftCopyPaste.py new file mode 100644 index 0000000..6d029ff --- /dev/null +++ b/monai_ex/tests/test_RandSoftCopyPaste.py @@ -0,0 +1,53 @@ +import pytest + +from pathlib import Path +import nibabel as nib + +import numpy as np +from monai_ex.transforms.utility.array import RandSoftCopyPaste +from monai.data.synthetic import create_test_image_3d +from monai_ex.transforms.io.array import GenerateSyntheticData + + +@pytest.mark.parametrize("dim", [2, 3]) +@pytest.mark.parametrize("prob", [0, 1]) +def test_randsoftcopypaste(dim, prob): + spatial_size = (100,) * dim + generator = GenerateSyntheticData( + *spatial_size, + num_objs=1, + rad_max=10, + rad_min=9, + noise_max=0.2, + num_seg_classes=1, + channel_dim=0, + ) + + src_image, src_mask = generator(None) + tar_image, tar_mask = generator(None) + volume_size = np.count_nonzero(src_mask) + np.count_nonzero(tar_mask) + + print("dummy data, mask shape:", src_image.shape, src_image.shape) + print("mask label: ", np.unique(src_mask)) + sythetic_img, sythetic_msk = RandSoftCopyPaste( + 2, 4, prob=prob, mask_select_fn=lambda x: x==0, source_label_value=1 + )(tar_image, tar_mask, src_image, src_mask) + if prob == 0: + assert np.all(sythetic_img == tar_image) + assert np.all(sythetic_msk == tar_mask) + else: + assert sythetic_img.shape == (1, *spatial_size) + assert volume_size/2 <= np.count_nonzero(sythetic_msk) <= volume_size + + # save_fpath = Path.home() / f"sythetic_img_{dim}.nii.gz" + # nib.save(nib.Nifti1Image(sythetic_img.squeeze(), np.eye(4)), save_fpath) + + sythetic_img, sythetic_msk = RandSoftCopyPaste( + 2, 4, prob=prob, source_label_value=1 + )(tar_image, None, src_image, src_mask) + if prob == 0: + assert np.all(sythetic_img == tar_image) + assert sythetic_msk is None + else: + assert sythetic_img.shape == (1, *spatial_size) + assert volume_size/2 <= np.count_nonzero(sythetic_msk) <= volume_size diff --git a/monai_ex/tests/test_RandSoftCopyPasteD.py b/monai_ex/tests/test_RandSoftCopyPasteD.py new file mode 100644 index 0000000..65d8dab --- /dev/null +++ b/monai_ex/tests/test_RandSoftCopyPasteD.py @@ -0,0 +1,151 @@ +import pytest + +from pathlib import Path +import nibabel as nib + +import numpy as np +from monai_ex.transforms.utility.dictionary import RandSoftCopyPasteD +from monai.data.synthetic import create_test_image_3d +from monai.data import Dataset +from monai_ex.transforms import MapTransform, GenerateSyntheticData, Compose, adaptor + + + +class GenerateSyntheticDataD(MapTransform): + def __init__( + self, + keys, + label_key, + height: int, + width: int, + depth: int = None, + num_objs: int = 12, + rad_max: int = 30, + rad_min: int = 5, + noise_max: float = 0.0, + num_seg_classes: int = 5, + channel_dim: int = None, + random_state: np.random.RandomState = None, + allow_missing_keys: bool = False, + ): + super().__init__(keys, allow_missing_keys) + + self.label_key = label_key + self.loader = GenerateSyntheticData( + height, + width, + depth, + num_objs, + rad_max, + rad_min, + noise_max, + num_seg_classes, + channel_dim, + random_state, + ) + + def __call__(self, filename: dict): + test_data = self.loader(None) + + data = {} + for key in self.keys: + data[key] = test_data[0] + data[self.label_key] = test_data[1] + return data + + +@pytest.mark.parametrize("dim", [2, 3]) +def test_randsoftcopypaste(dim): + data_num = 2 + spatial_size = (100,) * dim + generator = GenerateSyntheticDataD( + "image", + "label", + *spatial_size, + num_objs=1, + rad_max=5, + rad_min=4, + noise_max=0.5, + num_seg_classes=1, + channel_dim=0, + ) + + dummy_fpath = [{"image": "d.nii", "label": "l.nii"} for i in range(data_num)] + + output = generator(dummy_fpath[0]) + volume_size = np.count_nonzero(output["label"]) + + source_dataset = Dataset( + [{"image": 'dummy.nii', "label": 'dummy_label.nii'} for i in range(data_num)], + transform=generator + ) + + dataset = Dataset( + dummy_fpath, transform=Compose([ + generator, + RandSoftCopyPasteD( + keys="image", mask_key="label", + source_dataset=source_dataset, # will generate image & mask + source_fg_key="label", + source_fg_value=1, + k_erode=2, + k_dilate=5, + alpha=0.8, + prob=1, + mask_select_fn=lambda x: x == 0, + ) + ]) + ) + + for i, item in enumerate(dataset): + image, label = item["image"], item["label"] + + # save_fpath = Path.home() / f"sythetic_{dim}Dimg_{i}.nii.gz" + # nib.save(nib.Nifti1Image(image.squeeze(), np.eye(4)), save_fpath) + # save_fpath = Path.home() / f"sythetic_{dim}Dlabel_{i}.nii.gz" + # nib.save(nib.Nifti1Image(label.squeeze(), np.eye(4)), save_fpath) + + assert volume_size < np.count_nonzero(label) <= 2 * volume_size + + + +@pytest.mark.parametrize("dim", [2, 3]) +def test_randsoftcopypaste_multiimage(dim): + data_num = 2 + spatial_size = (100,) * dim + generator = GenerateSyntheticDataD( + ["image1", "image2"], + "label", + *spatial_size, + num_objs=1, + rad_max=5, + rad_min=4, + noise_max=0, + num_seg_classes=1, + channel_dim=0, + ) + + dummy_fpath = [{"image1": "d.nii", "image2": "d.nii", "label": "l.nii"} for i in range(data_num)] + source_dataset = Dataset( + [{"image1": '1', "image2": "2", "label": '1'} for i in range(data_num)], + transform=generator + ) + + outputs = generator({"image1": '1', "image2": '2', "label": '1'}) + + generator = RandSoftCopyPasteD( + keys=["image1", "image2"], mask_key="label", + source_dataset=source_dataset, + source_fg_key="label", + source_fg_value=1, + k_erode=2, + k_dilate=5, + alpha=0.8, + prob=1, + mask_select_fn=lambda x: x == 0, + ) + + generated_item = generator(outputs) + assert generated_item["image1"].shape == (1, *spatial_size) + assert generated_item["image2"].shape == (1, *spatial_size) + assert np.all(generated_item["image1"] == generated_item["image2"]) diff --git a/monai_ex/tests/test_SelectSlicesByMask.py b/monai_ex/tests/test_SelectSlicesByMask.py new file mode 100644 index 0000000..7bc7bc9 --- /dev/null +++ b/monai_ex/tests/test_SelectSlicesByMask.py @@ -0,0 +1,52 @@ +import pytest + +import numpy as np +from monai_ex.transforms.croppad.array import SelectSlicesByMask +from monai_ex.transforms.croppad.dictionary import SelectSlicesByMaskD +from monai_ex.transforms import GenerateSyntheticData, GenerateSyntheticDataD + +def test_selectslicesbymask(): + dim = 3 + spatial_size = (100,) * dim + + generator = GenerateSyntheticData( + *spatial_size, + num_objs=1, + rad_max=5, + rad_min=4, + noise_max=0, + num_seg_classes=1, + channel_dim=0, + ) + + image, label = generator(None) + cropper = SelectSlicesByMask(z_axis=2, center_mode='center', mask_data=label) + img_slice = cropper(image) + + assert img_slice.shape == (1, 100, 100) + assert np.count_nonzero(img_slice) > 0 + + +def test_selectslicesbymaskdict(): + dim = 3 + spatial_size = (100,) * dim + + generator = GenerateSyntheticDataD( + ["image", "label"], + *spatial_size, + num_objs=1, + rad_max=5, + rad_min=4, + noise_max=0, + num_seg_classes=1, + channel_dim=0, + ) + + outputs = generator({"image": "1", "label": "1"}) + cropper = SelectSlicesByMaskD(keys=["image", "label"], mask_key="label", z_axis=2, center_mode='center') + img_slice = cropper(outputs) + + assert img_slice["image"].shape == (1, 100, 100) + assert img_slice["label"].shape == (1, 100, 100) + assert np.count_nonzero(img_slice["image"]) > 0 + assert np.count_nonzero(img_slice["label"]) > 0 diff --git a/monai_ex/tests/test_bbox_nd.py b/monai_ex/tests/test_bbox_nd.py new file mode 100644 index 0000000..0920e83 --- /dev/null +++ b/monai_ex/tests/test_bbox_nd.py @@ -0,0 +1,23 @@ +import pytest +from monai_ex.utils.misc import bbox_ND +import numpy as np + +dummy_data_3d = np.zeros([10, 10, 10]) +dummy_data_3d[4:7, 4:8, 4:9] = 1 + +dummy_data_2d = np.zeros([10, 10]) +dummy_data_2d[4:7, 4:8] = 1 + +@pytest.mark.parametrize('data', [dummy_data_2d, dummy_data_3d]) +def test_bbox_nd(data): + bounding = bbox_ND(data, False) + if len(data) == 3: + assert bounding == (4, 6, 4, 7, 4, 8) + elif len(data) == 2: + assert bounding == (4, 6, 4, 7) + + bbox_range = bbox_ND(data, True) + if len(data) == 3: + assert bbox_range == (2, 3, 4) + elif len(data) == 2: + assert bbox_range == (2, 3) diff --git a/monai_ex/transforms/compose.py b/monai_ex/transforms/compose.py index 5c3d42e..c4ab231 100644 --- a/monai_ex/transforms/compose.py +++ b/monai_ex/transforms/compose.py @@ -53,8 +53,6 @@ def __call__(self, input_): return apply_transform(self.selected_trans, input_) -ReturnType = TypeVar("ReturnType") - def _apply_transform( transform: Callable[..., ReturnType], parameters: Any, unpack_parameters: bool = False ) -> ReturnType: diff --git a/monai_ex/transforms/croppad/array.py b/monai_ex/transforms/croppad/array.py index 7f5b7d4..f8204a8 100644 --- a/monai_ex/transforms/croppad/array.py +++ b/monai_ex/transforms/croppad/array.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Sequence, Union, Any +from typing import List, Optional, Sequence, Union, Any, Callable import numpy as np import torch @@ -9,6 +9,7 @@ fall_back_tuple, ) from monai.transforms.utils import ( + is_positive, map_binary_to_indices, generate_spatial_bounding_box, generate_pos_neg_label_crop_centers, @@ -25,15 +26,21 @@ class CenterMask2DSliceCrop(Transform): - """ - Extract 2D slices from the image at the - center of mask with specified ROI size. + """Extract 2D slices from the image at the + center of mask with specified ROI size. Args: - roi_size: the spatial size of the crop region e.g. [224,224,128] - If its components have non-positive values, the corresponding size of input image will be used. - """ + roi_size (Union[Sequence[int], int]): the 2D spatial size of the crop region e.g. [224,224] + crop_mode (str): 2D slice crop mode: "single", "cross", "parallel" + z_axis (int): the index of z axis (channel dim not counted) + center_mode (Optional[str], optional): center point calculation mode: "center", "maximum". Defaults to "center". + mask_data (Optional[np.ndarray], optional): mask data. Defaults to None. + n_slices (int, optional): the slice# will be croped, if crop_mode is "cross" or "parallel". Defaults to 3. + Raises: + ValueError: _description_ + ValueError: _description_ + """ def __init__( self, roi_size: Union[Sequence[int], int], @@ -471,6 +478,55 @@ def __call__( return results +class SelectSlicesByMask(CenterMask2DSliceCrop): + backend = SpatialCrop.backend + + def __init__( + self, + z_axis: int, + center_mode: Optional[str] = "center", + mask_data: Optional[np.ndarray] = None, + mask_select_fn: Callable = is_positive, + ) -> None: + """Select one slice based on mask data + + Args: + roi_size (Union[Sequence[int], int]): the 2D spatial size of the crop region e.g. [224,224] + z_axis (int): the index of z axis (channel dim not counted) + center_mode (Optional[str], optional): center point calculation mode: "center", "maximum". Defaults to "center". + mask_data (Optional[np.ndarray], optional): mask data. Defaults to None. + """ + self.mask_select_fn = mask_select_fn + super().__init__(None, "single", z_axis, center_mode, mask_data, 1) + + def __call__( + self, + img: np.ndarray, + msk: Optional[np.ndarray] = None, + ) -> Any: + if self.mask_data is None and msk is None: + raise ValueError("Unknown mask_data.") + + mask_data_ = msk if msk is not None else self.mask_data + mask_data_ = self.mask_select_fn(mask_data_) + + if mask_data_.shape[0] != 1 and mask_data_.shape[0] != img.shape[0]: + raise ValueError( + "When mask_data is not single channel, mask_data channels must match img, " + f"got img={img.shape[0]} mask_data={mask_data_.shape[0]}." + ) + + center = self.get_center_pos(mask_data_, self.z_axis) + slice_idx = center[self.z_axis] + + if isinstance(img, np.ndarray): + return np.take(img, slice_idx, axis=self.z_axis + 1) + elif isinstance(img, torch.Tensor): + return torch.index_select(img, dim=self.z_axis + 1, index=torch.tensor(slice_idx)).squeeze(self.z_axis + 1) + else: + raise NotImplementedError(f"Only support np.array and torch.Tensor, but got {type(img)}") + + class RandSelectSlicesFromImage(Randomizable): backend = SpatialCrop.backend diff --git a/monai_ex/transforms/croppad/dictionary.py b/monai_ex/transforms/croppad/dictionary.py index a9963f1..e671c4a 100644 --- a/monai_ex/transforms/croppad/dictionary.py +++ b/monai_ex/transforms/croppad/dictionary.py @@ -1,4 +1,4 @@ -from typing import Dict, Hashable, Mapping, Optional, Sequence, Union, List +from typing import Dict, Hashable, Mapping, Optional, Sequence, Union, List, Callable import torch import numpy as np @@ -10,6 +10,7 @@ generate_pos_neg_label_crop_centers, ) from monai.transforms import ( + is_positive, RandCropByPosNegLabeld, SpatialCrop, ResizeWithPadOrCrop, @@ -27,6 +28,7 @@ FullMask2DSliceCrop, GetMaxSlices3direcCrop, RandSelectSlicesFromImage, + SelectSlicesByMask, ) @@ -406,15 +408,39 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n return results -class RandSelectSlicesFromImaged(Randomizable, MapTransform): - backend = RandSelectSlicesFromImage.backend +class SelectSlicesByMaskd(MapTransform): + backend = SelectSlicesByMask.backend def __init__( self, keys: KeysCollection, - dim: int = 0, - num_samples: int = 1, - allow_missing_keys: bool = False + mask_key: str, + z_axis: int, + center_mode: Optional[str] = "center", + mask_select_fn: Callable = is_positive, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.mask_key = mask_key + self.cropper = SelectSlicesByMask( + z_axis=z_axis, center_mode=center_mode, mask_data=None, mask_select_fn=mask_select_fn + ) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + data = d[key] + mask = d[self.mask_key] + d[key] = self.cropper(data, mask) + + return d + + +class RandSelectSlicesFromImaged(Randomizable, MapTransform): + backend = RandSelectSlicesFromImage.backend + + def __init__( + self, keys: KeysCollection, dim: int = 0, num_samples: int = 1, allow_missing_keys: bool = False ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) self.dim = dim @@ -451,3 +477,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N RandCropByPosNegLabelExD = RandCropByPosNegLabelExDict = RandCropByPosNegLabelExd RandCrop2dByPosNegLabelD = RandCrop2dByPosNegLabelDict = RandCrop2dByPosNegLabeld RandSelectSlicesFromImageD = RandSelectSlicesFromImageDict = RandSelectSlicesFromImaged +SelectSlicesByMaskD = SelectSlicesByMaskDict = SelectSlicesByMaskd diff --git a/monai_ex/transforms/io/array.py b/monai_ex/transforms/io/array.py index 65039a9..1661d66 100644 --- a/monai_ex/transforms/io/array.py +++ b/monai_ex/transforms/io/array.py @@ -31,7 +31,7 @@ def __init__( self.channel_dim = channel_dim self.random_state = random_state - def __call__(self, data: Any): + def __call__(self, data: Optional[Any] = None): if self.depth: img, seg = create_test_image_3d( self.height, @@ -57,7 +57,7 @@ def __call__(self, data: Any): self.channel_dim, self.random_state, ) - + if self.num_seg_classes < 1: return img, img @@ -81,16 +81,14 @@ def __init__( self.channel_dim = channel_dim self.random_state = random_state - def __call__(self, data: Any): + def __call__(self, data: Optional[Any] = None): if self.depth: image = np.random.rand(self.width, self.height, self.depth) else: image = np.random.rand(self.width, self.height) if self.channel_dim is not None: - if not ( - isinstance(self.channel_dim, int) and self.channel_dim in (-1, 0, 3) - ): + if not (isinstance(self.channel_dim, int) and self.channel_dim in (-1, 0, 3)): raise AssertionError("invalid channel dim.") image = image[None] if self.channel_dim == 0 else image[..., None] diff --git a/monai_ex/transforms/utility/array.py b/monai_ex/transforms/utility/array.py index b2c1df5..7f08a5b 100644 --- a/monai_ex/transforms/utility/array.py +++ b/monai_ex/transforms/utility/array.py @@ -5,10 +5,11 @@ import torch from scipy import ndimage as ndi -from monai.transforms.compose import Transform, Randomizable +from monai.transforms.compose import Transform, Randomizable, RandomizableTransform from monai.transforms import DataStats, SaveImage, CastToType +from monai.transforms.utils import generate_pos_neg_label_crop_centers, map_binary_to_indices, is_positive from monai.config import NdarrayTensor, DtypeLike -from monai_ex.utils import convert_data_type_ex +from monai_ex.utils import convert_data_type_ex, bbox_ND, ensure_list class CastToTypeEx(CastToType): @@ -24,9 +25,7 @@ def __init__(self, dtype=np.float32) -> None: """ self.dtype = dtype - def __call__( - self, img: Any, dtype: Optional[Union[DtypeLike, torch.dtype]] = None - ) -> Any: + def __call__(self, img: Any, dtype: Optional[Union[DtypeLike, torch.dtype]] = None) -> Any: """ Apply the transform to `img`, assuming `img` is a numpy array or PyTorch Tensor. @@ -39,12 +38,8 @@ def __call__( """ type_list = (torch.Tensor, np.ndarray, int, bool, float, list, tuple) if not isinstance(img, type_list): - raise TypeError( - f"img must be one of ({type_list}) but is {type(img).__name__}." - ) - img_out, *_ = convert_data_type_ex( - img, output_type=type(img), dtype=dtype or self.dtype - ) + raise TypeError(f"img must be one of ({type_list}) but is {type(img).__name__}.") + img_out, *_ = convert_data_type_ex(img, output_type=type(img), dtype=dtype or self.dtype) return img_out @@ -53,7 +48,7 @@ class ToTensorEx(Transform): Converts the input image to a tensor without applying any other transformations. """ - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + def __call__(self, img: NdarrayTensor) -> torch.Tensor: """ Apply the transform to `img` and make it contiguous. """ @@ -105,9 +100,7 @@ def __call__( data_value: Optional[bool] = None, additional_info: Optional[Callable] = None, ) -> NdarrayTensor: - img = super().__init__( - img, prefix, data_type, data_shape, value_range, data_value, additional_info - ) + img = super().__init__(img, prefix, data_type, data_shape, value_range, data_value, additional_info) if self.save_data: saver = SaveImage( @@ -176,7 +169,7 @@ def __init__( # pytype: disable=annotation-type-mismatch select_labels: Union[Sequence[int], int], merge_channels: bool = False, ) -> None: # pytype: disable=annotation-type-mismatch - self.select_labels = ensure_tuple(select_labels) + self.select_labels = ensure_list(select_labels) self.merge_channels = merge_channels def randomize(self): @@ -204,12 +197,213 @@ def __call__( if img.shape[0] > 1: data = img[[self.select_label]] else: - data = np.where(np.in1d(img, self.select_label), True, False).reshape( - img.shape + data = np.where(np.in1d(img, self.select_label), True, False).reshape(img.shape) + + return np.any(data, axis=0, keepdims=True) if (merge_channels or self.merge_channels) else data + + +class RandSoftCopyPaste(RandomizableTransform): + """ + Perform Soft Copy&Paste augmentation. + Reference: `https://arxiv.org/ftp/arxiv/papers/2203/2203.10507.pdf` + + Args: + k_erode (int | float): erosion iteration num. + Float value denotes the percentage from edge to center. + k_dilate (int): dilation iteration num. + alpha (float, optional): transparence ratio. Defaults to 0.8. + prob (float, optional): Probability to perform this aug. Defaults to 0.1. + mask_select_fn (Callable, optional): function to select expected foreground, default is to select values > 0. + source_label_value (Optional[int], optional): source foregound value. Defaults to None. + strict_paste (bool, optional): whether to strictly paste source mask inside of target mask region. Defaults to False. + tolerance (int, optional): even in strict_paste mode, there is a tolerance to allow paste to the edge. Defaults to 10. + log_name (Optional[str], optional): logger name. Defaults to None. + """ + + def __init__( + self, + k_erode: Union[int, float], + k_dilate: int, + alpha: float = 0.8, + prob: float = 0.1, + mask_select_fn: Callable = is_positive, + source_label_value: Optional[int] = None, + strict_paste: bool = False, + tolerance: int = 100, + shift_source_intensity: bool = False, + log_name: Optional[str] = None, + ) -> None: + RandomizableTransform.__init__(self, prob) + self.k_erode = k_erode + self.k_dilate = k_dilate + self.alpha = alpha + self.mask_select_fn = mask_select_fn + self.source_label_value = source_label_value + self.strict_paste = strict_paste + self.tolerance = tolerance + self.shift_intensity = shift_source_intensity + self.logger = logging.getLogger(log_name) + + def soften(self, src_mask): + if src_mask.shape[0] > 1: + if self.source_label_value is None: + raise ValueError("Multi-channel label data need to specify label_idx") + else: + src_mask = src_mask[self.source_label_value, ...] + elif src_mask.shape[0] == 1: + if self.source_label_value is None: + src_mask = (src_mask > 0).squeeze(0) + else: + src_mask = (src_mask == self.source_label_value).squeeze(0) + + def _minmax_norm(x): + minValue, maxValue = np.min(x), np.max(x) + return (x - minValue) / (maxValue - minValue) + + struct = ndi.generate_binary_structure(src_mask.ndim, src_mask.ndim - 1) + if 0 < self.k_erode < 1: + mask = _minmax_norm(ndi.distance_transform_edt(src_mask)) > self.k_erode + else: + mask = ndi.binary_erosion(src_mask, struct, iterations=self.k_erode).astype(src_mask.dtype) + for j in range(self.k_dilate): + mask_binary = np.where(mask > 0, 1, 0) + mask_dilate = ndi.binary_dilation(mask_binary, struct, iterations=1).astype(mask.dtype) + mask_alpha = mask_dilate * (self.alpha ** (j + 1)) + mask = (1 - mask) * mask_alpha + mask + + return mask + + def ensure_strict_center(self, softed_mask, target_mask): + src_ranges = tuple( + slice(self.boundingbox[2 * i], self.boundingbox[2 * i + 1]) for i in range(len(self.boundingbox) // 2) + ) + src_slices = (slice(softed_mask.shape[0]), *src_ranges) + shifted_src_mask = np.zeros_like(target_mask) + for tar_slice in self.target_slices: + shifted_src_mask[tar_slice] = softed_mask[src_slices] + if np.count_nonzero(target_mask[shifted_src_mask] == 0) <= self.tolerance: # full contain! + return [tar_slice] # only support sample num = 1 + return [] + + def compute_target_position(self, src_mask, softed_mask, target_image, target_mask) -> None: + n_ch = target_image.shape[0] + self.boundingbox = bbox_ND(softed_mask[0, ...]) + bbox_size = tuple( + self.boundingbox[2 * i + 1] - self.boundingbox[2 * i] for i in range(len(self.boundingbox) // 2) + ) + selected_target_mask = self.mask_select_fn(target_mask) + fg_indices_, bg_indices_ = map_binary_to_indices(selected_target_mask, None, None) + centers = generate_pos_neg_label_crop_centers( + bbox_size, + 10, # pick 10 candidates + 1, + softed_mask.shape[1:], + fg_indices_, + bg_indices_, + self.R, + False, + ) + target_ranges = [] + for center in centers: + target_ranges.append( + tuple(slice(int(center - sz // 2), int(center - sz // 2 + sz)) for center, sz in zip(center, bbox_size)) + ) + self.target_slices = [(slice(n_ch), *ranges) for ranges in target_ranges] + if self.strict_paste: + self.logger.debug("Enter strict filtering mode.") + self.target_slices = self.ensure_strict_center( + src_mask, selected_target_mask ) - return ( - np.any(data, axis=0, keepdims=True) - if (merge_channels or self.merge_channels) - else data + def paste( + self, + source_image: NdarrayTensor, + origin_mask: NdarrayTensor, + softed_mask: NdarrayTensor, + target_image: NdarrayTensor, + target_bg_mask: NdarrayTensor, + randomize: True, + ): + n_ch = target_image.shape[0] + if randomize: + self.compute_target_position(origin_mask, softed_mask, target_image, target_bg_mask) + src_ranges = tuple( + slice(self.boundingbox[2 * i], self.boundingbox[2 * i + 1]) for i in range(len(self.boundingbox) // 2) ) + src_slices = (slice(n_ch), *src_ranges) + + if not self.target_slices: + self.logger.debug("No position is found to paste strictly! Skip copy&paste") + return None + + shifted_src_image = np.zeros_like(target_image) + shifted_src_mask = np.zeros_like(target_image) + + offset = 0 + if self.shift_intensity: + src_mean_intensity = np.mean(source_image[src_slices]) + tar_mean_intensity = np.mean(target_image[self.target_slices[0]]) + offset = tar_mean_intensity - src_mean_intensity + + # softed_image = source_image * softed_mask + + src_region = source_image[src_slices] + offset + shifted_src_image[self.target_slices[0]] = src_region + shifted_src_mask[self.target_slices[0]] = softed_mask[src_slices] + shifted_src_image *= shifted_src_mask + # shifted_src_image = np.clip(shifted_src_image, np.min(target_image), np.max(target_image)) + sythetic_image = shifted_src_image + (1 - shifted_src_mask) * target_image + shifted_src_mask[self.target_slices[0]] = origin_mask[src_slices] + return sythetic_image, shifted_src_mask + + def __call__( + self, + image: NdarrayTensor, + fg_mask: Optional[NdarrayTensor], + bg_mask: NdarrayTensor, + source_image: NdarrayTensor, + source_fg_mask: NdarrayTensor, + softed_fg_mask: Optional[NdarrayTensor] = None, + randomize: bool = True, + ) -> NdarrayTensor: + if randomize: + self.randomize(None) + + if not self._do_transform: + return image, fg_mask + + if self.strict_paste and np.count_nonzero(self.mask_select_fn(bg_mask)) < np.count_nonzero(source_fg_mask): + self.logger.debug("Target mask area is smaller than source foreground area. Skip copy&paste") + return image, fg_mask + + if softed_fg_mask is None: + softed_fg_mask = self.soften(source_fg_mask) + if np.count_nonzero(softed_fg_mask) == 0: + self.logger.debug("Source foreground area is too small to be soften. Skip copy&paste") + return image, fg_mask + + softed_fg_mask = softed_fg_mask[np.newaxis, ...] + if source_image.shape[0] > 1: + softed_fg_mask = np.repeat(softed_fg_mask, repeats=source_image.shape[0], axis=0) + + processed_data = self.paste( + source_image=source_image, + origin_mask=source_fg_mask, + softed_mask=softed_fg_mask, + target_image=image, + target_bg_mask=bg_mask, + randomize=randomize, + ) + if processed_data is None: + return image, fg_mask + + sythetic_image, shifted_src_mask = processed_data + + if fg_mask is None: + shifted_src_mask[shifted_src_mask > 0] = self.source_label_value + sythetic_mask = shifted_src_mask + else: + sythetic_mask = fg_mask.copy() + sythetic_mask[shifted_src_mask > 0] = self.source_label_value + + return sythetic_image, sythetic_mask diff --git a/monai_ex/transforms/utility/dictionary.py b/monai_ex/transforms/utility/dictionary.py index 7665771..dda089b 100644 --- a/monai_ex/transforms/utility/dictionary.py +++ b/monai_ex/transforms/utility/dictionary.py @@ -1,21 +1,25 @@ import logging from typing import Callable, Dict, Hashable, Mapping, Optional, Sequence, Union, List - +import copy import numpy as np import torch +from torch.utils.data import Dataset from monai.config import KeysCollection, NdarrayTensor, NdarrayOrTensor from monai.transforms.compose import MapTransform, Randomizable from monai.utils import ensure_tuple_rep from monai_ex.utils import ensure_list +from monai_ex.utils.exceptions import TransformException -from monai.transforms import SplitChannel +from monai.transforms.utility.array import SplitChannel +from monai.transforms.utils import is_positive from monai_ex.transforms.utility.array import ( CastToTypeEx, ToTensorEx, DataStatsEx, DataLabelling, RandLabelToMask, + RandSoftCopyPaste, ) from monai_ex.transforms import ( @@ -33,9 +37,7 @@ class CastToTypeExd(MapTransform): def __init__( self, keys: KeysCollection, - dtype: Union[ - Sequence[Union[np.dtype, torch.dtype, str]], np.dtype, torch.dtype, str - ] = np.float32, + dtype: Union[Sequence[Union[np.dtype, torch.dtype, str]], np.dtype, torch.dtype, str] = np.float32, ) -> None: """ Args: @@ -74,9 +76,7 @@ def __init__(self, keys: KeysCollection) -> None: super().__init__(keys) self.converter = ToTensorEx() - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key in self.keys: d[key] = self.converter(d[key]) @@ -112,19 +112,9 @@ def __init__( self.logger_handler = logger_handler self.printer = DataStatsEx(logger_handler=logger_handler) - def __call__( - self, data: Mapping[Hashable, NdarrayTensor] - ) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: d = dict(data) - for ( - key, - prefix, - data_type, - data_shape, - value_range, - data_value, - additional_info, - ) in self.key_iterator( + for (key, prefix, data_type, data_shape, value_range, data_value, additional_info,) in self.key_iterator( d, self.prefix, self.data_type, @@ -149,7 +139,7 @@ def __call__( class SplitChannelExd(MapTransform): """ Extension of `monai.transforms.SplitChanneld`. - Extended: `output_names`: the names to construct keys to store split data if + Extended: `output_names`: the names to construct keys to store split data if you don't want postfixes. `remove_origin`: delete original data of given keys @@ -165,7 +155,7 @@ def __init__( channel_dim: int = 0, remove_origin: bool = False, allow_missing_keys: bool = False, - meta_key_postfix='meta_dict', + meta_key_postfix="meta_dict", ) -> None: """ Args: @@ -205,6 +195,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N d.pop(f"{key}_{self.meta_key_postfix}") return d + class DataLabellingd(MapTransform): def __init__( self, @@ -213,16 +204,13 @@ def __init__( super().__init__(keys) self.converter = DataLabelling() - def __call__( - self, img: Mapping[Hashable, torch.Tensor] - ) -> Dict[Hashable, torch.Tensor]: + def __call__(self, img: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(img) for idx, key in enumerate(self.keys): d[key] = self.converter(d[key]) return d - class ConcatModalityd(MapTransform): """Concat multi-modality data by given keys.""" @@ -268,9 +256,7 @@ def __init__( self.n_layer = n_layer if pos < 0 or neg < 0: - raise ValueError( - f"pos and neg must be nonnegative, got pos={pos} neg={neg}." - ) + raise ValueError(f"pos and neg must be nonnegative, got pos={pos} neg={neg}.") if pos + neg == 0: raise ValueError("Incompatible values: pos=0 and neg=0.") self.pos_ratio = pos / (pos + neg) @@ -300,9 +286,7 @@ def randomize( image: Optional[np.ndarray] = None, ) -> None: if fg_indices is None or bg_indices is None: - fg_indices_, bg_indices_ = map_binary_to_indices( - label, image, self.image_threshold - ) + fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) else: fg_indices_ = fg_indices bg_indices_ = bg_indices @@ -317,38 +301,24 @@ def randomize( self.R, ) - def __call__( - self, data: Mapping[Hashable, np.ndarray] - ) -> List[Dict[Hashable, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]: d = dict(data) label = d[self.label_key] image = d[self.image_key] if self.image_key else None - fg_indices = ( - d.get(self.fg_indices_key, None) - if self.fg_indices_key is not None - else None - ) - bg_indices = ( - d.get(self.bg_indices_key, None) - if self.bg_indices_key is not None - else None - ) + fg_indices = d.get(self.fg_indices_key, None) if self.fg_indices_key is not None else None + bg_indices = d.get(self.bg_indices_key, None) if self.bg_indices_key is not None else None self.randomize(label, fg_indices, bg_indices, image) assert isinstance(self.spatial_size, tuple) assert self.centers is not None - results: List[Dict[Hashable, np.ndarray]] = [ - dict() for _ in range(self.num_samples) - ] + results: List[Dict[Hashable, np.ndarray]] = [dict() for _ in range(self.num_samples)] for key in data.keys(): if key in self.keys: img = d[key] for i, center in enumerate(self.centers): if self.crop_mode in ["single", "parallel"]: size_ = self.get_new_spatial_size() - slice_ = SpatialCrop(roi_center=tuple(center), roi_size=size_)( - img - ) + slice_ = SpatialCrop(roi_center=tuple(center), roi_size=size_)(img) seg_sum = slice_.squeeze().sum() results[i][key] = np.moveaxis(slice_.squeeze(0), self.z_axis, 0) @@ -356,9 +326,7 @@ def __call__( cross_slices = np.zeros(shape=(3,) + self.spatial_size) for k in range(3): size_ = np.insert(self.spatial_size, k, 1) - slice_ = SpatialCrop( - roi_center=tuple(center), roi_size=size_ - )(img) + slice_ = SpatialCrop(roi_center=tuple(center), roi_size=size_)(img) cross_slices[k] = slice_.squeeze() results[i][key] = cross_slices else: @@ -389,7 +357,7 @@ def __init__( # pytype: disable=annotation-type-mismatch select_labels: Union[Sequence[int], int], merge_channels: bool = False, cls_label_key: Optional[KeysCollection] = None, - select_msk_label: Optional[int] = None, #! for tmp debug + select_msk_label: Optional[int] = None, #! for tmp debug ) -> None: super().__init__(keys) self.select_labels = select_labels @@ -407,14 +375,16 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda if self.cls_label_key is not None: label = d[self.cls_label_key] - assert len(label) == len(self.select_labels), 'length of cls_label_key must equal to length of mask select_labels' + assert len(label) == len( + self.select_labels + ), "length of cls_label_key must equal to length of mask select_labels" if isinstance(label, (list, tuple)): - label = { i:L for i, L in enumerate(label, 1)} + label = {i: L for i, L in enumerate(label, 1)} elif isinstance(label, (int, float)): - label = {1:label} - assert isinstance(label, dict), 'Only support dict type label' - + label = {1: label} + assert isinstance(label, dict), "Only support dict type label" + d[self.cls_label_key] = label[self.select_label] for key in self.keys: @@ -431,6 +401,7 @@ class GetItemd(MapTransform): keys (KeysCollection): keys of the corresponding items to be transformed. index (Union[Sequence[int], int]): i-th item you want to select. """ + def __init__( # pytype: disable=annotation-type-mismatch self, keys: KeysCollection, @@ -445,6 +416,119 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N d[key] = d[key][index] return d + +class RandSoftCopyPasted(Randomizable, MapTransform): + """Dictionary-based wrapper of :py:class:`monai_ex.transforms.RandSoftCopyPaste`. + + Args: + keys (KeysCollection): keys of the corresponding items to be transformed. + mask_key (Optional[str]): key of the mask data. + source_dataset (Dataset): a dataset for process source data, return dict data. + source_fg_key (str): key of source foreground data. + source_fg_value (Optional[int]): source foregound value. + k_erode (int): erosion iteration num. + k_dilate (int): dilation iteration num. + alpha (float, optional): transparence ratio. Defaults to 0.8. + prob (float, optional): Probability to perform this aug. Defaults to 0.1. + mask_select_fn (Callable, optional): function to select expected foreground, default is to select values > 0. + strict_paste (bool, optional): whether to strictly paste source mask inside of target mask region. Defaults to False. + tolerance (int, optional): even in strict_paste mode, there is a tolerance to allow paste to the edge. Defaults to 10. + log_name (Optional[str], optional): logger name. Defaults to None. + """ + + def __init__( + self, + keys: KeysCollection, + mask_key: Optional[str], + source_dataset: Dataset, + source_fg_key: str, + source_fg_value: Optional[int], + k_erode: int, + k_dilate: int, + alpha: float = 0.8, + prob: float = 0.1, + mask_select_fn: Callable = is_positive, + strict_paste: bool = False, + tolerance: int = 100, + shift_source_intensity: bool = False, + log_name: Optional[str] = None, + ) -> None: + super().__init__(keys) + self.mask_key = mask_key + self.source_dataset = source_dataset + self.source_fg_value = source_fg_value + self.source_fg_key = source_fg_key + self.generator = RandSoftCopyPaste( + k_erode=k_erode, + k_dilate=k_dilate, + alpha=alpha, + prob=prob, + mask_select_fn=mask_select_fn, + source_label_value=source_fg_value, + strict_paste=strict_paste, + tolerance=tolerance, + shift_source_intensity=shift_source_intensity, + log_name=log_name, + ) + self.logger = logging.getLogger(log_name) + + def randomize(self) -> None: + return self.R.randint(len(self.source_dataset)) + + def compute_target_position(self, src_mask, softed_mask, target_image, target_mask): + self.generator.compute_target_position(src_mask, softed_mask, target_image, target_mask) + + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = dict(data) + idx = self.randomize() + + try: + source = self.source_dataset[idx] + except Exception as e: + raise TransformException("Source dataset crashed.\nErr msg: {e}") + + if self.source_fg_key not in source: + raise TransformException(f"Source dataset did not contain foregound mask key: {self.source_fg_key}") + + if self.generator.strict_paste and ( + np.count_nonzero(self.generator.mask_select_fn(d[self.mask_key])) + < np.count_nonzero(source[self.source_fg_key]) + ): + self.logger.debug("Target mask area is smaller than source foreground area. Skip copy&paste") + return d + + softed_mask = self.generator.soften(source[self.source_fg_key]) + if np.count_nonzero(softed_mask) == 0: + self.logger.debug("Source foreground area is too small to be soften. Skip copy&paste") + return d + + softed_mask = softed_mask[np.newaxis, ...] + + first_img_key = self.first_key(d) + if source[first_img_key].shape[0] > 1: + softed_mask = np.repeat(softed_mask, repeats=source_image.shape[0], axis=0) + + self.compute_target_position(source[self.source_fg_key], softed_mask, d[first_img_key], d[self.mask_key]) + + for key in self.key_iterator(d): + image = d[key] + bg_mask = d[self.mask_key] if self.mask_key else None + source_image, source_fg = source[key], source[self.source_fg_key] + + sythetic_image, sythetic_mask = self.generator( + image, + fg_mask=d.get(self.source_fg_key), + bg_mask=bg_mask, + source_image=source_image, + source_fg_mask=source_fg, + softed_fg_mask=softed_mask, + randomize=False, + ) + d[key] = sythetic_image + d[self.source_fg_key] = sythetic_mask + return d + + ToTensorExD = ToTensorExDict = ToTensorExd CastToTypeExD = CastToTypeExDict = CastToTypeExd DataStatsExD = DataStatsExDict = DataStatsExd @@ -454,3 +538,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N RandCrop2dByPosNegLabelD = RandCrop2dByPosNegLabelDict = RandCrop2dByPosNegLabeld RandLabelToMaskD = RandLabelToMaskDict = RandLabelToMaskd GetItemD = GetItemDict = GetItemd +RandSoftCopyPasteD = RandSoftCopyPasteDict = RandSoftCopyPasted diff --git a/monai_ex/utils/misc.py b/monai_ex/utils/misc.py index 9ae7b62..5ba3d9a 100644 --- a/monai_ex/utils/misc.py +++ b/monai_ex/utils/misc.py @@ -1,14 +1,36 @@ +import itertools +from functools import partial from typing import Any, List -import torch - +import numpy as np +import torch from monai.utils.misc import issequenceiterable -from functools import partial from utils_cw import catch_exception + from .exceptions import GenericException +trycatch = partial(catch_exception, handled_exception_type=GenericException, path_keywords=["strix", "monai_ex"]) + -trycatch = partial(catch_exception, handled_exception_type=GenericException, path_keywords=['strix','monai_ex']) +# Copied from strix +def bbox_ND(img: np.ndarray, return_range: bool = False): + """Compute boundingbox of n-dimensional data. + + Args: + img (np.ndarray): Input nd data. + return_range (bool, optional): if return size of each boundingbox. Defaults to False. + + Returns: + list: boundingbox coord/size list + """ + N = img.ndim + out = [] + for ax in itertools.combinations(reversed(range(N)), N - 1): + nonzero = np.any(img, axis=ax) + out.extend(np.where(nonzero)[0][[0, -1]]) + if return_range: + return tuple(out[2 * i + 1] - out[2 * i] for i in range(len(out) // 2)) + return tuple(out) def ensure_same_dim(tensor1, tensor2): @@ -30,7 +52,9 @@ def ensure_list(vals: Any): Returns a list of `vals`. """ if not issequenceiterable(vals) or isinstance(vals, dict): - vals = [vals, ] + vals = [ + vals, + ] return list(vals) @@ -43,7 +67,9 @@ def ensure_list_rep(vals: Any, dim: int) -> List[Any]: ValueError: When ``tup`` is a sequence and ``tup`` length is not ``dim``. """ if not issequenceiterable(vals): - return [vals,] * dim + return [ + vals, + ] * dim elif len(vals) == dim: return list(vals) @@ -56,7 +82,7 @@ def _register_generic(module_dict, module_name, module): class Registry(dict): - ''' + """ A helper class for managing registering modules, it extends a dictionary and provides a register functions. @@ -76,7 +102,8 @@ def foo(): Access of module is just like using a dictionary, eg: f = some_registry["foo_modeul"] - ''' + """ + def __init__(self, *args, **kwargs): super(Registry, self).__init__(*args, **kwargs) @@ -91,4 +118,4 @@ def register_fn(fn): _register_generic(self, module_name, fn) return fn - return register_fn \ No newline at end of file + return register_fn