From abac45445563c7ca8e6d123767a5c13547e6a101 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 22 Feb 2021 19:58:27 +0800 Subject: [PATCH 1/7] [DLMED] add SaveImage and SaveImaged transforms Signed-off-by: Nic Ma --- docs/source/transforms.rst | 12 +++ monai/data/nifti_saver.py | 3 +- monai/data/nifti_writer.py | 3 +- monai/data/png_writer.py | 6 +- monai/handlers/segmentation_saver.py | 49 +++++------- monai/transforms/__init__.py | 11 ++- monai/transforms/io/array.py | 101 +++++++++++++++++++++++- monai/transforms/io/dictionary.py | 91 ++++++++++++++++++++- tests/min_tests.py | 2 + tests/test_save_image.py | 112 ++++++++++++++++++++++++++ tests/test_save_imaged.py | 114 +++++++++++++++++++++++++++ 11 files changed, 465 insertions(+), 39 deletions(-) create mode 100644 tests/test_save_image.py create mode 100644 tests/test_save_imaged.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 90d960a6b9..85c64c3971 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -237,6 +237,12 @@ IO :members: :special-members: __call__ +`SaveImage` +""""""""""" +.. autoclass:: SaveImage + :members: + :special-members: __call__ + Post-processing ^^^^^^^^^^^^^^^ @@ -702,6 +708,12 @@ IO (Dict) :members: :special-members: __call__ +`SaveImaged` +"""""""""""" +.. autoclass:: SaveImaged + :members: + :special-members: __call__ + Post-processing (Dict) ^^^^^^^^^^^^^^^^^^^^^^ diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index f4781f82fd..1036f620d1 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -56,8 +56,7 @@ def __init__( align_corners: Geometrically, we consider the pixels of the input as squares rather than points. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. - If None, use the data type of input data. To be compatible with other modules, - the output data type is always ``np.float32``. + If None, use the data type of input data. output_dtype: data type for saving data. Defaults to ``np.float32``. """ self.output_dir = output_dir diff --git a/monai/data/nifti_writer.py b/monai/data/nifti_writer.py index 6837ebeb90..cdcd28f2fe 100644 --- a/monai/data/nifti_writer.py +++ b/monai/data/nifti_writer.py @@ -85,8 +85,7 @@ def write_nifti( align_corners: Geometrically, we consider the pixels of the input as squares rather than points. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. - If None, use the data type of input data. To be compatible with other modules, - the output data type is always ``np.float32``. + If None, use the data type of input data. output_dtype: data type for saving data. Defaults to ``np.float32``. """ if not isinstance(data, np.ndarray): diff --git a/monai/data/png_writer.py b/monai/data/png_writer.py index d7baa6ea79..8cee0dd1c5 100644 --- a/monai/data/png_writer.py +++ b/monai/data/png_writer.py @@ -13,7 +13,7 @@ import numpy as np -from monai.transforms import Resize +from monai.transforms.spatial.array import Resize from monai.utils import InterpolateMode, ensure_tuple_rep, optional_import Image, _ = optional_import("PIL", name="Image") @@ -76,6 +76,10 @@ def write_png( else: raise ValueError(f"Unsupported scale: {scale}, available options are [255, 65535]") + # PNG data must be int number + if data.dtype not in (np.uint8, np.uint16): + data = data.astype(np.uint8) + img = Image.fromarray(data) img.save(file_name, "PNG") return diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index c712ce9a9e..14ce023f71 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -14,7 +14,7 @@ import numpy as np -from monai.data import NiftiSaver, PNGSaver +from monai.transforms import SaveImage from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, exact_version, optional_import Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") @@ -47,9 +47,11 @@ def __init__( """ Args: output_dir: output image directory. - output_postfix: a string appended to all output file names. - output_ext: output file extension name. - resample: whether to resample before saving the data array. + output_postfix: a string appended to all output file names, default to `seg`. + output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`. + resample: whether to resample before saving the data array. + if saving PNG format image, based on the `spatial_shape` from metadata. + if saving NIfTI format image, based on the `original_affine` from metadata. mode: This option is used when ``resample = True``. Defaults to ``"nearest"``. - NIfTI files {``"bilinear"``, ``"nearest"``} @@ -71,8 +73,8 @@ def __init__( [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. It's used for PNG format only. dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. - If None, use the data type of input data. To be compatible with other modules, - the output data type is always ``np.float32``, it's used for Nifti format only. + If None, use the data type of input data. + It's used for Nifti format only. output_dtype: data type for saving data. Defaults to ``np.float32``, it's used for Nifti format only. batch_transform: a callable that is used to transform the ignite.engine.batch into expected format to extract the meta_data dictionary. @@ -83,27 +85,18 @@ def __init__( name: identifier of logging.logger to use, defaulting to `engine.logger`. """ - self.saver: Union[NiftiSaver, PNGSaver] - if output_ext in (".nii.gz", ".nii"): - self.saver = NiftiSaver( - output_dir=output_dir, - output_postfix=output_postfix, - output_ext=output_ext, - resample=resample, - mode=GridSampleMode(mode), - padding_mode=padding_mode, - dtype=dtype, - output_dtype=output_dtype, - ) - elif output_ext == ".png": - self.saver = PNGSaver( - output_dir=output_dir, - output_postfix=output_postfix, - output_ext=output_ext, - resample=resample, - mode=InterpolateMode(mode), - scale=scale, - ) + self._saver = SaveImage( + output_dir=output_dir, + output_postfix=output_postfix, + output_ext=output_ext, + resample=resample, + mode=mode, + padding_mode=padding_mode, + scale=scale, + dtype=dtype, + output_dtype=output_dtype, + save_batch=True, + ) self.batch_transform = batch_transform self.output_transform = output_transform @@ -130,5 +123,5 @@ def __call__(self, engine: Engine) -> None: """ meta_data = self.batch_transform(engine.state.batch) engine_output = self.output_transform(engine.state.output) - self.saver.save_batch(engine_output, meta_data) + self._saver(engine_output, meta_data) self.logger.info("saved all the model outputs into files.") diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 9eaedd6b15..45c96fb224 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -138,8 +138,15 @@ ThresholdIntensityD, ThresholdIntensityDict, ) -from .io.array import LoadImage -from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict +from .io.array import LoadImage, SaveImage +from .io.dictionary import ( + LoadImaged, + LoadImageD, + LoadImageDict, + SaveImaged, + SaveImageD, + SaveImageDict, +) from .post.array import ( Activations, AsDiscrete, diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 3b359cc460..151791e610 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -13,18 +13,27 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ -from typing import List, Optional, Sequence, Union +from typing import List, Dict, Optional, Sequence, Union import numpy as np +import torch +from monai.data.nifti_saver import NiftiSaver +from monai.data.png_saver import PNGSaver from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader from monai.transforms.compose import Transform -from monai.utils import ensure_tuple, optional_import +from monai.utils import ( + ensure_tuple, + optional_import, + GridSampleMode, + GridSamplePadMode, + InterpolateMode, +) nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") -__all__ = ["LoadImage"] +__all__ = ["LoadImage", "SaveImage"] class LoadImage(Transform): @@ -127,3 +136,89 @@ def __call__( return img_array meta_data["filename_or_obj"] = ensure_tuple(filename)[0] return img_array, meta_data + + +class SaveImage(Transform): + """ + Save transformed data into files, support NIfTI and PNG formats. + It can work for both numpy array and PyTorch Tensor in both pre-transform chain + and post transform chain. + + Args: + output_dir: output image directory. + output_postfix: a string appended to all output file names, default to `trans`. + output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`. + resample: whether to resample before saving the data array. + if saving PNG format image, based on the `spatial_shape` from metadata. + if saving NIfTI format image, based on the `original_affine` from metadata. + mode: This option is used when ``resample = True``. Defaults to ``"nearest"``. + + - NIfTI files {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} + The interpolation mode. + See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + + padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. + + - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + - PNG files + This option is ignored. + + scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling + [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. + it's used for PNG format only. + dtype: data type during resampling computation. Defaults to ``np.float64`` for best precision. + if None, use the data type of input data. To be compatible with other modules, + the output data type is always ``np.float32``. + it's used for NIfTI format only. + output_dtype: data type for saving data. Defaults to ``np.float32``. + it's used for NIfTI format only. + save_batch: whether the import image is a batch data, default to `False`. + usually pre-transforms run for channel first data, while post-transforms run for batch data. + + """ + def __init__( + self, + output_dir: str = "./", + output_postfix: str = "trans", + output_ext: str = ".nii.gz", + resample: bool = True, + mode: Union[GridSampleMode, InterpolateMode, str] = "nearest", + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + scale: Optional[int] = None, + dtype: Optional[np.dtype] = np.float64, + output_dtype: Optional[np.dtype] = np.float32, + save_batch: bool = False, + ) -> None: + self.saver: Union[NiftiSaver, PNGSaver] + if output_ext in (".nii.gz", ".nii"): + self.saver = NiftiSaver( + output_dir=output_dir, + output_postfix=output_postfix, + output_ext=output_ext, + resample=resample, + mode=GridSampleMode(mode), + padding_mode=padding_mode, + dtype=dtype, + output_dtype=output_dtype, + ) + elif output_ext == ".png": + self.saver = PNGSaver( + output_dir=output_dir, + output_postfix=output_postfix, + output_ext=output_ext, + resample=resample, + mode=InterpolateMode(mode), + scale=scale, + ) + self.save_batch = save_batch + + def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None): + if self.save_batch: + self.saver.save_batch(img, meta_data) + else: + self.saver.save(img, meta_data) diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 62ac4c8562..bd20e3c75f 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -22,12 +22,16 @@ from monai.config import KeysCollection from monai.data.image_reader import ImageReader from monai.transforms.compose import MapTransform -from monai.transforms.io.array import LoadImage +from monai.transforms.io.array import LoadImage, SaveImage +from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode __all__ = [ "LoadImaged", "LoadImageD", "LoadImageDict", + "SaveImaged", + "SaveImageD", + "SaveImageDict", ] @@ -106,4 +110,89 @@ def __call__(self, data, reader: Optional[ImageReader] = None): return d +class SaveImaged(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.SaveImage`, + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + meta_key_postfix: `key_{postfix}` was used to store the metadata in `LoadImaged`. + So need the key to extract metadata to save images, default is `meta_dict`. + The meta data is a dictionary object, if no corresponding metadata, set to `None`. + For example, for data with key `image`, the metadata by default is in `image_meta_dict`. + output_dir: output image directory. + output_postfix: a string appended to all output file names, default to `trans`. + output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`. + resample: whether to resample before saving the data array. + if saving PNG format image, based on the `spatial_shape` from metadata. + if saving NIfTI format image, based on the `original_affine` from metadata. + mode: This option is used when ``resample = True``. Defaults to ``"nearest"``. + + - NIfTI files {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} + The interpolation mode. + See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + + padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. + + - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + - PNG files + This option is ignored. + + scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling + [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. + it's used for PNG format only. + dtype: data type during resampling computation. Defaults to ``np.float64`` for best precision. + if None, use the data type of input data. To be compatible with other modules, + the output data type is always ``np.float32``. + it's used for NIfTI format only. + output_dtype: data type for saving data. Defaults to ``np.float32``. + it's used for NIfTI format only. + save_batch: whether the import image is a batch data, default to `False`. + usually pre-transforms run for channel first data, while post-transforms run for batch data. + + """ + def __init__( + self, + keys: KeysCollection, + meta_key_postfix: str = "meta_dict", + output_dir: str = "./", + output_postfix: str = "trans", + output_ext: str = ".nii.gz", + resample: bool = True, + mode: Union[GridSampleMode, InterpolateMode, str] = "nearest", + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + scale: Optional[int] = None, + dtype: Optional[np.dtype] = np.float64, + output_dtype: Optional[np.dtype] = np.float32, + save_batch: bool = False, + ) -> None: + super().__init__(keys) + self.meta_key_postfix = meta_key_postfix + self._saver = SaveImage( + output_dir=output_dir, + output_postfix=output_postfix, + output_ext=output_ext, + resample=resample, + mode=mode, + padding_mode=padding_mode, + scale=scale, + dtype=dtype, + output_dtype=output_dtype, + save_batch=save_batch, + ) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + meta_data = d[f"{key}_{self.meta_key_postfix}"] if self.meta_key_postfix is not None else None + self._saver(img=d[key], meta_data=meta_data) + return d + + LoadImageD = LoadImageDict = LoadImaged +SaveImageD = SaveImageDict = SaveImaged diff --git a/tests/min_tests.py b/tests/min_tests.py index 665ead6cc6..e0a274b789 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -103,6 +103,8 @@ def run_testsuit(): "test_handler_metrics_saver", "test_handler_metrics_saver_dist", "test_evenly_divisible_all_gather_dist", + "test_save_image", + "test_save_imaged", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_save_image.py b/tests/test_save_image.py new file mode 100644 index 0000000000..141960e09b --- /dev/null +++ b/tests/test_save_image.py @@ -0,0 +1,112 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import SaveImage + +TEST_CASE_0 = [ + torch.randint(0, 255, (8, 1, 2, 3, 4)), + {"filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)]}, + ".nii.gz", + False, + True, +] + +TEST_CASE_1 = [ + torch.randint(0, 255, (8, 1, 2, 3), dtype=torch.uint8), + {"filename_or_obj": ["testfile" + str(i) + ".png" for i in range(8)]}, + ".png", + False, + True, +] + +TEST_CASE_2 = [ + np.random.randint(0, 255, (8, 1, 2, 3, 4)), + {"filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)]}, + ".nii.gz", + False, + True, +] + +TEST_CASE_3 = [ + torch.randint(0, 255, (8, 1, 2, 2)), + { + "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], + "spatial_shape": [(28, 28)] * 8, + "affine": [np.diag(np.ones(4)) * 5] * 8, + "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, + }, + ".nii.gz", + True, + True, +] + +TEST_CASE_4 = [ + torch.randint(0, 255, (8, 1, 2, 3), dtype=torch.uint8), + { + "filename_or_obj": ["testfile" + str(i) + ".png" for i in range(8)], + "spatial_shape": [(28, 28)] * 8, + }, + ".png", + True, + True, +] + +TEST_CASE_5 = [ + torch.randint(0, 255, (1, 2, 3, 4)), + {"filename_or_obj": "testfile0.nii.gz"}, + ".nii.gz", + False, + False, +] + +TEST_CASE_6 = [ + torch.randint(0, 255, (1, 2, 3, 4)), + None, + ".nii.gz", + False, + False, +] + + +class TestSaveImage(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + def test_saved_content(self, test_data, meta_data, output_ext, resample, save_batch): + with tempfile.TemporaryDirectory() as tempdir: + trans = SaveImage( + output_dir=tempdir, + output_ext=output_ext, + resample=resample, + save_batch=save_batch, + ) + trans(test_data, meta_data) + + if save_batch: + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_trans" + output_ext) + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + else: + if meta_data is not None: + filepath = os.path.join("testfile0", "testfile0" + "_trans" + output_ext) + else: + filepath = os.path.join("0", "0" + "_trans" + output_ext) + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_save_imaged.py b/tests/test_save_imaged.py new file mode 100644 index 0000000000..a6ebfe0d8d --- /dev/null +++ b/tests/test_save_imaged.py @@ -0,0 +1,114 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import SaveImaged + +TEST_CASE_0 = [ + { + "img": torch.randint(0, 255, (8, 1, 2, 3, 4)), + "img_meta_dict": {"filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)]}, + }, + ".nii.gz", + False, + True, +] + +TEST_CASE_1 = [ + { + "img": torch.randint(0, 255, (8, 1, 2, 3), dtype=torch.uint8), + "img_meta_dict": {"filename_or_obj": ["testfile" + str(i) + ".png" for i in range(8)]}, + }, + ".png", + False, + True, +] + +TEST_CASE_2 = [ + { + "img": np.random.randint(0, 255, (8, 1, 2, 3, 4)), + "img_meta_dict": {"filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)]}, + }, + ".nii.gz", + False, + True, +] + +TEST_CASE_3 = [ + { + "img": torch.randint(0, 255, (8, 1, 2, 2)), + "img_meta_dict": { + "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], + "spatial_shape": [(28, 28)] * 8, + "affine": [np.diag(np.ones(4)) * 5] * 8, + "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, + }, + }, + ".nii.gz", + True, + True, +] + +TEST_CASE_4 = [ + { + "img": torch.randint(0, 255, (8, 1, 2, 3), dtype=torch.uint8), + "img_meta_dict": { + "filename_or_obj": ["testfile" + str(i) + ".png" for i in range(8)], + "spatial_shape": [(28, 28)] * 8, + }, + }, + ".png", + True, + True, +] + +TEST_CASE_5 = [ + { + "img": torch.randint(0, 255, (1, 2, 3, 4)), + "img_meta_dict": {"filename_or_obj": "testfile0.nii.gz"}, + }, + ".nii.gz", + False, + False, +] + + +class TestSaveImaged(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + def test_saved_content(self, test_data, output_ext, resample, save_batch): + with tempfile.TemporaryDirectory() as tempdir: + trans = SaveImaged( + keys="img", + output_dir=tempdir, + output_ext=output_ext, + resample=resample, + save_batch=save_batch, + ) + trans(test_data) + + if save_batch: + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_trans" + output_ext) + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + else: + filepath = os.path.join("testfile0", "testfile0" + "_trans" + output_ext) + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + + +if __name__ == "__main__": + unittest.main() From 6795d65db7fdb0f7a48275796cdcdc5e0586b31d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 22 Feb 2021 20:11:02 +0800 Subject: [PATCH 2/7] [DLMED] fix typo Signed-off-by: Nic Ma --- monai/handlers/segmentation_saver.py | 10 +++++----- monai/transforms/io/array.py | 4 ++-- monai/transforms/io/dictionary.py | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index 96c1d073d2..a46918b893 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -48,11 +48,11 @@ def __init__( """ Args: output_dir: output image directory. - output_postfix: a string appended to all output file names, default to `seg`. - output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`. - resample: whether to resample before saving the data array. - if saving PNG format image, based on the `spatial_shape` from metadata. - if saving NIfTI format image, based on the `original_affine` from metadata. + output_postfix: a string appended to all output file names, default to `seg`. + output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`. + resample: whether to resample before saving the data array. + if saving PNG format image, based on the `spatial_shape` from metadata. + if saving NIfTI format image, based on the `original_affine` from metadata. mode: This option is used when ``resample = True``. Defaults to ``"nearest"``. - NIfTI files {``"bilinear"``, ``"nearest"``} diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 4926a37ad1..9d07cb95a8 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -192,8 +192,8 @@ def __init__( mode: Union[GridSampleMode, InterpolateMode, str] = "nearest", padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, scale: Optional[int] = None, - dtype: Optional[np.dtype] = np.float64, - output_dtype: Optional[np.dtype] = np.float32, + dtype: DtypeLike = np.float64, + output_dtype: DtypeLike = np.float32, save_batch: bool = False, ) -> None: self.saver: Union[NiftiSaver, PNGSaver] diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 13ffacea9c..33d056bafc 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -167,8 +167,8 @@ def __init__( mode: Union[GridSampleMode, InterpolateMode, str] = "nearest", padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, scale: Optional[int] = None, - dtype: Optional[np.dtype] = np.float64, - output_dtype: Optional[np.dtype] = np.float32, + dtype: DtypeLike = np.float64, + output_dtype: DtypeLike = np.float32, save_batch: bool = False, ) -> None: super().__init__(keys) From 09e20c4553f3841940a6f849d385ba35c9db1486 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 22 Feb 2021 20:12:21 +0800 Subject: [PATCH 3/7] [DLMED] fix flake8 issues Signed-off-by: Nic Ma --- monai/transforms/__init__.py | 9 +-------- monai/transforms/io/array.py | 16 ++++++---------- monai/transforms/io/dictionary.py | 1 + 3 files changed, 8 insertions(+), 18 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 0047b460b4..6f7c2a4f61 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -139,14 +139,7 @@ ThresholdIntensityDict, ) from .io.array import LoadImage, SaveImage -from .io.dictionary import ( - LoadImaged, - LoadImageD, - LoadImageDict, - SaveImaged, - SaveImageD, - SaveImageDict, -) +from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .post.array import ( Activations, AsDiscrete, diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 9d07cb95a8..10aa8303b5 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -13,24 +13,19 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ -from typing import List, Dict, Optional, Sequence, Union +from typing import Dict, List, Optional, Sequence, Union import numpy as np import torch from monai.config import DtypeLike -from monai.utils import ImageMetaKey as Key +from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader from monai.data.nifti_saver import NiftiSaver from monai.data.png_saver import PNGSaver -from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader from monai.transforms.compose import Transform -from monai.utils import ( - ensure_tuple, - optional_import, - GridSampleMode, - GridSamplePadMode, - InterpolateMode, -) +from monai.utils import GridSampleMode, GridSamplePadMode +from monai.utils import ImageMetaKey as Key +from monai.utils import InterpolateMode, ensure_tuple, optional_import nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") @@ -183,6 +178,7 @@ class SaveImage(Transform): usually pre-transforms run for channel first data, while post-transforms run for batch data. """ + def __init__( self, output_dir: str = "./", diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 33d056bafc..bcbcaee0ce 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -156,6 +156,7 @@ class SaveImaged(MapTransform): usually pre-transforms run for channel first data, while post-transforms run for batch data. """ + def __init__( self, keys: KeysCollection, From 0476661676fc107ba863a65e152da35a2ae4adc3 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 22 Feb 2021 20:42:44 +0800 Subject: [PATCH 4/7] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/data/png_writer.py | 2 +- monai/transforms/io/array.py | 3 +++ monai/transforms/io/dictionary.py | 3 ++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/monai/data/png_writer.py b/monai/data/png_writer.py index c533266116..752f45f612 100644 --- a/monai/data/png_writer.py +++ b/monai/data/png_writer.py @@ -77,7 +77,7 @@ def write_png( raise ValueError(f"Unsupported scale: {scale}, available options are [255, 65535]") # PNG data must be int number - if data.dtype not in (np.uint8, np.uint16): + if not isinstance(data.dtype, (np.uint8, np.uint16)): data = data.astype(np.uint8) img = Image.fromarray(data) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 10aa8303b5..9c14f7a689 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -213,6 +213,9 @@ def __init__( mode=InterpolateMode(mode), scale=scale, ) + else: + raise ValueError(f"unsupported output extension: {output_ext}.") + self.save_batch = save_batch def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None): diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index bcbcaee0ce..d3220aa682 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -112,7 +112,8 @@ def __call__(self, data, reader: Optional[ImageReader] = None): class SaveImaged(MapTransform): """ - Dictionary-based wrapper of :py:class:`monai.transforms.SaveImage`, + Dictionary-based wrapper of :py:class:`monai.transforms.SaveImage`. + Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` From 1c5199c12d906183495ca78c32ce6552d843451b Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 23 Feb 2021 00:45:35 +0800 Subject: [PATCH 5/7] [DLMED] fix CI tests Signed-off-by: Nic Ma --- monai/data/png_writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/png_writer.py b/monai/data/png_writer.py index 752f45f612..54f5c317e9 100644 --- a/monai/data/png_writer.py +++ b/monai/data/png_writer.py @@ -77,7 +77,7 @@ def write_png( raise ValueError(f"Unsupported scale: {scale}, available options are [255, 65535]") # PNG data must be int number - if not isinstance(data.dtype, (np.uint8, np.uint16)): + if data.dtype != np.uint8 and data.dtype != np.uint16: data = data.astype(np.uint8) img = Image.fromarray(data) From 6561fab6495f4c4dfdd89f353c42842062aa7a58 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 23 Feb 2021 07:47:08 +0800 Subject: [PATCH 6/7] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/data/png_writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/png_writer.py b/monai/data/png_writer.py index 54f5c317e9..c533266116 100644 --- a/monai/data/png_writer.py +++ b/monai/data/png_writer.py @@ -77,7 +77,7 @@ def write_png( raise ValueError(f"Unsupported scale: {scale}, available options are [255, 65535]") # PNG data must be int number - if data.dtype != np.uint8 and data.dtype != np.uint16: + if data.dtype not in (np.uint8, np.uint16): data = data.astype(np.uint8) img = Image.fromarray(data) From abcd851479159c6c6ed1099d9c1c5113946c7fc4 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 23 Feb 2021 08:10:57 +0800 Subject: [PATCH 7/7] [DLMED] add ignore Signed-off-by: Nic Ma --- monai/data/png_writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/png_writer.py b/monai/data/png_writer.py index c533266116..9ce01ed97f 100644 --- a/monai/data/png_writer.py +++ b/monai/data/png_writer.py @@ -77,7 +77,7 @@ def write_png( raise ValueError(f"Unsupported scale: {scale}, available options are [255, 65535]") # PNG data must be int number - if data.dtype not in (np.uint8, np.uint16): + if data.dtype not in (np.uint8, np.uint16): # type: ignore data = data.astype(np.uint8) img = Image.fromarray(data)