diff --git a/monai/data/dataset_summary.py b/monai/data/dataset_summary.py index 43178d2536..dd8a94143b 100644 --- a/monai/data/dataset_summary.py +++ b/monai/data/dataset_summary.py @@ -15,8 +15,11 @@ import numpy as np import torch +from monai.config import KeysCollection from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset +from monai.transforms import concatenate +from monai.utils import convert_data_type class DatasetSummary: @@ -38,6 +41,7 @@ def __init__( dataset: Dataset, image_key: Optional[str] = "image", label_key: Optional[str] = "label", + meta_key: Optional[KeysCollection] = None, meta_key_postfix: str = "meta_dict", num_workers: int = 0, **kwargs, @@ -47,11 +51,16 @@ def __init__( dataset: dataset from which to load the data. image_key: key name of images (default: ``image``). label_key: key name of labels (default: ``label``). + meta_key: explicitly indicate the key of the corresponding meta data dictionary. + for example, for data with key `image`, the metadata by default is in `image_meta_dict`. + the meta data is a dictionary object which contains: filename, affine, original_shape, etc. + if None, will try to construct meta_keys by `{image_key}_{meta_key_postfix}`. meta_key_postfix: use `{image_key}_{meta_key_postfix}` to fetch the meta data from dict, the meta data is a dictionary object (default: ``meta_dict``). num_workers: how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process (default: ``0``). - kwargs: other parameters (except batch_size) for DataLoader (this class forces to use ``batch_size=1``). + kwargs: other parameters (except `batch_size` and `num_workers`) for DataLoader, + this class forces to use ``batch_size=1``. """ @@ -59,18 +68,17 @@ def __init__( self.image_key = image_key self.label_key = label_key - if image_key: - self.meta_key = f"{image_key}_{meta_key_postfix}" + self.meta_key = meta_key or f"{image_key}_{meta_key_postfix}" self.all_meta_data: List = [] def collect_meta_data(self): """ This function is used to collect the meta data for all images of the dataset. """ - if not self.meta_key: - raise ValueError("To collect meta data for the dataset, `meta_key` should exist.") for data in self.data_loader: + if self.meta_key not in data: + raise ValueError(f"To collect meta data for the dataset, key `{self.meta_key}` must exist in `data`.") self.all_meta_data.append(data[self.meta_key]) def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: int = 3, percentile: float = 10.0): @@ -78,8 +86,8 @@ def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: Calculate the target spacing according to all spacings. If the target spacing is very anisotropic, decrease the spacing value of the maximum axis according to percentile. - So far, this function only supports NIFTI images which store spacings in headers with key "pixdim". After loading - with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`. + So far, this function only supports NIFTI images which store spacings in headers with key "pixdim". + After loading with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`. Args: spacing_key: key of spacing in meta data (default: ``pixdim``). @@ -92,8 +100,8 @@ def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: self.collect_meta_data() if spacing_key not in self.all_meta_data[0]: raise ValueError("The provided spacing_key is not in self.all_meta_data.") - - all_spacings = torch.cat([data[spacing_key][:, 1:4] for data in self.all_meta_data], dim=0).numpy() + all_spacings = concatenate(to_cat=[data[spacing_key][:, 1:4] for data in self.all_meta_data], axis=0) + all_spacings, *_ = convert_data_type(data=all_spacings, output_type=np.ndarray, wrap_sequence=True) target_spacing = np.median(all_spacings, axis=0) if max(target_spacing) / min(target_spacing) >= anisotropic_threshold: @@ -126,6 +134,8 @@ def calculate_statistics(self, foreground_threshold: int = 0): image, label = data[self.image_key], data[self.label_key] else: image, label = data + image, *_ = convert_data_type(data=image, output_type=torch.Tensor) + label, *_ = convert_data_type(data=label, output_type=torch.Tensor) voxel_max.append(image.max().item()) voxel_min.append(image.min().item()) @@ -169,6 +179,8 @@ def calculate_percentiles( image, label = data[self.image_key], data[self.label_key] else: image, label = data + image, *_ = convert_data_type(data=image, output_type=torch.Tensor) + label, *_ = convert_data_type(data=label, output_type=torch.Tensor) intensities = image[torch.where(label > foreground_threshold)].tolist() if sampling_flag: diff --git a/tests/test_dataset_summary.py b/tests/test_dataset_summary.py index 667dc3f190..5569c51a0c 100644 --- a/tests/test_dataset_summary.py +++ b/tests/test_dataset_summary.py @@ -22,6 +22,15 @@ from monai.utils import set_determinism +def test_collate(batch): + elem = batch[0] + elem_type = type(elem) + if isinstance(elem, np.ndarray): + return np.stack(batch, 0) + elif isinstance(elem, dict): + return elem_type({key: test_collate([d[key] for d in batch]) for key in elem}) + + class TestDatasetSummary(unittest.TestCase): def test_spacing_intensity(self): set_determinism(seed=0) @@ -40,9 +49,12 @@ def test_spacing_intensity(self): {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) ] - dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"])) + dataset = Dataset( + data=data_dicts, transform=LoadImaged(keys=["image", "label"], meta_keys=["test1", "test2"]) + ) - calculator = DatasetSummary(dataset, num_workers=4) + # test **kwargs of `DatasetSummary` for `DataLoader` + calculator = DatasetSummary(dataset, num_workers=4, meta_key="test1", collate_fn=test_collate) target_spacing = calculator.get_target_spacing() self.assertEqual(target_spacing, (1.0, 1.0, 1.0)) @@ -74,7 +86,7 @@ def test_anisotropic_spacing(self): dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"])) - calculator = DatasetSummary(dataset, num_workers=4) + calculator = DatasetSummary(dataset, num_workers=4, meta_key_postfix="meta_dict") target_spacing = calculator.get_target_spacing(anisotropic_threshold=4.0, percentile=20.0) np.testing.assert_allclose(target_spacing, (1.0, 1.0, 1.8))