From 8f411d0b91297a92586b6c7dbb05105068f5fdc2 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 16 Jul 2021 17:15:00 +0800 Subject: [PATCH 1/9] Add DatasetCalculator Signed-off-by: Yiheng Wang --- docs/source/data.rst | 4 + monai/data/__init__.py | 1 + monai/data/dataset_calculator.py | 124 +++++++++++++++++++++++++++++++ tests/min_tests.py | 1 + tests/test_dataset_calculator.py | 50 +++++++++++++ 5 files changed, 180 insertions(+) create mode 100644 monai/data/dataset_calculator.py create mode 100644 tests/test_dataset_calculator.py diff --git a/docs/source/data.rst b/docs/source/data.rst index a5c3509fc9..2626f4ff71 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -182,6 +182,10 @@ DistributedWeightedRandomSampler ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: monai.data.DistributedWeightedRandomSampler +DatasetCalculator +~~~~~~~~~~~~~~~~~ +.. autoclass:: monai.data.DatasetCalculator + Decathlon Datalist ~~~~~~~~~~~~~~~~~~ .. autofunction:: monai.data.load_decathlon_datalist diff --git a/monai/data/__init__.py b/monai/data/__init__.py index af42627f5f..158465d141 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -23,6 +23,7 @@ SmartCacheDataset, ZipDataset, ) +from .dataset_calculator import DatasetCalculator from .decathlon_datalist import load_decathlon_datalist, load_decathlon_properties from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter from .image_dataset import ImageDataset diff --git a/monai/data/dataset_calculator.py b/monai/data/dataset_calculator.py new file mode 100644 index 0000000000..53ff0a7779 --- /dev/null +++ b/monai/data/dataset_calculator.py @@ -0,0 +1,124 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +from typing import Dict, Sequence + +import numpy as np +from joblib import Parallel, delayed + +from monai.transforms import LoadImaged + + +class DatasetCalculator: + """ + This class contains several functions that can collect data such as voxel spacings + and intensities of the input dataset, then target spacings and intensity statistics (min, max, mean, std) + can be calculated via calling the corresponding functions. + + This class refers to: + `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. + + """ + + def __init__( + self, + datalist: Sequence[Dict], + image_key: str = "image", + label_key: str = "label", + meta_key_postfix: str = "meta_dict", + num_workers: int = -1, + ): + """ + Args: + datalist: a list that contains the path of all images and labels. The list is + consisted with dictionaries, and each dictionary contains the image and label + path of one sample. For datasets that have Decathlon format, datalist can be + achieved by calling `monai.data.load_decathlon_datalist`. + image_key: the key name of images. Defaults to `image`. + label_key: the key name of labels. Defaults to `label`. + meta_key_postfix: for nifti images, use `{image_key}_{meta_key_postfix}` to + store the metadata of images. + num_workers: the maximum number of processes can be used in data loading. + + """ + + self.datalist = datalist + self.image_key = image_key + self.label_key = label_key + self.meta_key_postfix = meta_key_postfix + self.num_workers = num_workers + self.loader = LoadImaged(keys=[image_key, label_key], meta_key_postfix=meta_key_postfix) + + def _run_parallel(self, function): + """ + Parallelly running the function for all data in the datalist. + + """ + + return Parallel(n_jobs=self.num_workers)(delayed(function)(data) for data in self.datalist) + + def _load_spacing(self, path_dict: Dict): + """ + Load spacing from a data's dictionary. Assume that the original image file has `pixdim` + in its metadata. + + """ + data = self.loader(path_dict) + meta_key = "{}_{}".format(self.image_key, self.meta_key_postfix) + spacing = data[meta_key]["pixdim"][1:4].tolist() + + return spacing + + def _get_target_spacing(self, anisotropic_threshold: int = 3, percentile: float = 10.0): + """ + Calculate the target spacing according to all spacings. + If the target spacing is very anisotropic, + decrease the spacing value of the maximum axis according to percentile. + + """ + spacing = self._run_parallel(self._load_spacing) + spacing = np.array(spacing) + target_spacing = np.median(spacing, axis=0) + if max(target_spacing) / min(target_spacing) >= anisotropic_threshold: + largest_axis = np.argmax(target_spacing) + target_spacing[largest_axis] = np.percentile(spacing[:, largest_axis], percentile) + + output = list(target_spacing) + output = [round(value, 2) for value in output] + + return tuple(output) + + def _load_intensity(self, path_dict: Dict): + """ + Load intensity from a data's dictionary. + + """ + data = self.loader(path_dict) + image = data[self.image_key] + foreground_idx = np.where(data[self.label_key] > 0) + + return image[foreground_idx].tolist() + + def _get_intensity_stats(self, lower: float = 0.5, upper: float = 99.5): + """ + Calculate min, max, mean and std of all intensities. The minimal and maximum + values will be processed according to the provided percentiles. + + """ + intensity = self._run_parallel(self._load_intensity) + intensity = np.array(list(itertools.chain.from_iterable(intensity))) + min_value, max_value = np.percentile(intensity, [lower, upper]) + mean_value, std_value = np.mean(intensity), np.std(intensity) + output = [min_value, max_value, mean_value, std_value] + output = [round(value, 2) for value in output] + + return tuple(output) diff --git a/tests/min_tests.py b/tests/min_tests.py index 1cd54f35d0..4e08fcc832 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -111,6 +111,7 @@ def run_testsuit(): "test_handler_metrics_saver", "test_handler_metrics_saver_dist", "test_handler_classification_saver_dist", + "test_dataset_calculator", "test_deepgrow_transforms", "test_deepgrow_interaction", "test_deepgrow_dataset", diff --git a/tests/test_dataset_calculator.py b/tests/test_dataset_calculator.py new file mode 100644 index 0000000000..3ca13fb7c3 --- /dev/null +++ b/tests/test_dataset_calculator.py @@ -0,0 +1,50 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import os +import tempfile +import unittest + +import nibabel as nib +import numpy as np + +from monai.data import DatasetCalculator, create_test_image_3d +from monai.utils import set_determinism + + +class TestDatasetCalculator(unittest.TestCase): + def test_spacing_intensity(self): + set_determinism(seed=0) + with tempfile.TemporaryDirectory() as tempdir: + + for i in range(5): + im, seg = create_test_image_3d(32, 32, 32, num_seg_classes=1, num_objs=3, rad_max=6, channel_dim=-1) + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) + + train_images = sorted(glob.glob(os.path.join(tempdir, "img*.nii.gz"))) + train_labels = sorted(glob.glob(os.path.join(tempdir, "seg*.nii.gz"))) + data_dicts = [ + {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) + ] + + calculator = DatasetCalculator(data_dicts) + target_spacing = calculator._get_target_spacing(anisotropic_threshold=3, percentile=10.0) + self.assertEqual(target_spacing, (1.0, 1.0, 1.0)) + intensity_stats = calculator._get_intensity_stats(lower=0.5, upper=99.5) + self.assertEqual(intensity_stats, (0.56, 1.0, 0.89, 0.13)) + + +if __name__ == "__main__": + unittest.main() From 1921f6fc1ddbe9f74c9939297230f6695af6d017 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 16 Jul 2021 17:33:41 +0800 Subject: [PATCH 2/9] update docstring Signed-off-by: Yiheng Wang --- monai/data/dataset_calculator.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/monai/data/dataset_calculator.py b/monai/data/dataset_calculator.py index 53ff0a7779..fc0d1252b1 100644 --- a/monai/data/dataset_calculator.py +++ b/monai/data/dataset_calculator.py @@ -20,11 +20,14 @@ class DatasetCalculator: """ - This class contains several functions that can collect data such as voxel spacings - and intensities of the input dataset, then target spacings and intensity statistics (min, max, mean, std) - can be calculated via calling the corresponding functions. - - This class refers to: + This class provides a way to calculate a reasonable output voxel spacing according to + the input dataset. The achieved values can used to resample the input in 3d segmentation tasks + (like using as the `pixel` parameter in `monai.transforms.Spacingd`). + In addition, it also supports to count the mean, std, min and max intensities of the input, + and these statistics are helpful for image normalization + (like using in `monai.transforms.ScaleIntensityRanged` and `monai.transforms.NormalizeIntensityd`). + + The algorithm for calculation refers to: `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. """ From 7ac7159a3118d9d6902584770f5ef337ed989e43 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 16 Jul 2021 23:13:55 +0800 Subject: [PATCH 3/9] use multiprocessing Signed-off-by: Yiheng Wang --- monai/data/dataset_calculator.py | 11 ++++++----- tests/test_dataset_calculator.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/monai/data/dataset_calculator.py b/monai/data/dataset_calculator.py index fc0d1252b1..4ab584339d 100644 --- a/monai/data/dataset_calculator.py +++ b/monai/data/dataset_calculator.py @@ -10,10 +10,10 @@ # limitations under the License. import itertools +import multiprocessing as mp from typing import Dict, Sequence import numpy as np -from joblib import Parallel, delayed from monai.transforms import LoadImaged @@ -38,7 +38,7 @@ def __init__( image_key: str = "image", label_key: str = "label", meta_key_postfix: str = "meta_dict", - num_workers: int = -1, + num_processes: int = 1, ): """ Args: @@ -58,7 +58,7 @@ def __init__( self.image_key = image_key self.label_key = label_key self.meta_key_postfix = meta_key_postfix - self.num_workers = num_workers + self.num_processes = num_processes self.loader = LoadImaged(keys=[image_key, label_key], meta_key_postfix=meta_key_postfix) def _run_parallel(self, function): @@ -66,8 +66,9 @@ def _run_parallel(self, function): Parallelly running the function for all data in the datalist. """ - - return Parallel(n_jobs=self.num_workers)(delayed(function)(data) for data in self.datalist) + with mp.Pool(processes=self.num_processes) as pool: + result = pool.map(function, self.datalist) + return result def _load_spacing(self, path_dict: Dict): """ diff --git a/tests/test_dataset_calculator.py b/tests/test_dataset_calculator.py index 3ca13fb7c3..68b3b4d29a 100644 --- a/tests/test_dataset_calculator.py +++ b/tests/test_dataset_calculator.py @@ -39,7 +39,7 @@ def test_spacing_intensity(self): {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) ] - calculator = DatasetCalculator(data_dicts) + calculator = DatasetCalculator(data_dicts, num_processes=2) target_spacing = calculator._get_target_spacing(anisotropic_threshold=3, percentile=10.0) self.assertEqual(target_spacing, (1.0, 1.0, 1.0)) intensity_stats = calculator._get_intensity_stats(lower=0.5, upper=99.5) From 85fda190513486e4e056bbd09de5ad3c0a11bad2 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 28 Jul 2021 13:38:12 +0800 Subject: [PATCH 4/9] update to use dataset and other places Signed-off-by: Yiheng Wang --- monai/data/dataset_calculator.py | 176 ++++++++++++++++++++----------- tests/test_dataset_calculator.py | 20 ++-- 2 files changed, 126 insertions(+), 70 deletions(-) diff --git a/monai/data/dataset_calculator.py b/monai/data/dataset_calculator.py index 4ab584339d..5aed007493 100644 --- a/monai/data/dataset_calculator.py +++ b/monai/data/dataset_calculator.py @@ -9,20 +9,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools -import multiprocessing as mp -from typing import Dict, Sequence +from itertools import chain +from typing import List, Optional import numpy as np +import torch -from monai.transforms import LoadImaged +from monai.data.dataloader import DataLoader +from monai.data.dataset import Dataset class DatasetCalculator: """ This class provides a way to calculate a reasonable output voxel spacing according to the input dataset. The achieved values can used to resample the input in 3d segmentation tasks - (like using as the `pixel` parameter in `monai.transforms.Spacingd`). + (like using as the `pixdim` parameter in `monai.transforms.Spacingd`). In addition, it also supports to count the mean, std, min and max intensities of the input, and these statistics are helpful for image normalization (like using in `monai.transforms.ScaleIntensityRanged` and `monai.transforms.NormalizeIntensityd`). @@ -34,95 +35,142 @@ class DatasetCalculator: def __init__( self, - datalist: Sequence[Dict], - image_key: str = "image", - label_key: str = "label", + dataset: Dataset, + image_key: Optional[str] = "image", + label_key: Optional[str] = "label", meta_key_postfix: str = "meta_dict", - num_processes: int = 1, + num_workers: int = 0, + **kwargs, ): """ Args: - datalist: a list that contains the path of all images and labels. The list is - consisted with dictionaries, and each dictionary contains the image and label - path of one sample. For datasets that have Decathlon format, datalist can be - achieved by calling `monai.data.load_decathlon_datalist`. - image_key: the key name of images. Defaults to `image`. - label_key: the key name of labels. Defaults to `label`. - meta_key_postfix: for nifti images, use `{image_key}_{meta_key_postfix}` to - store the metadata of images. - num_workers: the maximum number of processes can be used in data loading. + dataset: dataset from which to load the data. + image_key: key name of images (default: ``image``). + label_key: key name of labels (default: ``label``). + meta_key_postfix: use `{image_key}_{meta_key_postfix}` to fetch the meta data from dict, + the meta data is a dictionary object (default: ``meta_dict``). + num_workers: how many subprocesses to use for data loading. + ``0`` means that the data will be loaded in the main process (default: ``0``). + kwargs: other parameters (except batch_size) for DataLoader (this class forces to use ``batch_size=1``). """ - self.datalist = datalist + self.data_loader = DataLoader(dataset=dataset, batch_size=1, num_workers=num_workers, **kwargs) + self.image_key = image_key self.label_key = label_key - self.meta_key_postfix = meta_key_postfix - self.num_processes = num_processes - self.loader = LoadImaged(keys=[image_key, label_key], meta_key_postfix=meta_key_postfix) - - def _run_parallel(self, function): - """ - Parallelly running the function for all data in the datalist. - - """ - with mp.Pool(processes=self.num_processes) as pool: - result = pool.map(function, self.datalist) - return result + if image_key: + self.meta_key = "{}_{}".format(image_key, meta_key_postfix) + self.all_meta_data: List = [] - def _load_spacing(self, path_dict: Dict): + def collect_meta_data(self): """ - Load spacing from a data's dictionary. Assume that the original image file has `pixdim` - in its metadata. - + This function is used to collect the meta data for all images of the dataset. """ - data = self.loader(path_dict) - meta_key = "{}_{}".format(self.image_key, self.meta_key_postfix) - spacing = data[meta_key]["pixdim"][1:4].tolist() + if not self.meta_key: + raise ValueError("To collect meta data for the dataset, `meta_key` should exist.") - return spacing + for data in self.data_loader: + self.all_meta_data.append(data[self.meta_key]) - def _get_target_spacing(self, anisotropic_threshold: int = 3, percentile: float = 10.0): + def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: int = 3, percentile: float = 10.0): """ Calculate the target spacing according to all spacings. If the target spacing is very anisotropic, decrease the spacing value of the maximum axis according to percentile. + Args: + spacing_key: key of spacing in meta data (default: ``pixdim``). + anisotropic_threshold: threshold to decide if the target spacing is anisotropic (default: ``3``). + percentile: for anisotropic target spacing, use the percentile of all spacings of the anisotropic axis to + replace that axis. + """ - spacing = self._run_parallel(self._load_spacing) - spacing = np.array(spacing) - target_spacing = np.median(spacing, axis=0) + if len(self.all_meta_data) == 0: + self.collect_meta_data() + if spacing_key not in self.all_meta_data[0]: + raise ValueError("The provided spacing_key is not in self.all_meta_data.") + + all_spacings = torch.vstack([data[spacing_key][:, 1:4] for data in self.all_meta_data]).numpy() + + target_spacing = np.median(all_spacings, axis=0) if max(target_spacing) / min(target_spacing) >= anisotropic_threshold: largest_axis = np.argmax(target_spacing) - target_spacing[largest_axis] = np.percentile(spacing[:, largest_axis], percentile) + target_spacing[largest_axis] = np.percentile(all_spacings[:, largest_axis], percentile) output = list(target_spacing) - output = [round(value, 2) for value in output] return tuple(output) - def _load_intensity(self, path_dict: Dict): + def calculate_statistics(self, foreground_threshold: int = 0): """ - Load intensity from a data's dictionary. + This function is used to calculate the maximum, minimum, mean and standard deviation of intensities of + the input dataset. - """ - data = self.loader(path_dict) - image = data[self.image_key] - foreground_idx = np.where(data[self.label_key] > 0) - - return image[foreground_idx].tolist() + Args: + foreground_threshold: the threshold to distinguish if a voxel belongs to foreground, this parameter + is used to select the foreground of images for calculation. Normally, `label > 0` means the corresponding + voxel belongs to foreground, thus if you need to calculate the statistics for whole images, you can set + the threshold to ``-1`` (default: ``0``). - def _get_intensity_stats(self, lower: float = 0.5, upper: float = 99.5): """ - Calculate min, max, mean and std of all intensities. The minimal and maximum - values will be processed according to the provided percentiles. - + voxel_sum = torch.as_tensor(0.0) + voxel_square_sum = torch.as_tensor(0.0) + voxel_max, voxel_min = [], [] + voxel_ct = 0 + + for data in self.data_loader: + if self.image_key and self.label_key: + image, label = data[self.image_key], data[self.label_key] + else: + image, label = data + voxel_max.append(image.max().item()) + voxel_min.append(image.min().item()) + + image_foreground = image[torch.where(label > foreground_threshold)] + voxel_ct += len(image_foreground) + voxel_sum += image_foreground.sum() + voxel_square_sum += torch.square(image_foreground).sum() + + self.data_max, self.data_min = max(voxel_max), min(voxel_min) + self.data_mean = (voxel_sum / voxel_ct).item() + self.data_std = (torch.sqrt(voxel_square_sum / voxel_ct - self.data_mean ** 2)).item() + + def calculate_percentiles( + self, + foreground_threshold: int = 0, + sampling_flag: bool = True, + interval: int = 10, + min_percentile: float = 0.5, + max_percentile: float = 99.5, + ): """ - intensity = self._run_parallel(self._load_intensity) - intensity = np.array(list(itertools.chain.from_iterable(intensity))) - min_value, max_value = np.percentile(intensity, [lower, upper]) - mean_value, std_value = np.mean(intensity), np.std(intensity) - output = [min_value, max_value, mean_value, std_value] - output = [round(value, 2) for value in output] + This function is used to calculate the percentiles of intensities (and median) of the input dataset. To get + the required values, all voxels need to be accumulated. To reduce the memory used, this function can be set + to accumulate only a part of the voxels. - return tuple(output) + Args: + foreground_threshold: the threshold to distinguish if a voxel belongs to foreground, this parameter + is used to select the foreground of images for calculation. Normally, `label > 0` means the corresponding + voxel belongs to foreground, thus if you need to calculate the statistics for whole images, you can set + the threshold to ``-1`` (default: ``0``). + sampling_flag: whether to sample only a part of the voxels (default: ``True``). + interval: the sampling interval for accumulating voxels (default: ``10``). + min_percentile: minimal percentile (default: ``0.5``). + max_percentile: maximal percentile (default: ``99.5``). + + """ + all_intensities = [] + for data in self.data_loader: + + image, label = data[self.image_key], data[self.label_key] + intensities = image[torch.where(label > foreground_threshold)].tolist() + if sampling_flag: + intensities = intensities[::interval] + all_intensities.append(intensities) + + all_intensities = list(chain(*all_intensities)) + self.data_min_percentile, self.data_max_percentile = np.percentile( + all_intensities, [min_percentile, max_percentile] + ) + self.data_median = np.median(all_intensities) diff --git a/tests/test_dataset_calculator.py b/tests/test_dataset_calculator.py index 68b3b4d29a..fe4e13fca4 100644 --- a/tests/test_dataset_calculator.py +++ b/tests/test_dataset_calculator.py @@ -17,7 +17,8 @@ import nibabel as nib import numpy as np -from monai.data import DatasetCalculator, create_test_image_3d +from monai.data import Dataset, DatasetCalculator, create_test_image_3d +from monai.transforms import LoadImaged from monai.utils import set_determinism @@ -27,7 +28,7 @@ def test_spacing_intensity(self): with tempfile.TemporaryDirectory() as tempdir: for i in range(5): - im, seg = create_test_image_3d(32, 32, 32, num_seg_classes=1, num_objs=3, rad_max=6, channel_dim=-1) + im, seg = create_test_image_3d(32, 32, 32, num_seg_classes=1, num_objs=3, rad_max=6, channel_dim=0) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) @@ -39,11 +40,18 @@ def test_spacing_intensity(self): {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) ] - calculator = DatasetCalculator(data_dicts, num_processes=2) - target_spacing = calculator._get_target_spacing(anisotropic_threshold=3, percentile=10.0) + dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"])) + + calculator = DatasetCalculator(dataset, num_workers=4) + + target_spacing = calculator.get_target_spacing() self.assertEqual(target_spacing, (1.0, 1.0, 1.0)) - intensity_stats = calculator._get_intensity_stats(lower=0.5, upper=99.5) - self.assertEqual(intensity_stats, (0.56, 1.0, 0.89, 0.13)) + calculator.calculate_statistics() + np.testing.assert_allclose(calculator.data_mean, 0.892599, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(calculator.data_std, 0.131731, rtol=1e-5, atol=1e-5) + calculator.calculate_percentiles(sampling_flag=True, interval=2) + self.assertEqual(calculator.data_max_percentile, 1.0) + np.testing.assert_allclose(calculator.data_min_percentile, 0.556411, rtol=1e-5, atol=1e-5) if __name__ == "__main__": From 0ab5b0c05a8871c00fcd016a1000be93ae0a4d7c Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 28 Jul 2021 13:48:41 +0800 Subject: [PATCH 5/9] update to support array return Signed-off-by: Yiheng Wang --- monai/data/dataset_calculator.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/monai/data/dataset_calculator.py b/monai/data/dataset_calculator.py index 5aed007493..406df33dcd 100644 --- a/monai/data/dataset_calculator.py +++ b/monai/data/dataset_calculator.py @@ -124,6 +124,7 @@ def calculate_statistics(self, foreground_threshold: int = 0): image, label = data[self.image_key], data[self.label_key] else: image, label = data + voxel_max.append(image.max().item()) voxel_min.append(image.min().item()) @@ -162,8 +163,11 @@ def calculate_percentiles( """ all_intensities = [] for data in self.data_loader: + if self.image_key and self.label_key: + image, label = data[self.image_key], data[self.label_key] + else: + image, label = data - image, label = data[self.image_key], data[self.label_key] intensities = image[torch.where(label > foreground_threshold)].tolist() if sampling_flag: intensities = intensities[::interval] From 4b0a8755487583edaf72546612e06e807c287c06 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 29 Jul 2021 23:04:59 +0800 Subject: [PATCH 6/9] update with new testcases and change name Signed-off-by: Yiheng Wang --- docs/source/data.rst | 6 +-- monai/data/__init__.py | 2 +- ...taset_calculator.py => dataset_summary.py} | 4 +- ..._calculator.py => test_dataset_summary.py} | 40 +++++++++++++++++-- 4 files changed, 44 insertions(+), 8 deletions(-) rename monai/data/{dataset_calculator.py => dataset_summary.py} (97%) rename tests/{test_dataset_calculator.py => test_dataset_summary.py} (57%) diff --git a/docs/source/data.rst b/docs/source/data.rst index 2626f4ff71..022f7877d1 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -182,9 +182,9 @@ DistributedWeightedRandomSampler ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: monai.data.DistributedWeightedRandomSampler -DatasetCalculator -~~~~~~~~~~~~~~~~~ -.. autoclass:: monai.data.DatasetCalculator +DatasetSummary +~~~~~~~~~~~~~~ +.. autoclass:: monai.data.DatasetSummary Decathlon Datalist ~~~~~~~~~~~~~~~~~~ diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 158465d141..fca170335b 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -23,7 +23,7 @@ SmartCacheDataset, ZipDataset, ) -from .dataset_calculator import DatasetCalculator +from .dataset_summary import DatasetSummary from .decathlon_datalist import load_decathlon_datalist, load_decathlon_properties from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter from .image_dataset import ImageDataset diff --git a/monai/data/dataset_calculator.py b/monai/data/dataset_summary.py similarity index 97% rename from monai/data/dataset_calculator.py rename to monai/data/dataset_summary.py index 406df33dcd..b61e5422f6 100644 --- a/monai/data/dataset_calculator.py +++ b/monai/data/dataset_summary.py @@ -19,7 +19,7 @@ from monai.data.dataset import Dataset -class DatasetCalculator: +class DatasetSummary: """ This class provides a way to calculate a reasonable output voxel spacing according to the input dataset. The achieved values can used to resample the input in 3d segmentation tasks @@ -78,6 +78,8 @@ def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: Calculate the target spacing according to all spacings. If the target spacing is very anisotropic, decrease the spacing value of the maximum axis according to percentile. + So far, this function only supports NIFTI images which store spacings in headers with key "pixdim". After loading + with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`. Args: spacing_key: key of spacing in meta data (default: ``pixdim``). diff --git a/tests/test_dataset_calculator.py b/tests/test_dataset_summary.py similarity index 57% rename from tests/test_dataset_calculator.py rename to tests/test_dataset_summary.py index fe4e13fca4..7b1d8460c0 100644 --- a/tests/test_dataset_calculator.py +++ b/tests/test_dataset_summary.py @@ -17,12 +17,12 @@ import nibabel as nib import numpy as np -from monai.data import Dataset, DatasetCalculator, create_test_image_3d +from monai.data import Dataset, DatasetSummary, create_test_image_3d from monai.transforms import LoadImaged from monai.utils import set_determinism -class TestDatasetCalculator(unittest.TestCase): +class TestDatasetSummary(unittest.TestCase): def test_spacing_intensity(self): set_determinism(seed=0) with tempfile.TemporaryDirectory() as tempdir: @@ -42,7 +42,7 @@ def test_spacing_intensity(self): dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"])) - calculator = DatasetCalculator(dataset, num_workers=4) + calculator = DatasetSummary(dataset, num_workers=4) target_spacing = calculator.get_target_spacing() self.assertEqual(target_spacing, (1.0, 1.0, 1.0)) @@ -53,6 +53,40 @@ def test_spacing_intensity(self): self.assertEqual(calculator.data_max_percentile, 1.0) np.testing.assert_allclose(calculator.data_min_percentile, 0.556411, rtol=1e-5, atol=1e-5) + def test_anisotropic_spacing(self): + set_determinism(seed=0) + with tempfile.TemporaryDirectory() as tempdir: + + pixdims = [ + [1.0, 1.0, 5.0], + [1.0, 1.0, 4.0], + [1.0, 1.0, 4.5], + [1.0, 1.0, 2.0], + [1.0, 1.0, 1.0], + ] + for i in range(5): + im, seg = create_test_image_3d(32, 32, 32, num_seg_classes=1, num_objs=3, rad_max=6, channel_dim=0) + n = nib.Nifti1Image(im, np.eye(4)) + n.header["pixdim"][1:4] = pixdims[i] + nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) + n = nib.Nifti1Image(seg, np.eye(4)) + n.header["pixdim"][1:4] = pixdims[i] + nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) + + train_images = sorted(glob.glob(os.path.join(tempdir, "img*.nii.gz"))) + train_labels = sorted(glob.glob(os.path.join(tempdir, "seg*.nii.gz"))) + data_dicts = [ + {"image": image_name, "label": label_name} + for image_name, label_name in zip(train_images, train_labels) + ] + + dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"])) + + calculator = DatasetSummary(dataset, num_workers=4) + + target_spacing = calculator.get_target_spacing(anisotropic_threshold=4.0, percentile=20.0) + self.assertEqual(target_spacing, (1.0, 1.0, 1.8)) + if __name__ == "__main__": unittest.main() From 20bd23021464a14fedeb30e1fb6af3402cb5246f Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 29 Jul 2021 23:07:49 +0800 Subject: [PATCH 7/9] update min test Signed-off-by: Yiheng Wang --- tests/min_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/min_tests.py b/tests/min_tests.py index 4e08fcc832..1f53569cd9 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -111,7 +111,7 @@ def run_testsuit(): "test_handler_metrics_saver", "test_handler_metrics_saver_dist", "test_handler_classification_saver_dist", - "test_dataset_calculator", + "test_dataset_summary", "test_deepgrow_transforms", "test_deepgrow_interaction", "test_deepgrow_dataset", From e334e2576d7d8eb4aa9dd975aee7afc195300e69 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 29 Jul 2021 23:16:21 +0800 Subject: [PATCH 8/9] update unittest Signed-off-by: Yiheng Wang --- tests/test_dataset_summary.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/tests/test_dataset_summary.py b/tests/test_dataset_summary.py index 7b1d8460c0..5307bc7e66 100644 --- a/tests/test_dataset_summary.py +++ b/tests/test_dataset_summary.py @@ -54,7 +54,6 @@ def test_spacing_intensity(self): np.testing.assert_allclose(calculator.data_min_percentile, 0.556411, rtol=1e-5, atol=1e-5) def test_anisotropic_spacing(self): - set_determinism(seed=0) with tempfile.TemporaryDirectory() as tempdir: pixdims = [ @@ -73,19 +72,18 @@ def test_anisotropic_spacing(self): n.header["pixdim"][1:4] = pixdims[i] nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) - train_images = sorted(glob.glob(os.path.join(tempdir, "img*.nii.gz"))) - train_labels = sorted(glob.glob(os.path.join(tempdir, "seg*.nii.gz"))) - data_dicts = [ - {"image": image_name, "label": label_name} - for image_name, label_name in zip(train_images, train_labels) - ] + train_images = sorted(glob.glob(os.path.join(tempdir, "img*.nii.gz"))) + train_labels = sorted(glob.glob(os.path.join(tempdir, "seg*.nii.gz"))) + data_dicts = [ + {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) + ] - dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"])) + dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"])) - calculator = DatasetSummary(dataset, num_workers=4) + calculator = DatasetSummary(dataset, num_workers=4) - target_spacing = calculator.get_target_spacing(anisotropic_threshold=4.0, percentile=20.0) - self.assertEqual(target_spacing, (1.0, 1.0, 1.8)) + target_spacing = calculator.get_target_spacing(anisotropic_threshold=4.0, percentile=20.0) + np.testing.assert_allclose(target_spacing, (1.0, 1.0, 1.8)) if __name__ == "__main__": From 8f0369e9f3cf3a2f96016bd07c62aa99d5415d95 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 30 Jul 2021 11:47:54 +0800 Subject: [PATCH 9/9] fix vstack error Signed-off-by: Yiheng Wang --- monai/data/dataset_summary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/dataset_summary.py b/monai/data/dataset_summary.py index b61e5422f6..a8598eb6c8 100644 --- a/monai/data/dataset_summary.py +++ b/monai/data/dataset_summary.py @@ -93,7 +93,7 @@ def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: if spacing_key not in self.all_meta_data[0]: raise ValueError("The provided spacing_key is not in self.all_meta_data.") - all_spacings = torch.vstack([data[spacing_key][:, 1:4] for data in self.all_meta_data]).numpy() + all_spacings = torch.cat([data[spacing_key][:, 1:4] for data in self.all_meta_data], dim=0).numpy() target_spacing = np.median(all_spacings, axis=0) if max(target_spacing) / min(target_spacing) >= anisotropic_threshold: