Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions monai_ex/tests/test_CenterMask2DSliceCropD.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest

from monai_ex.transforms.croppad.dictionary import CenterMask2DSliceCropD
from monai.data import Dataset
from monai_ex.transforms import GenerateSyntheticDataD, Compose


@pytest.mark.parametrize("crop_size,crop_mode,expected", [((50,50), "single", (1,50,50)), ((50,50), "parallel", (3,50,50))])
def test_fullimiage2dslicecropd(crop_size, crop_mode, expected):
dim = 3
spatial_size = (100,) * dim

generator = GenerateSyntheticDataD(
["image", "label"],
*spatial_size,
num_objs=1,
rad_max=5,
rad_min=4,
noise_max=0.5,
num_seg_classes=1,
channel_dim=0,
)

source_dataset = Dataset(
[{"image": 'dummy.nii', "label": 'dummy_label.nii'} for i in range(2)],
transform=Compose([
generator,
CenterMask2DSliceCropD(
keys="image",
mask_key="label",
roi_size=crop_size,
crop_mode=crop_mode,
center_mode="center",
z_axis=2,
n_slices=3
)
])
)

output_item = source_dataset[0]

assert output_item['image'].shape == expected
53 changes: 53 additions & 0 deletions monai_ex/tests/test_RandSoftCopyPaste.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest

from pathlib import Path
import nibabel as nib

import numpy as np
from monai_ex.transforms.utility.array import RandSoftCopyPaste
from monai.data.synthetic import create_test_image_3d
from monai_ex.transforms.io.array import GenerateSyntheticData


@pytest.mark.parametrize("dim", [2, 3])
@pytest.mark.parametrize("prob", [0, 1])
def test_randsoftcopypaste(dim, prob):
spatial_size = (100,) * dim
generator = GenerateSyntheticData(
*spatial_size,
num_objs=1,
rad_max=10,
rad_min=9,
noise_max=0.2,
num_seg_classes=1,
channel_dim=0,
)

src_image, src_mask = generator(None)
tar_image, tar_mask = generator(None)
volume_size = np.count_nonzero(src_mask) + np.count_nonzero(tar_mask)

print("dummy data, mask shape:", src_image.shape, src_image.shape)
print("mask label: ", np.unique(src_mask))
sythetic_img, sythetic_msk = RandSoftCopyPaste(
2, 4, prob=prob, mask_select_fn=lambda x: x==0, source_label_value=1
)(tar_image, tar_mask, src_image, src_mask)
if prob == 0:
assert np.all(sythetic_img == tar_image)
assert np.all(sythetic_msk == tar_mask)
else:
assert sythetic_img.shape == (1, *spatial_size)
assert volume_size/2 <= np.count_nonzero(sythetic_msk) <= volume_size

# save_fpath = Path.home() / f"sythetic_img_{dim}.nii.gz"
# nib.save(nib.Nifti1Image(sythetic_img.squeeze(), np.eye(4)), save_fpath)

sythetic_img, sythetic_msk = RandSoftCopyPaste(
2, 4, prob=prob, source_label_value=1
)(tar_image, None, src_image, src_mask)
if prob == 0:
assert np.all(sythetic_img == tar_image)
assert sythetic_msk is None
else:
assert sythetic_img.shape == (1, *spatial_size)
assert volume_size/2 <= np.count_nonzero(sythetic_msk) <= volume_size
151 changes: 151 additions & 0 deletions monai_ex/tests/test_RandSoftCopyPasteD.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import pytest

from pathlib import Path
import nibabel as nib

import numpy as np
from monai_ex.transforms.utility.dictionary import RandSoftCopyPasteD
from monai.data.synthetic import create_test_image_3d
from monai.data import Dataset
from monai_ex.transforms import MapTransform, GenerateSyntheticData, Compose, adaptor



class GenerateSyntheticDataD(MapTransform):
def __init__(
self,
keys,
label_key,
height: int,
width: int,
depth: int = None,
num_objs: int = 12,
rad_max: int = 30,
rad_min: int = 5,
noise_max: float = 0.0,
num_seg_classes: int = 5,
channel_dim: int = None,
random_state: np.random.RandomState = None,
allow_missing_keys: bool = False,
):
super().__init__(keys, allow_missing_keys)

self.label_key = label_key
self.loader = GenerateSyntheticData(
height,
width,
depth,
num_objs,
rad_max,
rad_min,
noise_max,
num_seg_classes,
channel_dim,
random_state,
)

def __call__(self, filename: dict):
test_data = self.loader(None)

data = {}
for key in self.keys:
data[key] = test_data[0]
data[self.label_key] = test_data[1]
return data


@pytest.mark.parametrize("dim", [2, 3])
def test_randsoftcopypaste(dim):
data_num = 2
spatial_size = (100,) * dim
generator = GenerateSyntheticDataD(
"image",
"label",
*spatial_size,
num_objs=1,
rad_max=5,
rad_min=4,
noise_max=0.5,
num_seg_classes=1,
channel_dim=0,
)

dummy_fpath = [{"image": "d.nii", "label": "l.nii"} for i in range(data_num)]

output = generator(dummy_fpath[0])
volume_size = np.count_nonzero(output["label"])

