From 254aac2b0572a1b8a7234e634226d1670a80d623 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 3 Mar 2021 12:41:29 +0000 Subject: [PATCH] nifti saver squeeze dims Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/nifti_saver.py | 14 ++++++++++++++ monai/transforms/io/array.py | 8 ++++++++ monai/transforms/io/dictionary.py | 8 ++++++++ tests/test_nifti_saver.py | 24 ++++++++++++++++++++++++ 4 files changed, 54 insertions(+) diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index 01e701b1a6..016b06fda5 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -27,6 +27,8 @@ class NiftiSaver: Typically, the data can be segmentation predictions, call `save` for single data or call `save_batch` to save a batch of data together. If no meta data provided, use index from 0 as the filename prefix. + + NB: image should include channel dimension: [B],C,H,W,[D]. """ def __init__( @@ -40,6 +42,7 @@ def __init__( align_corners: bool = False, dtype: DtypeLike = np.float64, output_dtype: DtypeLike = np.float32, + squeeze_end_dims: bool = True, ) -> None: """ Args: @@ -60,6 +63,10 @@ def __init__( dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. output_dtype: data type for saving data. Defaults to ``np.float32``. + squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel + has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and + then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false, + image will always be saved as (H,W,D,C). """ self.output_dir = output_dir self.output_postfix = output_postfix @@ -71,6 +78,7 @@ def __init__( self.dtype = dtype self.output_dtype = output_dtype self._data_index = 0 + self.squeeze_end_dims = squeeze_end_dims def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: """ @@ -111,6 +119,12 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] data = np.expand_dims(data, -1) # change data to "channel last" format and write to nifti format file data = np.moveaxis(np.asarray(data), 0, -1) + + # if desired, remove trailing singleton dimensions + if self.squeeze_end_dims: + while data.shape[-1] == 1: + data = np.squeeze(data, -1) + write_nifti( data, file_name=filename, diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 9c4f631699..4ede04cf69 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -165,6 +165,8 @@ class SaveImage(Transform): It can work for both numpy array and PyTorch Tensor in both pre-transform chain and post transform chain. + NB: image should include channel dimension: [B],C,H,W,[D]. + Args: output_dir: output image directory. output_postfix: a string appended to all output file names, default to `trans`. @@ -200,6 +202,10 @@ class SaveImage(Transform): it's used for NIfTI format only. save_batch: whether the import image is a batch data, default to `False`. usually pre-transforms run for channel first data, while post-transforms run for batch data. + squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel + has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and + then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false, + image will always be saved as (H,W,D,C). """ @@ -215,6 +221,7 @@ def __init__( dtype: DtypeLike = np.float64, output_dtype: DtypeLike = np.float32, save_batch: bool = False, + squeeze_end_dims: bool = True, ) -> None: self.saver: Union[NiftiSaver, PNGSaver] if output_ext in (".nii.gz", ".nii"): @@ -227,6 +234,7 @@ def __init__( padding_mode=padding_mode, dtype=dtype, output_dtype=output_dtype, + squeeze_end_dims=squeeze_end_dims, ) elif output_ext == ".png": self.saver = PNGSaver( diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index d9b6b5e6ab..6b168503c6 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -122,6 +122,8 @@ class SaveImaged(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.SaveImage`. + NB: image should include channel dimension: [B],C,H,W,[D]. + Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` @@ -163,6 +165,10 @@ class SaveImaged(MapTransform): it's used for NIfTI format only. save_batch: whether the import image is a batch data, default to `False`. usually pre-transforms run for channel first data, while post-transforms run for batch data. + squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel + has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and + then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false, + image will always be saved as (H,W,D,C). """ @@ -180,6 +186,7 @@ def __init__( dtype: DtypeLike = np.float64, output_dtype: DtypeLike = np.float32, save_batch: bool = False, + squeeze_end_dims: bool = True, ) -> None: super().__init__(keys) self.meta_key_postfix = meta_key_postfix @@ -194,6 +201,7 @@ def __init__( dtype=dtype, output_dtype=output_dtype, save_batch=save_batch, + squeeze_end_dims=squeeze_end_dims, ) def __call__(self, data): diff --git a/tests/test_nifti_saver.py b/tests/test_nifti_saver.py index 2e2bfd4254..f48374a61c 100644 --- a/tests/test_nifti_saver.py +++ b/tests/test_nifti_saver.py @@ -17,6 +17,7 @@ import torch from monai.data import NiftiSaver +from monai.transforms import LoadImage class TestNiftiSaver(unittest.TestCase): @@ -72,6 +73,29 @@ def test_saved_3d_resize_content(self): filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + def test_squeeze_end_dims(self): + with tempfile.TemporaryDirectory() as tempdir: + + for squeeze_end_dims in [False, True]: + + saver = NiftiSaver( + output_dir=tempdir, + output_postfix="", + output_ext=".nii.gz", + dtype=np.float32, + squeeze_end_dims=squeeze_end_dims, + ) + + fname = "testfile_squeeze" + meta_data = {"filename_or_obj": fname} + + # 2d image w channel + saver.save(torch.randint(0, 255, (1, 2, 2)), meta_data) + + im, meta = LoadImage()(os.path.join(tempdir, fname, fname + ".nii.gz")) + self.assertTrue(im.ndim == 2 if squeeze_end_dims else 4) + self.assertTrue(meta["dim"][0] == im.ndim) + if __name__ == "__main__": unittest.main()