diff --git a/monai/apps/deepgrow/dataset.py b/monai/apps/deepgrow/dataset.py index 45cfbde6ea..acaeba0bc3 100644 --- a/monai/apps/deepgrow/dataset.py +++ b/monai/apps/deepgrow/dataset.py @@ -22,7 +22,7 @@ def create_dataset( datalist, output_dir: str, - dimension, + dimension: int, pixdim, image_key: str = "image", label_key: str = "label", @@ -138,7 +138,7 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): if len(vol_image.shape) == 4: logging.info( "4D-Image, pick only first series; Image: {}; Label: {}".format( - vol_image.shape, vol_label.shape if vol_label else None + vol_image.shape, vol_label.shape if vol_label is not None else None ) ) vol_image = vol_image[0] @@ -216,7 +216,7 @@ def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): if len(vol_image.shape) == 4: logging.info( "4D-Image, pick only first series; Image: {}; Label: {}".format( - vol_image.shape, vol_label.shape if vol_label else None + vol_image.shape, vol_label.shape if vol_label is not None else None ) ) vol_image = vol_image[0] diff --git a/tests/test_deepgrow_dataset.py b/tests/test_deepgrow_dataset.py index e871c328a6..147d8e7099 100644 --- a/tests/test_deepgrow_dataset.py +++ b/tests/test_deepgrow_dataset.py @@ -10,47 +10,96 @@ # limitations under the License. import os +import shutil import tempfile import unittest import nibabel as nib import numpy as np +from parameterized import parameterized from monai.apps.deepgrow.dataset import create_dataset +from monai.utils import set_determinism + +TEST_CASE_1 = [{"dimension": 2, "pixdim": (1, 1)}, {"length": 3}, 9, 1] + +TEST_CASE_2 = [{"dimension": 2, "pixdim": (1, 1), "limit": 1}, {"length": 3}, 3, 1] + +TEST_CASE_3 = [{"dimension": 2, "pixdim": (1, 1)}, {"length": 1}, 3, 1] + +TEST_CASE_4 = [{"dimension": 3, "pixdim": (1, 1, 1)}, {"length": 1}, 1, 1] + +TEST_CASE_5 = [{"dimension": 3, "pixdim": (1, 1, 1)}, {"length": 1, "image_channel": 4}, 1, 1] + +TEST_CASE_6 = [{"dimension": 2, "pixdim": (1, 1)}, {"length": 1, "image_channel": 4}, 3, 1] + +TEST_CASE_7 = [ + {"dimension": 2, "pixdim": (1, 1), "label_key": None}, + {"length": 1, "image_channel": 4, "with_label": False}, + 40, + None, +] + +TEST_CASE_8 = [ + {"dimension": 3, "pixdim": (1, 1, 1), "label_key": None}, + {"length": 1, "image_channel": 4, "with_label": False}, + 1, + None, +] class TestCreateDataset(unittest.TestCase): - def _create_data(self, tempdir): + def setUp(self): + set_determinism(1) + self.tempdir = tempfile.mkdtemp() + + def _create_data(self, length=1, image_channel=1, with_label=True): affine = np.eye(4) - image = np.random.randint(0, 2, size=(128, 128, 40)) - image_file = os.path.join(tempdir, "image1.nii.gz") - nib.save(nib.Nifti1Image(image, affine), image_file) - - label = np.zeros((128, 128, 40)) - label[0][1][0] = 1 - label[0][1][1] = 1 - label[0][0][2] = 1 - label[0][1][2] = 1 - label_file = os.path.join(tempdir, "label1.nii.gz") - nib.save(nib.Nifti1Image(label, affine), label_file) - - return [{"image": image_file, "label": label_file}] - - def test_create_dataset_2d(self): - with tempfile.TemporaryDirectory() as tempdir: - datalist = self._create_data(tempdir) - output_dir = os.path.join(tempdir, "2d") - deepgrow_datalist = create_dataset(datalist=datalist, output_dir=output_dir, dimension=2, pixdim=(1, 1)) - self.assertEqual(len(deepgrow_datalist), 3) - self.assertEqual(deepgrow_datalist[0]["region"], 1) - - def test_create_dataset_3d(self): - with tempfile.TemporaryDirectory() as tempdir: - datalist = self._create_data(tempdir) - output_dir = os.path.join(tempdir, "3d") - deepgrow_datalist = create_dataset(datalist=datalist, output_dir=output_dir, dimension=3, pixdim=(1, 1, 1)) - self.assertEqual(len(deepgrow_datalist), 1) - self.assertEqual(deepgrow_datalist[0]["region"], 1) + datalist = [] + for i in range(length): + if image_channel == 1: + image = np.random.randint(0, 2, size=(128, 128, 40)) + else: + image = np.random.randint(0, 2, size=(128, 128, 40, image_channel)) + image_file = os.path.join(self.tempdir, f"image{i}.nii.gz") + nib.save(nib.Nifti1Image(image, affine), image_file) + + if with_label: + # 3 slices has label + label = np.zeros((128, 128, 40)) + label[0][1][0] = 1 + label[0][1][1] = 1 + label[0][0][2] = 1 + label[0][1][2] = 1 + label_file = os.path.join(self.tempdir, f"label{i}.nii.gz") + nib.save(nib.Nifti1Image(label, affine), label_file) + datalist.append({"image": image_file, "label": label_file}) + else: + datalist.append({"image": image_file}) + + return datalist + + @parameterized.expand( + [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8] + ) + def test_create_dataset(self, args, data_args, expected_length, expected_region): + datalist = self._create_data(**data_args) + deepgrow_datalist = create_dataset(datalist=datalist, output_dir=self.tempdir, **args) + self.assertEqual(len(deepgrow_datalist), expected_length) + if expected_region is not None: + self.assertEqual(deepgrow_datalist[0]["region"], expected_region) + + def test_invalid_dim(self): + with self.assertRaises(ValueError): + create_dataset(datalist=self._create_data(), output_dir=self.tempdir, dimension=4, pixdim=(1, 1, 1, 1)) + + def test_empty_datalist(self): + with self.assertRaises(ValueError): + create_dataset(datalist=[], output_dir=self.tempdir, dimension=3, pixdim=(1, 1, 1)) + + def tearDown(self): + shutil.rmtree(self.tempdir) + set_determinism(None) if __name__ == "__main__":