source_dataset = Dataset(
[{"image": 'dummy.nii', "label": 'dummy_label.nii'} for i in range(data_num)],
transform=generator
)

dataset = Dataset(
dummy_fpath, transform=Compose([
generator,
RandSoftCopyPasteD(
keys="image", mask_key="label",
source_dataset=source_dataset, # will generate image & mask
source_fg_key="label",
source_fg_value=1,
k_erode=2,
k_dilate=5,
alpha=0.8,
prob=1,
mask_select_fn=lambda x: x == 0,
)
])
)

for i, item in enumerate(dataset):
image, label = item["image"], item["label"]

# save_fpath = Path.home() / f"sythetic_{dim}Dimg_{i}.nii.gz"
# nib.save(nib.Nifti1Image(image.squeeze(), np.eye(4)), save_fpath)
# save_fpath = Path.home() / f"sythetic_{dim}Dlabel_{i}.nii.gz"
# nib.save(nib.Nifti1Image(label.squeeze(), np.eye(4)), save_fpath)

assert volume_size < np.count_nonzero(label) <= 2 * volume_size



@pytest.mark.parametrize("dim", [2, 3])
def test_randsoftcopypaste_multiimage(dim):
data_num = 2
spatial_size = (100,) * dim
generator = GenerateSyntheticDataD(
["image1", "image2"],
"label",
*spatial_size,
num_objs=1,
rad_max=5,
rad_min=4,
noise_max=0,
num_seg_classes=1,
channel_dim=0,
)

dummy_fpath = [{"image1": "d.nii", "image2": "d.nii", "label": "l.nii"} for i in range(data_num)]
source_dataset = Dataset(
[{"image1": '1', "image2": "2", "label": '1'} for i in range(data_num)],
transform=generator
)

outputs = generator({"image1": '1', "image2": '2', "label": '1'})

generator = RandSoftCopyPasteD(
keys=["image1", "image2"], mask_key="label",
source_dataset=source_dataset,
source_fg_key="label",
source_fg_value=1,
k_erode=2,
k_dilate=5,
alpha=0.8,
prob=1,
mask_select_fn=lambda x: x == 0,
)

generated_item = generator(outputs)
assert generated_item["image1"].shape == (1, *spatial_size)
assert generated_item["image2"].shape == (1, *spatial_size)
assert np.all(generated_item["image1"] == generated_item["image2"])
52 changes: 52 additions & 0 deletions monai_ex/tests/test_SelectSlicesByMask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest

import numpy as np
from monai_ex.transforms.croppad.array import SelectSlicesByMask
from monai_ex.transforms.croppad.dictionary import SelectSlicesByMaskD
from monai_ex.transforms import GenerateSyntheticData, GenerateSyntheticDataD

def test_selectslicesbymask():
dim = 3
spatial_size = (100,) * dim

generator = GenerateSyntheticData(
*spatial_size,
num_objs=1,
rad_max=5,
rad_min=4,
noise_max=0,
num_seg_classes=1,
channel_dim=0,
)

image, label = generator(None)
cropper = SelectSlicesByMask(z_axis=2, center_mode='center', mask_data=label)
img_slice = cropper(image)

assert img_slice.shape == (1, 100, 100)
assert np.count_nonzero(img_slice) > 0


def test_selectslicesbymaskdict():
dim = 3
spatial_size = (100,) * dim

generator = GenerateSyntheticDataD(
["image", "label"],
*spatial_size,
num_objs=1,
rad_max=5,
rad_min=4,
noise_max=0,
num_seg_classes=1,
channel_dim=0,
)

outputs = generator({"image": "1", "label": "1"})
cropper = SelectSlicesByMaskD(keys=["image", "label"], mask_key="label", z_axis=2, center_mode='center')
img_slice = cropper(outputs)

assert img_slice["image"].shape == (1, 100, 100)
assert img_slice["label"].shape == (1, 100, 100)
assert np.count_nonzero(img_slice["image"]) > 0
assert np.count_nonzero(img_slice["label"]) > 0
23 changes: 23 additions & 0 deletions monai_ex/tests/test_bbox_nd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest
from monai_ex.utils.misc import bbox_ND
import numpy as np

dummy_data_3d = np.zeros([10, 10, 10])
dummy_data_3d[4:7, 4:8, 4:9] = 1

dummy_data_2d = np.zeros([10, 10])
dummy_data_2d[4:7, 4:8] = 1

@pytest.mark.parametrize('data', [dummy_data_2d, dummy_data_3d])
def test_bbox_nd(data):
bounding = bbox_ND(data, False)
if len(data) == 3:
assert bounding == (4, 6, 4, 7, 4, 8)
elif len(data) == 2:
assert bounding == (4, 6, 4, 7)

bbox_range = bbox_ND(data, True)
if len(data) == 3:
assert bbox_range == (2, 3, 4)
elif len(data) == 2:
assert bbox_range == (2, 3)
2 changes: 0 additions & 2 deletions monai_ex/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ def __call__(self, input_):
return apply_transform(self.selected_trans, input_)


ReturnType = TypeVar("ReturnType")

def _apply_transform(
transform: Callable[..., ReturnType], parameters: Any, unpack_parameters: bool = False
) -> ReturnType:
Expand Down
Loading