Skip to content
12 changes: 12 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,12 @@ IO
:members:
:special-members: __call__

`SaveImage`
"""""""""""
.. autoclass:: SaveImage
:members:
:special-members: __call__

Post-processing
^^^^^^^^^^^^^^^

Expand Down Expand Up @@ -702,6 +708,12 @@ IO (Dict)
:members:
:special-members: __call__

`SaveImaged`
""""""""""""
.. autoclass:: SaveImaged
:members:
:special-members: __call__

Post-processing (Dict)
^^^^^^^^^^^^^^^^^^^^^^

Expand Down
3 changes: 1 addition & 2 deletions monai/data/nifti_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ def __init__(
align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
If None, use the data type of input data. To be compatible with other modules,
the output data type is always ``np.float32``.
If None, use the data type of input data.
output_dtype: data type for saving data. Defaults to ``np.float32``.
"""
self.output_dir = output_dir
Expand Down
3 changes: 1 addition & 2 deletions monai/data/nifti_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ def write_nifti(
align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
If None, use the data type of input data. To be compatible with other modules,
the output data type is always ``np.float32``.
If None, use the data type of input data.
output_dtype: data type for saving data. Defaults to ``np.float32``.
"""
if not isinstance(data, np.ndarray):
Expand Down
6 changes: 5 additions & 1 deletion monai/data/png_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import numpy as np

from monai.transforms import Resize
from monai.transforms.spatial.array import Resize
from monai.utils import InterpolateMode, ensure_tuple_rep, optional_import

Image, _ = optional_import("PIL", name="Image")
Expand Down Expand Up @@ -76,6 +76,10 @@ def write_png(
else:
raise ValueError(f"Unsupported scale: {scale}, available options are [255, 65535]")

# PNG data must be int number
if data.dtype not in (np.uint8, np.uint16): # type: ignore
data = data.astype(np.uint8)

img = Image.fromarray(data)
img.save(file_name, "PNG")
return
47 changes: 20 additions & 27 deletions monai/handlers/segmentation_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import numpy as np

from monai.config import DtypeLike
from monai.data import NiftiSaver, PNGSaver
from monai.transforms import SaveImage
from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, exact_version, optional_import

Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events")
Expand Down Expand Up @@ -48,9 +48,11 @@ def __init__(
"""
Args:
output_dir: output image directory.
output_postfix: a string appended to all output file names.
output_ext: output file extension name.
output_postfix: a string appended to all output file names, default to `seg`.
output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`.
resample: whether to resample before saving the data array.
if saving PNG format image, based on the `spatial_shape` from metadata.
if saving NIfTI format image, based on the `original_affine` from metadata.
mode: This option is used when ``resample = True``. Defaults to ``"nearest"``.

- NIfTI files {``"bilinear"``, ``"nearest"``}
Expand All @@ -72,8 +74,8 @@ def __init__(
[0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling.
It's used for PNG format only.
dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
If None, use the data type of input data. To be compatible with other modules,
the output data type is always ``np.float32``, it's used for Nifti format only.
If None, use the data type of input data.
It's used for Nifti format only.
output_dtype: data type for saving data. Defaults to ``np.float32``, it's used for Nifti format only.
batch_transform: a callable that is used to transform the
ignite.engine.batch into expected format to extract the meta_data dictionary.
Expand All @@ -84,27 +86,18 @@ def __init__(
name: identifier of logging.logger to use, defaulting to `engine.logger`.

"""
self.saver: Union[NiftiSaver, PNGSaver]
if output_ext in (".nii.gz", ".nii"):
self.saver = NiftiSaver(
output_dir=output_dir,
output_postfix=output_postfix,
output_ext=output_ext,
resample=resample,
mode=GridSampleMode(mode),
padding_mode=padding_mode,
dtype=dtype,
output_dtype=output_dtype,
)
elif output_ext == ".png":
self.saver = PNGSaver(
output_dir=output_dir,
output_postfix=output_postfix,
output_ext=output_ext,
resample=resample,
mode=InterpolateMode(mode),
scale=scale,
)
self._saver = SaveImage(
output_dir=output_dir,
output_postfix=output_postfix,
output_ext=output_ext,
resample=resample,
mode=mode,
padding_mode=padding_mode,
scale=scale,
dtype=dtype,
output_dtype=output_dtype,
save_batch=True,
)
self.batch_transform = batch_transform
self.output_transform = output_transform

Expand All @@ -131,5 +124,5 @@ def __call__(self, engine: Engine) -> None:
"""
meta_data = self.batch_transform(engine.state.batch)
engine_output = self.output_transform(engine.state.output)
self.saver.save_batch(engine_output, meta_data)
self._saver(engine_output, meta_data)
self.logger.info("saved all the model outputs into files.")
4 changes: 2 additions & 2 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@
ThresholdIntensityD,
ThresholdIntensityDict,
)
from .io.array import LoadImage
from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict
from .io.array import LoadImage, SaveImage
from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict
from .post.array import (
Activations,
AsDiscrete,
Expand Down
100 changes: 97 additions & 3 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,24 @@
https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design
"""

from typing import List, Optional, Sequence, Union
from typing import Dict, List, Optional, Sequence, Union

import numpy as np
import torch

from monai.config import DtypeLike
from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader
from monai.data.nifti_saver import NiftiSaver
from monai.data.png_saver import PNGSaver
from monai.transforms.compose import Transform
from monai.utils import GridSampleMode, GridSamplePadMode
from monai.utils import ImageMetaKey as Key
from monai.utils import ensure_tuple, optional_import
from monai.utils import InterpolateMode, ensure_tuple, optional_import

nib, _ = optional_import("nibabel")
Image, _ = optional_import("PIL.Image")

__all__ = ["LoadImage"]
__all__ = ["LoadImage", "SaveImage"]


class LoadImage(Transform):
Expand Down Expand Up @@ -129,3 +133,93 @@ def __call__(
return img_array
meta_data[Key.FILENAME_OR_OBJ] = ensure_tuple(filename)[0]
return img_array, meta_data


class SaveImage(Transform):
"""
Save transformed data into files, support NIfTI and PNG formats.
It can work for both numpy array and PyTorch Tensor in both pre-transform chain
and post transform chain.

Args:
output_dir: output image directory.
output_postfix: a string appended to all output file names, default to `trans`.
output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`.
resample: whether to resample before saving the data array.
if saving PNG format image, based on the `spatial_shape` from metadata.
if saving NIfTI format image, based on the `original_affine` from metadata.
mode: This option is used when ``resample = True``. Defaults to ``"nearest"``.

- NIfTI files {``"bilinear"``, ``"nearest"``}
Interpolation mode to calculate output values.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
- PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
The interpolation mode.
See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate

padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``.

- NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
- PNG files
This option is ignored.

scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling
[0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling.
it's used for PNG format only.
dtype: data type during resampling computation. Defaults to ``np.float64`` for best precision.
if None, use the data type of input data. To be compatible with other modules,
the output data type is always ``np.float32``.
it's used for NIfTI format only.
output_dtype: data type for saving data. Defaults to ``np.float32``.
it's used for NIfTI format only.
save_batch: whether the import image is a batch data, default to `False`.
usually pre-transforms run for channel first data, while post-transforms run for batch data.

"""

def __init__(
self,
output_dir: str = "./",
output_postfix: str = "trans",
output_ext: str = ".nii.gz",
resample: bool = True,
mode: Union[GridSampleMode, InterpolateMode, str] = "nearest",
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
scale: Optional[int] = None,
dtype: DtypeLike = np.float64,
output_dtype: DtypeLike = np.float32,
save_batch: bool = False,
) -> None:
self.saver: Union[NiftiSaver, PNGSaver]
if output_ext in (".nii.gz", ".nii"):
self.saver = NiftiSaver(
output_dir=output_dir,
output_postfix=output_postfix,
output_ext=output_ext,
resample=resample,
mode=GridSampleMode(mode),
padding_mode=padding_mode,
dtype=dtype,
output_dtype=output_dtype,
)
elif output_ext == ".png":
self.saver = PNGSaver(
output_dir=output_dir,
output_postfix=output_postfix,
output_ext=output_ext,
resample=resample,
mode=InterpolateMode(mode),
scale=scale,
)
else:
raise ValueError(f"unsupported output extension: {output_ext}.")

self.save_batch = save_batch

def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None):
if self.save_batch:
self.saver.save_batch(img, meta_data)
else:
self.saver.save(img, meta_data)
93 changes: 92 additions & 1 deletion monai/transforms/io/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@
from monai.config import DtypeLike, KeysCollection
from monai.data.image_reader import ImageReader
from monai.transforms.compose import MapTransform
from monai.transforms.io.array import LoadImage
from monai.transforms.io.array import LoadImage, SaveImage
from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode

__all__ = [
"LoadImaged",
"LoadImageD",
"LoadImageDict",
"SaveImaged",
"SaveImageD",
"SaveImageDict",
]


Expand Down Expand Up @@ -106,4 +110,91 @@ def __call__(self, data, reader: Optional[ImageReader] = None):
return d


class SaveImaged(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.SaveImage`.

Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
meta_key_postfix: `key_{postfix}` was used to store the metadata in `LoadImaged`.
So need the key to extract metadata to save images, default is `meta_dict`.
The meta data is a dictionary object, if no corresponding metadata, set to `None`.
For example, for data with key `image`, the metadata by default is in `image_meta_dict`.
output_dir: output image directory.
output_postfix: a string appended to all output file names, default to `trans`.
output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`.
resample: whether to resample before saving the data array.
if saving PNG format image, based on the `spatial_shape` from metadata.
if saving NIfTI format image, based on the `original_affine` from metadata.
mode: This option is used when ``resample = True``. Defaults to ``"nearest"``.

- NIfTI files {``"bilinear"``, ``"nearest"``}
Interpolation mode to calculate output values.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
- PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
The interpolation mode.
See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate

padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``.

- NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
- PNG files
This option is ignored.

scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling
[0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling.
it's used for PNG format only.
dtype: data type during resampling computation. Defaults to ``np.float64`` for best precision.
if None, use the data type of input data. To be compatible with other modules,
the output data type is always ``np.float32``.
it's used for NIfTI format only.
output_dtype: data type for saving data. Defaults to ``np.float32``.
it's used for NIfTI format only.
save_batch: whether the import image is a batch data, default to `False`.
usually pre-transforms run for channel first data, while post-transforms run for batch data.

"""

def __init__(
self,
keys: KeysCollection,
meta_key_postfix: str = "meta_dict",
output_dir: str = "./",
output_postfix: str = "trans",
output_ext: str = ".nii.gz",
resample: bool = True,
mode: Union[GridSampleMode, InterpolateMode, str] = "nearest",
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
scale: Optional[int] = None,
dtype: DtypeLike = np.float64,
output_dtype: DtypeLike = np.float32,
save_batch: bool = False,
) -> None:
super().__init__(keys)
self.meta_key_postfix = meta_key_postfix
self._saver = SaveImage(
output_dir=output_dir,
output_postfix=output_postfix,
output_ext=output_ext,
resample=resample,
mode=mode,
padding_mode=padding_mode,
scale=scale,
dtype=dtype,
output_dtype=output_dtype,
save_batch=save_batch,
)

def __call__(self, data):
d = dict(data)
for key in self.keys:
meta_data = d[f"{key}_{self.meta_key_postfix}"] if self.meta_key_postfix is not None else None
self._saver(img=d[key], meta_data=meta_data)
return d


LoadImageD = LoadImageDict = LoadImaged
SaveImageD = SaveImageDict = SaveImaged
2 changes: 2 additions & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def run_testsuit():
"test_deepgrow_transforms",
"test_deepgrow_interaction",
"test_deepgrow_dataset",
"test_save_image",
"test_save_imaged",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"

Expand Down
Loading