diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 228b73b6c2..5b550f7885 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -237,6 +237,12 @@ IO :members: :special-members: __call__ +`SaveImage` +""""""""""" +.. autoclass:: SaveImage + :members: + :special-members: __call__ + Post-processing ^^^^^^^^^^^^^^^ @@ -702,6 +708,12 @@ IO (Dict) :members: :special-members: __call__ +`SaveImaged` +"""""""""""" +.. autoclass:: SaveImaged + :members: + :special-members: __call__ + Post-processing (Dict) ^^^^^^^^^^^^^^^^^^^^^^ diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index e699a0ce9b..01e701b1a6 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -58,8 +58,7 @@ def __init__( align_corners: Geometrically, we consider the pixels of the input as squares rather than points. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. - If None, use the data type of input data. To be compatible with other modules, - the output data type is always ``np.float32``. + If None, use the data type of input data. output_dtype: data type for saving data. Defaults to ``np.float32``. """ self.output_dir = output_dir diff --git a/monai/data/nifti_writer.py b/monai/data/nifti_writer.py index 29dc62cdec..f530482b14 100644 --- a/monai/data/nifti_writer.py +++ b/monai/data/nifti_writer.py @@ -86,8 +86,7 @@ def write_nifti( align_corners: Geometrically, we consider the pixels of the input as squares rather than points. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. - If None, use the data type of input data. To be compatible with other modules, - the output data type is always ``np.float32``. + If None, use the data type of input data. output_dtype: data type for saving data. Defaults to ``np.float32``. """ if not isinstance(data, np.ndarray): diff --git a/monai/data/png_writer.py b/monai/data/png_writer.py index e6b9f1e8cf..9ce01ed97f 100644 --- a/monai/data/png_writer.py +++ b/monai/data/png_writer.py @@ -13,7 +13,7 @@ import numpy as np -from monai.transforms import Resize +from monai.transforms.spatial.array import Resize from monai.utils import InterpolateMode, ensure_tuple_rep, optional_import Image, _ = optional_import("PIL", name="Image") @@ -76,6 +76,10 @@ def write_png( else: raise ValueError(f"Unsupported scale: {scale}, available options are [255, 65535]") + # PNG data must be int number + if data.dtype not in (np.uint8, np.uint16): # type: ignore + data = data.astype(np.uint8) + img = Image.fromarray(data) img.save(file_name, "PNG") return diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index 8321a49851..a46918b893 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -15,7 +15,7 @@ import numpy as np from monai.config import DtypeLike -from monai.data import NiftiSaver, PNGSaver +from monai.transforms import SaveImage from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, exact_version, optional_import Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") @@ -48,9 +48,11 @@ def __init__( """ Args: output_dir: output image directory. - output_postfix: a string appended to all output file names. - output_ext: output file extension name. + output_postfix: a string appended to all output file names, default to `seg`. + output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`. resample: whether to resample before saving the data array. + if saving PNG format image, based on the `spatial_shape` from metadata. + if saving NIfTI format image, based on the `original_affine` from metadata. mode: This option is used when ``resample = True``. Defaults to ``"nearest"``. - NIfTI files {``"bilinear"``, ``"nearest"``} @@ -72,8 +74,8 @@ def __init__( [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. It's used for PNG format only. dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. - If None, use the data type of input data. To be compatible with other modules, - the output data type is always ``np.float32``, it's used for Nifti format only. + If None, use the data type of input data. + It's used for Nifti format only. output_dtype: data type for saving data. Defaults to ``np.float32``, it's used for Nifti format only. batch_transform: a callable that is used to transform the ignite.engine.batch into expected format to extract the meta_data dictionary. @@ -84,27 +86,18 @@ def __init__( name: identifier of logging.logger to use, defaulting to `engine.logger`. """ - self.saver: Union[NiftiSaver, PNGSaver] - if output_ext in (".nii.gz", ".nii"): - self.saver = NiftiSaver( - output_dir=output_dir, - output_postfix=output_postfix, - output_ext=output_ext, - resample=resample, - mode=GridSampleMode(mode), - padding_mode=padding_mode, - dtype=dtype, - output_dtype=output_dtype, - ) - elif output_ext == ".png": - self.saver = PNGSaver( - output_dir=output_dir, - output_postfix=output_postfix, - output_ext=output_ext, - resample=resample, - mode=InterpolateMode(mode), - scale=scale, - ) + self._saver = SaveImage( + output_dir=output_dir, + output_postfix=output_postfix, + output_ext=output_ext, + resample=resample, + mode=mode, + padding_mode=padding_mode, + scale=scale, + dtype=dtype, + output_dtype=output_dtype, + save_batch=True, + ) self.batch_transform = batch_transform self.output_transform = output_transform @@ -131,5 +124,5 @@ def __call__(self, engine: Engine) -> None: """ meta_data = self.batch_transform(engine.state.batch) engine_output = self.output_transform(engine.state.output) - self.saver.save_batch(engine_output, meta_data) + self._saver(engine_output, meta_data) self.logger.info("saved all the model outputs into files.") diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 4dc7744755..6f7c2a4f61 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -138,8 +138,8 @@ ThresholdIntensityD, ThresholdIntensityDict, ) -from .io.array import LoadImage -from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict +from .io.array import LoadImage, SaveImage +from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .post.array import ( Activations, AsDiscrete, diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index f57b2dd27a..9c14f7a689 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -13,20 +13,24 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ -from typing import List, Optional, Sequence, Union +from typing import Dict, List, Optional, Sequence, Union import numpy as np +import torch from monai.config import DtypeLike from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader +from monai.data.nifti_saver import NiftiSaver +from monai.data.png_saver import PNGSaver from monai.transforms.compose import Transform +from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key -from monai.utils import ensure_tuple, optional_import +from monai.utils import InterpolateMode, ensure_tuple, optional_import nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") -__all__ = ["LoadImage"] +__all__ = ["LoadImage", "SaveImage"] class LoadImage(Transform): @@ -129,3 +133,93 @@ def __call__( return img_array meta_data[Key.FILENAME_OR_OBJ] = ensure_tuple(filename)[0] return img_array, meta_data + + +class SaveImage(Transform): + """ + Save transformed data into files, support NIfTI and PNG formats. + It can work for both numpy array and PyTorch Tensor in both pre-transform chain + and post transform chain. + + Args: + output_dir: output image directory. + output_postfix: a string appended to all output file names, default to `trans`. + output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`. + resample: whether to resample before saving the data array. + if saving PNG format image, based on the `spatial_shape` from metadata. + if saving NIfTI format image, based on the `original_affine` from metadata. + mode: This option is used when ``resample = True``. Defaults to ``"nearest"``. + + - NIfTI files {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} + The interpolation mode. + See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + + padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. + + - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + - PNG files + This option is ignored. + + scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling + [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. + it's used for PNG format only. + dtype: data type during resampling computation. Defaults to ``np.float64`` for best precision. + if None, use the data type of input data. To be compatible with other modules, + the output data type is always ``np.float32``. + it's used for NIfTI format only. + output_dtype: data type for saving data. Defaults to ``np.float32``. + it's used for NIfTI format only. + save_batch: whether the import image is a batch data, default to `False`. + usually pre-transforms run for channel first data, while post-transforms run for batch data. + + """ + + def __init__( + self, + output_dir: str = "./", + output_postfix: str = "trans", + output_ext: str = ".nii.gz", + resample: bool = True, + mode: Union[GridSampleMode, InterpolateMode, str] = "nearest", + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + scale: Optional[int] = None, + dtype: DtypeLike = np.float64, + output_dtype: DtypeLike = np.float32, + save_batch: bool = False, + ) -> None: + self.saver: Union[NiftiSaver, PNGSaver] + if output_ext in (".nii.gz", ".nii"): + self.saver = NiftiSaver( + output_dir=output_dir, + output_postfix=output_postfix, + output_ext=output_ext, + resample=resample, + mode=GridSampleMode(mode), + padding_mode=padding_mode, + dtype=dtype, + output_dtype=output_dtype, + ) + elif output_ext == ".png": + self.saver = PNGSaver( + output_dir=output_dir, + output_postfix=output_postfix, + output_ext=output_ext, + resample=resample, + mode=InterpolateMode(mode), + scale=scale, + ) + else: + raise ValueError(f"unsupported output extension: {output_ext}.") + + self.save_batch = save_batch + + def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None): + if self.save_batch: + self.saver.save_batch(img, meta_data) + else: + self.saver.save(img, meta_data) diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 40737374cf..d3220aa682 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -22,12 +22,16 @@ from monai.config import DtypeLike, KeysCollection from monai.data.image_reader import ImageReader from monai.transforms.compose import MapTransform -from monai.transforms.io.array import LoadImage +from monai.transforms.io.array import LoadImage, SaveImage +from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode __all__ = [ "LoadImaged", "LoadImageD", "LoadImageDict", + "SaveImaged", + "SaveImageD", + "SaveImageDict", ] @@ -106,4 +110,91 @@ def __call__(self, data, reader: Optional[ImageReader] = None): return d +class SaveImaged(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.SaveImage`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + meta_key_postfix: `key_{postfix}` was used to store the metadata in `LoadImaged`. + So need the key to extract metadata to save images, default is `meta_dict`. + The meta data is a dictionary object, if no corresponding metadata, set to `None`. + For example, for data with key `image`, the metadata by default is in `image_meta_dict`. + output_dir: output image directory. + output_postfix: a string appended to all output file names, default to `trans`. + output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`. + resample: whether to resample before saving the data array. + if saving PNG format image, based on the `spatial_shape` from metadata. + if saving NIfTI format image, based on the `original_affine` from metadata. + mode: This option is used when ``resample = True``. Defaults to ``"nearest"``. + + - NIfTI files {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} + The interpolation mode. + See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + + padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. + + - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + - PNG files + This option is ignored. + + scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling + [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. + it's used for PNG format only. + dtype: data type during resampling computation. Defaults to ``np.float64`` for best precision. + if None, use the data type of input data. To be compatible with other modules, + the output data type is always ``np.float32``. + it's used for NIfTI format only. + output_dtype: data type for saving data. Defaults to ``np.float32``. + it's used for NIfTI format only. + save_batch: whether the import image is a batch data, default to `False`. + usually pre-transforms run for channel first data, while post-transforms run for batch data. + + """ + + def __init__( + self, + keys: KeysCollection, + meta_key_postfix: str = "meta_dict", + output_dir: str = "./", + output_postfix: str = "trans", + output_ext: str = ".nii.gz", + resample: bool = True, + mode: Union[GridSampleMode, InterpolateMode, str] = "nearest", + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + scale: Optional[int] = None, + dtype: DtypeLike = np.float64, + output_dtype: DtypeLike = np.float32, + save_batch: bool = False, + ) -> None: + super().__init__(keys) + self.meta_key_postfix = meta_key_postfix + self._saver = SaveImage( + output_dir=output_dir, + output_postfix=output_postfix, + output_ext=output_ext, + resample=resample, + mode=mode, + padding_mode=padding_mode, + scale=scale, + dtype=dtype, + output_dtype=output_dtype, + save_batch=save_batch, + ) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + meta_data = d[f"{key}_{self.meta_key_postfix}"] if self.meta_key_postfix is not None else None + self._saver(img=d[key], meta_data=meta_data) + return d + + LoadImageD = LoadImageDict = LoadImaged +SaveImageD = SaveImageDict = SaveImaged diff --git a/tests/min_tests.py b/tests/min_tests.py index a1a1894ed1..999a1aeaa0 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -107,6 +107,8 @@ def run_testsuit(): "test_deepgrow_transforms", "test_deepgrow_interaction", "test_deepgrow_dataset", + "test_save_image", + "test_save_imaged", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_save_image.py b/tests/test_save_image.py new file mode 100644 index 0000000000..141960e09b --- /dev/null +++ b/tests/test_save_image.py @@ -0,0 +1,112 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import SaveImage + +TEST_CASE_0 = [ + torch.randint(0, 255, (8, 1, 2, 3, 4)), + {"filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)]}, + ".nii.gz", + False, + True, +] + +TEST_CASE_1 = [ + torch.randint(0, 255, (8, 1, 2, 3), dtype=torch.uint8), + {"filename_or_obj": ["testfile" + str(i) + ".png" for i in range(8)]}, + ".png", + False, + True, +] + +TEST_CASE_2 = [ + np.random.randint(0, 255, (8, 1, 2, 3, 4)), + {"filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)]}, + ".nii.gz", + False, + True, +] + +TEST_CASE_3 = [ + torch.randint(0, 255, (8, 1, 2, 2)), + { + "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], + "spatial_shape": [(28, 28)] * 8, + "affine": [np.diag(np.ones(4)) * 5] * 8, + "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, + }, + ".nii.gz", + True, + True, +] + +TEST_CASE_4 = [ + torch.randint(0, 255, (8, 1, 2, 3), dtype=torch.uint8), + { + "filename_or_obj": ["testfile" + str(i) + ".png" for i in range(8)], + "spatial_shape": [(28, 28)] * 8, + }, + ".png", + True, + True, +] + +TEST_CASE_5 = [ + torch.randint(0, 255, (1, 2, 3, 4)), + {"filename_or_obj": "testfile0.nii.gz"}, + ".nii.gz", + False, + False, +] + +TEST_CASE_6 = [ + torch.randint(0, 255, (1, 2, 3, 4)), + None, + ".nii.gz", + False, + False, +] + + +class TestSaveImage(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + def test_saved_content(self, test_data, meta_data, output_ext, resample, save_batch): + with tempfile.TemporaryDirectory() as tempdir: + trans = SaveImage( + output_dir=tempdir, + output_ext=output_ext, + resample=resample, + save_batch=save_batch, + ) + trans(test_data, meta_data) + + if save_batch: + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_trans" + output_ext) + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + else: + if meta_data is not None: + filepath = os.path.join("testfile0", "testfile0" + "_trans" + output_ext) + else: + filepath = os.path.join("0", "0" + "_trans" + output_ext) + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_save_imaged.py b/tests/test_save_imaged.py new file mode 100644 index 0000000000..a6ebfe0d8d --- /dev/null +++ b/tests/test_save_imaged.py @@ -0,0 +1,114 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import SaveImaged + +TEST_CASE_0 = [ + { + "img": torch.randint(0, 255, (8, 1, 2, 3, 4)), + "img_meta_dict": {"filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)]}, + }, + ".nii.gz", + False, + True, +] + +TEST_CASE_1 = [ + { + "img": torch.randint(0, 255, (8, 1, 2, 3), dtype=torch.uint8), + "img_meta_dict": {"filename_or_obj": ["testfile" + str(i) + ".png" for i in range(8)]}, + }, + ".png", + False, + True, +] + +TEST_CASE_2 = [ + { + "img": np.random.randint(0, 255, (8, 1, 2, 3, 4)), + "img_meta_dict": {"filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)]}, + }, + ".nii.gz", + False, + True, +] + +TEST_CASE_3 = [ + { + "img": torch.randint(0, 255, (8, 1, 2, 2)), + "img_meta_dict": { + "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], + "spatial_shape": [(28, 28)] * 8, + "affine": [np.diag(np.ones(4)) * 5] * 8, + "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, + }, + }, + ".nii.gz", + True, + True, +] + +TEST_CASE_4 = [ + { + "img": torch.randint(0, 255, (8, 1, 2, 3), dtype=torch.uint8), + "img_meta_dict": { + "filename_or_obj": ["testfile" + str(i) + ".png" for i in range(8)], + "spatial_shape": [(28, 28)] * 8, + }, + }, + ".png", + True, + True, +] + +TEST_CASE_5 = [ + { + "img": torch.randint(0, 255, (1, 2, 3, 4)), + "img_meta_dict": {"filename_or_obj": "testfile0.nii.gz"}, + }, + ".nii.gz", + False, + False, +] + + +class TestSaveImaged(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + def test_saved_content(self, test_data, output_ext, resample, save_batch): + with tempfile.TemporaryDirectory() as tempdir: + trans = SaveImaged( + keys="img", + output_dir=tempdir, + output_ext=output_ext, + resample=resample, + save_batch=save_batch, + ) + trans(test_data) + + if save_batch: + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_trans" + output_ext) + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + else: + filepath = os.path.join("testfile0", "testfile0" + "_trans" + output_ext) + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + + +if __name__ == "__main__": + unittest.main()