diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index 34df819d32..0ff719023c 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Sequence, Union +from typing import Dict, Optional, Union import numpy as np import torch @@ -93,12 +93,7 @@ def __init__( self.squeeze_end_dims = squeeze_end_dims self.data_root_dir = data_root_dir - def save( - self, - data: Union[torch.Tensor, np.ndarray], - meta_data: Optional[Dict] = None, - patch_index: Optional[int] = None, - ) -> None: + def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: """ Save data into a Nifti file. The meta_data could optionally have the following keys: @@ -107,6 +102,7 @@ def save( - ``'original_affine'`` -- for data orientation handling, defaulting to an identity matrix. - ``'affine'`` -- for data output affine, defaulting to an identity matrix. - ``'spatial_shape'`` -- for data output shape. + - ``'patch_index'`` -- if the data is a patch of big image, append the patch index to filename. When meta_data is specified, the saver will try to resample batch data from the space defined by "affine" to the space defined by "original_affine". @@ -117,7 +113,6 @@ def save( data: target data content that to be saved as a NIfTI format file. Assuming the data shape starts with a channel dimension and followed by spatial dimensions. meta_data: the meta data information corresponding to the data. - patch_index: if the data is a patch of big image, need to append the patch index to filename. See Also :py:meth:`monai.data.nifti_writer.write_nifti` @@ -127,6 +122,7 @@ def save( original_affine = meta_data.get("original_affine", None) if meta_data else None affine = meta_data.get("affine", None) if meta_data else None spatial_shape = meta_data.get("spatial_shape", None) if meta_data else None + patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None if isinstance(data, torch.Tensor): data = data.detach().cpu().numpy() @@ -158,12 +154,7 @@ def save( output_dtype=self.output_dtype, ) - def save_batch( - self, - batch_data: Union[torch.Tensor, np.ndarray], - meta_data: Optional[Dict] = None, - patch_indice: Optional[Sequence[int]] = None, - ) -> None: + def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: """ Save a batch of data into Nifti format files. @@ -180,11 +171,7 @@ def save_batch( Args: batch_data: target batch data content that save into NIfTI format. meta_data: every key-value in the meta_data is corresponding to a batch of data. - patch_indice: if the data is a patch of big image, need to append the patch index to filename. + """ for i, data in enumerate(batch_data): # save a batch of files - self.save( - data=data, - meta_data={k: meta_data[k][i] for k in meta_data} if meta_data is not None else None, - patch_index=patch_indice[i] if patch_indice is not None else None, - ) + self.save(data=data, meta_data={k: meta_data[k][i] for k in meta_data} if meta_data is not None else None) diff --git a/monai/data/png_saver.py b/monai/data/png_saver.py index 17087fcaca..880f6b204f 100644 --- a/monai/data/png_saver.py +++ b/monai/data/png_saver.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Sequence, Union +from typing import Dict, Optional, Union import numpy as np import torch @@ -71,18 +71,14 @@ def __init__( self._data_index = 0 - def save( - self, - data: Union[torch.Tensor, np.ndarray], - meta_data: Optional[Dict] = None, - patch_index: Optional[int] = None, - ) -> None: + def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: """ Save data into a png file. The meta_data could optionally have the following keys: - ``'filename_or_obj'`` -- for output file name creation, corresponding to filename or object. - ``'spatial_shape'`` -- for data output shape. + - ``'patch_index'`` -- if the data is a patch of big image, append the patch index to filename. If meta_data is None, use the default index (starting from 0) as the filename. @@ -92,7 +88,6 @@ def save( Shape of the spatial dimensions (C,H,W). C should be 1, 3 or 4 meta_data: the meta data information corresponding to the data. - patch_index: if the data is a patch of big image, need to append the patch index to filename. Raises: ValueError: When ``data`` channels is not one of [1, 3, 4]. @@ -104,6 +99,7 @@ def save( filename = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) self._data_index += 1 spatial_shape = meta_data.get("spatial_shape", None) if meta_data and self.resample else None + patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None if isinstance(data, torch.Tensor): data = data.detach().cpu().numpy() @@ -126,22 +122,13 @@ def save( scale=self.scale, ) - def save_batch( - self, - batch_data: Union[torch.Tensor, np.ndarray], - meta_data: Optional[Dict] = None, - patch_indice: Optional[Sequence[int]] = None, - ) -> None: + def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: """Save a batch of data into png format files. Args: batch_data: target batch data content that save into png format. meta_data: every key-value in the meta_data is corresponding to a batch of data. - patch_indice: if the data is a patch of big image, need to append the patch index to filename. + """ for i, data in enumerate(batch_data): # save a batch of files - self.save( - data=data, - meta_data={k: meta_data[k][i] for k in meta_data} if meta_data is not None else None, - patch_index=patch_indice[i] if patch_indice is not None else None, - ) + self.save(data=data, meta_data={k: meta_data[k][i] for k in meta_data} if meta_data is not None else None) diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index 7df10b9dad..9ee7ca67f9 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -16,9 +16,7 @@ from monai.config import DtypeLike from monai.transforms import SaveImage -from monai.utils import GridSampleMode, GridSamplePadMode -from monai.utils import ImageMetaKey as Key -from monai.utils import InterpolateMode, exact_version, optional_import +from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, exact_version, optional_import Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") if TYPE_CHECKING: @@ -145,6 +143,5 @@ def __call__(self, engine: Engine) -> None: """ meta_data = self.batch_transform(engine.state.batch) engine_output = self.output_transform(engine.state.output) - patch_indice = engine.state.batch.get(Key.PATCH_INDEX, None) - self._saver(engine_output, meta_data, patch_indice) + self._saver(engine_output, meta_data) self.logger.info("saved all the model outputs into files.") diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 428e35335c..1d4fcfdb1f 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -483,7 +483,7 @@ class RandSpatialCropSamplesd(RandomizableTransform, MapTransform): Crop image with random size or specific size ROI to generate a list of N samples. It can crop at a random position as center or at the image center. And allows to set the minimum size to limit the randomly generated ROI. Suppose all the expected fields - specified by `keys` have same shape. + specified by `keys` have same shape, and add `patch_index` to the corresponding meta data. It will return a list of dictionaries for all the cropped images. Args: @@ -495,6 +495,9 @@ class RandSpatialCropSamplesd(RandomizableTransform, MapTransform): random_center: crop at random position as center or the image center. random_size: crop with random size or specific size ROI. The actual size is sampled from `randint(roi_size, img_size)`. + meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data, + default is `meta_dict`, the meta data is a dictionary object. + used to add `patch_index` to the meta dict. allow_missing_keys: don't raise exception if key is missing. Raises: @@ -509,6 +512,7 @@ def __init__( num_samples: int, random_center: bool = True, random_size: bool = True, + meta_key_postfix: str = "meta_dict", allow_missing_keys: bool = False, ) -> None: RandomizableTransform.__init__(self, prob=1.0, do_transform=True) @@ -517,6 +521,7 @@ def __init__( raise ValueError(f"num_samples must be positive, got {num_samples}.") self.num_samples = num_samples self.cropper = RandSpatialCropd(keys, roi_size, random_center, random_size, allow_missing_keys) + self.meta_key_postfix = meta_key_postfix def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None @@ -530,9 +535,15 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]: ret = [] + d = dict(data) for i in range(self.num_samples): - cropped = self.cropper(data) - cropped[Key.PATCH_INDEX] = i # type: ignore + cropped = self.cropper(d) + # add `patch_index` to the meta data + for key in self.key_iterator(d): + meta_data_key = f"{key}_{self.meta_key_postfix}" + if meta_data_key not in cropped: + cropped[meta_data_key] = {} # type: ignore + cropped[meta_data_key][Key.PATCH_INDEX] = i ret.append(cropped) return ret @@ -687,6 +698,8 @@ class RandCropByPosNegLabeld(RandomizableTransform, MapTransform): Dictionary-based version :py:class:`monai.transforms.RandCropByPosNegLabel`. Crop random fixed sized regions with the center being a foreground or background voxel based on the Pos Neg Ratio. + Suppose all the expected fields specified by `keys` have same shape, + and add `patch_index` to the corresponding meta data. And will return a list of dictionaries for all the cropped images. Args: @@ -712,6 +725,9 @@ class RandCropByPosNegLabeld(RandomizableTransform, MapTransform): `image_threshold`, and randomly select crop centers based on them, need to provide `fg_indices_key` and `bg_indices_key` together, expect to be 1 dim array of spatial indices after flattening. a typical usage is to call `FgBgToIndicesd` transform first and cache the results. + meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data, + default is `meta_dict`, the meta data is a dictionary object. + used to add `patch_index` to the meta dict. allow_missing_keys: don't raise exception if key is missing. Raises: @@ -732,6 +748,7 @@ def __init__( image_threshold: float = 0.0, fg_indices_key: Optional[str] = None, bg_indices_key: Optional[str] = None, + meta_key_postfix: str = "meta_dict", allow_missing_keys: bool = False, ) -> None: RandomizableTransform.__init__(self) @@ -748,6 +765,7 @@ def __init__( self.image_threshold = image_threshold self.fg_indices_key = fg_indices_key self.bg_indices_key = bg_indices_key + self.meta_key_postfix = meta_key_postfix self.centers: Optional[List[List[np.ndarray]]] = None def randomize( @@ -789,8 +807,12 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n # fill in the extra keys with unmodified data for key in set(data.keys()).difference(set(self.keys)): results[i][key] = data[key] - # add patch index in the meta data - results[i][Key.PATCH_INDEX] = i # type: ignore + # add `patch_index` to the meta data + for key in self.key_iterator(d): + meta_data_key = f"{key}_{self.meta_key_postfix}" + if meta_data_key not in results[i]: + results[i][meta_data_key] = {} # type: ignore + results[i][meta_data_key][Key.PATCH_INDEX] = i return results diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index b138b97cb2..7a7fcb8cda 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -269,19 +269,14 @@ def __init__( self.save_batch = save_batch - def __call__( - self, - img: Union[torch.Tensor, np.ndarray], - meta_data: Optional[Dict] = None, - patch_index=None, # type is Union[Sequence[int], int, None], can't be compatible with save and save_batch - ): + def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None): """ Args: img: target data content that save into file. meta_data: key-value pairs of meta_data corresponding to the data. - patch_index: if the data is a patch of big image, need to append the patch index to filename. + """ if self.save_batch: - self.saver.save_batch(img, meta_data, patch_index) + self.saver.save_batch(img, meta_data) else: - self.saver.save(img, meta_data, patch_index) + self.saver.save(img, meta_data) diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 5b8f0a41d3..58d6431c74 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -23,9 +23,7 @@ from monai.data.image_reader import ImageReader from monai.transforms.io.array import LoadImage, SaveImage from monai.transforms.transform import MapTransform -from monai.utils import GridSampleMode, GridSamplePadMode -from monai.utils import ImageMetaKey as Key -from monai.utils import InterpolateMode +from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode __all__ = [ "LoadImaged", @@ -229,7 +227,7 @@ def __call__(self, data): d = dict(data) for key in self.key_iterator(d): meta_data = d[f"{key}_{self.meta_key_postfix}"] if self.meta_key_postfix is not None else None - self._saver(img=d[key], meta_data=meta_data, patch_index=d.get(Key.PATCH_INDEX, None)) + self._saver(img=d[key], meta_data=meta_data) return d diff --git a/tests/test_rand_crop_by_pos_neg_labeld.py b/tests/test_rand_crop_by_pos_neg_labeld.py index 2744d729a1..d52ba900ac 100644 --- a/tests/test_rand_crop_by_pos_neg_labeld.py +++ b/tests/test_rand_crop_by_pos_neg_labeld.py @@ -92,7 +92,7 @@ def test_type_shape(self, input_param, input_data, expected_type, expected_shape self.assertTupleEqual(result[0]["extral"].shape, expected_shape) self.assertTupleEqual(result[0]["label"].shape, expected_shape) for i, item in enumerate(result): - self.assertEqual(item["patch_index"], i) + self.assertEqual(item["image_meta_dict"]["patch_index"], i) if __name__ == "__main__": diff --git a/tests/test_rand_spatial_crop_samplesd.py b/tests/test_rand_spatial_crop_samplesd.py index 5b745add18..09688f44b7 100644 --- a/tests/test_rand_spatial_crop_samplesd.py +++ b/tests/test_rand_spatial_crop_samplesd.py @@ -71,7 +71,8 @@ def test_shape(self, input_param, input_data, expected_shape, expected_last): self.assertTupleEqual(item["img"].shape, expected) self.assertTupleEqual(item["seg"].shape, expected) for i, item in enumerate(result): - self.assertEqual(item["patch_index"], i) + self.assertEqual(item["img_meta_dict"]["patch_index"], i) + self.assertEqual(item["seg_meta_dict"]["patch_index"], i) np.testing.assert_allclose(item["img"], expected_last["img"]) np.testing.assert_allclose(item["seg"], expected_last["seg"]) diff --git a/tests/test_save_imaged.py b/tests/test_save_imaged.py index b5293473c2..d6536b3d51 100644 --- a/tests/test_save_imaged.py +++ b/tests/test_save_imaged.py @@ -117,7 +117,7 @@ def test_saved_content(self, test_data, output_ext, resample, save_batch): filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_trans" + output_ext) self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) else: - patch_index = test_data.get("patch_index", None) + patch_index = test_data["img_meta_dict"].get("patch_index", None) patch_index = f"_{patch_index}" if patch_index is not None else "" filepath = os.path.join("testfile0", "testfile0" + "_trans" + patch_index + output_ext) self.assertTrue(os.path.exists(os.path.join(tempdir, filepath)))