Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions monai/data/nifti_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class NiftiSaver:
Typically, the data can be segmentation predictions, call `save` for single data
or call `save_batch` to save a batch of data together. If no meta data provided,
use index from 0 as the filename prefix.
NB: image should include channel dimension: [B],C,H,W,[D].
"""

def __init__(
Expand All @@ -40,6 +42,7 @@ def __init__(
align_corners: bool = False,
dtype: DtypeLike = np.float64,
output_dtype: DtypeLike = np.float32,
squeeze_end_dims: bool = True,
) -> None:
"""
Args:
Expand All @@ -60,6 +63,10 @@ def __init__(
dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
If None, use the data type of input data.
output_dtype: data type for saving data. Defaults to ``np.float32``.
squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel
has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and
then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false,
image will always be saved as (H,W,D,C).
"""
self.output_dir = output_dir
self.output_postfix = output_postfix
Expand All @@ -71,6 +78,7 @@ def __init__(
self.dtype = dtype
self.output_dtype = output_dtype
self._data_index = 0
self.squeeze_end_dims = squeeze_end_dims

def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
"""
Expand Down Expand Up @@ -111,6 +119,12 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
data = np.expand_dims(data, -1)
# change data to "channel last" format and write to nifti format file
data = np.moveaxis(np.asarray(data), 0, -1)

# if desired, remove trailing singleton dimensions
if self.squeeze_end_dims:
while data.shape[-1] == 1:
data = np.squeeze(data, -1)

write_nifti(
data,
file_name=filename,
Expand Down
8 changes: 8 additions & 0 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ class SaveImage(Transform):
It can work for both numpy array and PyTorch Tensor in both pre-transform chain
and post transform chain.

NB: image should include channel dimension: [B],C,H,W,[D].

Args:
output_dir: output image directory.
output_postfix: a string appended to all output file names, default to `trans`.
Expand Down Expand Up @@ -200,6 +202,10 @@ class SaveImage(Transform):
it's used for NIfTI format only.
save_batch: whether the import image is a batch data, default to `False`.
usually pre-transforms run for channel first data, while post-transforms run for batch data.
squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel
has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and
then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false,
image will always be saved as (H,W,D,C).

"""

Expand All @@ -215,6 +221,7 @@ def __init__(
dtype: DtypeLike = np.float64,
output_dtype: DtypeLike = np.float32,
save_batch: bool = False,
squeeze_end_dims: bool = True,
) -> None:
self.saver: Union[NiftiSaver, PNGSaver]
if output_ext in (".nii.gz", ".nii"):
Expand All @@ -227,6 +234,7 @@ def __init__(
padding_mode=padding_mode,
dtype=dtype,
output_dtype=output_dtype,
squeeze_end_dims=squeeze_end_dims,
)
elif output_ext == ".png":
self.saver = PNGSaver(
Expand Down
8 changes: 8 additions & 0 deletions monai/transforms/io/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ class SaveImaged(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.SaveImage`.

NB: image should include channel dimension: [B],C,H,W,[D].

Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
Expand Down Expand Up @@ -166,6 +168,10 @@ class SaveImaged(MapTransform):
save_batch: whether the import image is a batch data, default to `False`.
usually pre-transforms run for channel first data, while post-transforms run for batch data.
allow_missing_keys: don't raise exception if key is missing.
squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel
has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and
then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false,
image will always be saved as (H,W,D,C).

"""

Expand All @@ -184,6 +190,7 @@ def __init__(
output_dtype: DtypeLike = np.float32,
save_batch: bool = False,
allow_missing_keys: bool = False,
squeeze_end_dims: bool = True,
) -> None:
super().__init__(keys, allow_missing_keys)
self.meta_key_postfix = meta_key_postfix
Expand All @@ -198,6 +205,7 @@ def __init__(
dtype=dtype,
output_dtype=output_dtype,
save_batch=save_batch,
squeeze_end_dims=squeeze_end_dims,
)

def __call__(self, data):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_nifti_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch

from monai.data import NiftiSaver
from monai.transforms import LoadImage


class TestNiftiSaver(unittest.TestCase):
Expand Down Expand Up @@ -72,6 +73,29 @@ def test_saved_3d_resize_content(self):
filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz")
self.assertTrue(os.path.exists(os.path.join(tempdir, filepath)))

def test_squeeze_end_dims(self):
with tempfile.TemporaryDirectory() as tempdir:

for squeeze_end_dims in [False, True]:

saver = NiftiSaver(
output_dir=tempdir,
output_postfix="",
output_ext=".nii.gz",
dtype=np.float32,
squeeze_end_dims=squeeze_end_dims,
)

fname = "testfile_squeeze"
meta_data = {"filename_or_obj": fname}

# 2d image w channel
saver.save(torch.randint(0, 255, (1, 2, 2)), meta_data)

im, meta = LoadImage()(os.path.join(tempdir, fname, fname + ".nii.gz"))
self.assertTrue(im.ndim == 2 if squeeze_end_dims else 4)
self.assertTrue(meta["dim"][0] == im.ndim)


if __name__ == "__main__":
unittest.main()