diff --git a/monai/data/readers/niftireader.py b/monai/data/readers/niftireader.py index 6dafd9e52f..1898411694 100644 --- a/monai/data/readers/niftireader.py +++ b/monai/data/readers/niftireader.py @@ -32,25 +32,32 @@ def load_nifti(filename_or_obj, as_closest_canonical=False, image_only=True, dty Returns: The loaded image volume if `image_only` is True, or a tuple containing the volume and the Nifti header in dict format otherwise + + Note: + header['original_affine'] stores the original affine loaded from `filename_or_obj`. + header['affine'] stores the affine after the optional `as_closest_canonical` transform. """ img = nib.load(filename_or_obj) + header = dict(img.header) + header['filename_or_obj'] = filename_or_obj + header['original_affine'] = img.affine + header['affine'] = img.affine + header['as_closest_canonical'] = as_closest_canonical + if as_closest_canonical: img = nib.as_closest_canonical(img) + header['affine'] = img.affine if dtype is not None: dat = img.get_fdata(dtype=dtype) else: dat = np.asanyarray(img.dataobj) - header = dict(img.header) - header['filename_or_obj'] = filename_or_obj - if image_only: return dat - else: - return dat, header + return dat, header @export("monai.data.readers") @@ -77,7 +84,7 @@ def __init__(self, image_files, seg_files, transform=None, seg_transform=None, i self.image_files = image_files self.seg_files = seg_files - self.transform = transform + self.transform = transform self.seg_transform = seg_transform self.image_only = image_only diff --git a/monai/data/writers/niftiwriter.py b/monai/data/writers/niftiwriter.py index 6a97b42c46..44e454d91a 100644 --- a/monai/data/writers/niftiwriter.py +++ b/monai/data/writers/niftiwriter.py @@ -13,26 +13,29 @@ import nibabel as nib -def write_nifti(data, affine, file_name, revert_canonical, dtype="float32"): +def write_nifti(data, affine, file_name, target_affine=None, dtype="float32"): """Write numpy data into nifti files to disk. Args: data (numpy.ndarray): input data to write to file. affine (numpy.ndarray): affine information for the data. file_name (string): expected file name that saved on disk. - revert_canonical (bool): whether to revert canonical. + target_affine (numpy.ndarray, optional): + before saving the (data, affine), transform the data into the orientation defined by `target_affine`. dtype (np.dtype, optional): convert the image to save to this data type. - """ assert isinstance(data, np.ndarray), 'input data must be numpy array.' if affine is None: affine = np.eye(4) - if revert_canonical: - codes = nib.orientations.axcodes2ornt(nib.orientations.aff2axcodes(np.linalg.inv(affine))) - reverted_results = nib.orientations.apply_orientation(np.squeeze(data), codes) - results_img = nib.Nifti1Image(reverted_results.astype(dtype), affine) + if target_affine is None: + results_img = nib.Nifti1Image(data.astype(dtype), affine) else: - results_img = nib.Nifti1Image(np.squeeze(data).astype(dtype), np.squeeze(affine)) + start_ornt = nib.orientations.io_orientation(affine) + target_ornt = nib.orientations.io_orientation(target_affine) + ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt) + + reverted_results = nib.orientations.apply_orientation(data, ornt_transform) + results_img = nib.Nifti1Image(reverted_results.astype(dtype), target_affine) nib.save(results_img, file_name) diff --git a/tests/test_nifti_rw.py b/tests/test_nifti_rw.py new file mode 100644 index 0000000000..40ecb1fba7 --- /dev/null +++ b/tests/test_nifti_rw.py @@ -0,0 +1,65 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import nibabel as nib +import numpy as np +from parameterized import parameterized + +from monai.data.readers.niftireader import load_nifti +from monai.data.writers.niftiwriter import write_nifti + +from .utils import make_nifti_image + +TEST_IMAGE = np.zeros((1, 2, 3)) +TEST_AFFINE = np.array([[-5.3, 0., 0., 102.01], [0., 0.52, 2.17, -7.50], [-0., 1.98, -0.26, -23.12], [0., 0., 0., 1.]]) + +TEST_CASE_1 = [TEST_IMAGE, TEST_AFFINE, (1, 2, 3), dict(as_closest_canonical=True, image_only=False)] +TEST_CASE_2 = [TEST_IMAGE, TEST_AFFINE, (1, 3, 2), dict(as_closest_canonical=True, image_only=True)] +TEST_CASE_3 = [TEST_IMAGE, TEST_AFFINE, (1, 2, 3), dict(as_closest_canonical=False, image_only=True)] +TEST_CASE_4 = [TEST_IMAGE, TEST_AFFINE, (1, 2, 3), dict(as_closest_canonical=False, image_only=False)] + + +class TestNiftiLoadRead(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + def test_orientation(self, array, affine, expected_shape, reader_param): + test_image = make_nifti_image(array, affine) + + # read test cases + load_result = load_nifti(test_image, **reader_param) + if isinstance(load_result, tuple): + data_array, header = load_result + else: + data_array = load_result + header = None + if os.path.exists(test_image): + os.remove(test_image) + + # write test cases + if header is not None: + write_nifti(data_array, header['affine'], test_image, header['original_affine']) + else: + write_nifti(data_array, affine, test_image) + saved = nib.load(test_image) + saved_affine = saved.affine + saved_shape = saved.get_fdata().shape + if os.path.exists(test_image): + os.remove(test_image) + + self.assertTrue(np.allclose(saved_affine, affine)) + self.assertTrue(np.allclose(saved_shape, expected_shape)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/utils.py b/tests/utils.py index c2b6de1580..e356b8b7b3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,6 +11,8 @@ import os import unittest +import tempfile +import nibabel as nib import numpy as np import torch @@ -26,6 +28,18 @@ def skip_if_quick(obj): return unittest.skipIf(is_quick, "Skipping slow tests")(obj) +def make_nifti_image(array, affine): + """ + Create a temporary nifti image on the disk and return the image name. + User is responsible for deleting the temporary file when done with it. + """ + test_image = nib.Nifti1Image(array, affine) + + _, image_name = tempfile.mkstemp(suffix='.nii.gz') + nib.save(test_image, image_name) + return image_name + + class NumpyImageTestCase2D(unittest.TestCase): im_shape = (128, 128) input_channels = 1