Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions monai/data/readers/niftireader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,32 @@ def load_nifti(filename_or_obj, as_closest_canonical=False, image_only=True, dty
Returns:
The loaded image volume if `image_only` is True, or a tuple containing the volume and the Nifti
header in dict format otherwise

Note:
header['original_affine'] stores the original affine loaded from `filename_or_obj`.
header['affine'] stores the affine after the optional `as_closest_canonical` transform.
"""

img = nib.load(filename_or_obj)

header = dict(img.header)
header['filename_or_obj'] = filename_or_obj
header['original_affine'] = img.affine
header['affine'] = img.affine
header['as_closest_canonical'] = as_closest_canonical

if as_closest_canonical:
img = nib.as_closest_canonical(img)
header['affine'] = img.affine

if dtype is not None:
dat = img.get_fdata(dtype=dtype)
else:
dat = np.asanyarray(img.dataobj)

header = dict(img.header)
header['filename_or_obj'] = filename_or_obj

if image_only:
return dat
else:
return dat, header
return dat, header


@export("monai.data.readers")
Expand All @@ -77,7 +84,7 @@ def __init__(self, image_files, seg_files, transform=None, seg_transform=None, i

self.image_files = image_files
self.seg_files = seg_files
self.transform = transform
self.transform = transform
self.seg_transform = seg_transform
self.image_only = image_only

Expand Down
19 changes: 11 additions & 8 deletions monai/data/writers/niftiwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,29 @@
import nibabel as nib


def write_nifti(data, affine, file_name, revert_canonical, dtype="float32"):
def write_nifti(data, affine, file_name, target_affine=None, dtype="float32"):
"""Write numpy data into nifti files to disk.

Args:
data (numpy.ndarray): input data to write to file.
affine (numpy.ndarray): affine information for the data.
file_name (string): expected file name that saved on disk.
revert_canonical (bool): whether to revert canonical.
target_affine (numpy.ndarray, optional):
before saving the (data, affine), transform the data into the orientation defined by `target_affine`.
dtype (np.dtype, optional): convert the image to save to this data type.

"""
assert isinstance(data, np.ndarray), 'input data must be numpy array.'
if affine is None:
affine = np.eye(4)

if revert_canonical:
codes = nib.orientations.axcodes2ornt(nib.orientations.aff2axcodes(np.linalg.inv(affine)))
reverted_results = nib.orientations.apply_orientation(np.squeeze(data), codes)
results_img = nib.Nifti1Image(reverted_results.astype(dtype), affine)
if target_affine is None:
results_img = nib.Nifti1Image(data.astype(dtype), affine)
else:
results_img = nib.Nifti1Image(np.squeeze(data).astype(dtype), np.squeeze(affine))
start_ornt = nib.orientations.io_orientation(affine)
target_ornt = nib.orientations.io_orientation(target_affine)
ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt)

reverted_results = nib.orientations.apply_orientation(data, ornt_transform)
results_img = nib.Nifti1Image(reverted_results.astype(dtype), target_affine)

nib.save(results_img, file_name)
65 changes: 65 additions & 0 deletions tests/test_nifti_rw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest

import nibabel as nib
import numpy as np
from parameterized import parameterized

from monai.data.readers.niftireader import load_nifti
from monai.data.writers.niftiwriter import write_nifti

from .utils import make_nifti_image

TEST_IMAGE = np.zeros((1, 2, 3))
TEST_AFFINE = np.array([[-5.3, 0., 0., 102.01], [0., 0.52, 2.17, -7.50], [-0., 1.98, -0.26, -23.12], [0., 0., 0., 1.]])

TEST_CASE_1 = [TEST_IMAGE, TEST_AFFINE, (1, 2, 3), dict(as_closest_canonical=True, image_only=False)]
TEST_CASE_2 = [TEST_IMAGE, TEST_AFFINE, (1, 3, 2), dict(as_closest_canonical=True, image_only=True)]
TEST_CASE_3 = [TEST_IMAGE, TEST_AFFINE, (1, 2, 3), dict(as_closest_canonical=False, image_only=True)]
TEST_CASE_4 = [TEST_IMAGE, TEST_AFFINE, (1, 2, 3), dict(as_closest_canonical=False, image_only=False)]


class TestNiftiLoadRead(unittest.TestCase):

@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_orientation(self, array, affine, expected_shape, reader_param):
test_image = make_nifti_image(array, affine)

# read test cases
load_result = load_nifti(test_image, **reader_param)
if isinstance(load_result, tuple):
data_array, header = load_result
else:
data_array = load_result
header = None
if os.path.exists(test_image):
os.remove(test_image)

# write test cases
if header is not None:
write_nifti(data_array, header['affine'], test_image, header['original_affine'])
else:
write_nifti(data_array, affine, test_image)
saved = nib.load(test_image)
saved_affine = saved.affine
saved_shape = saved.get_fdata().shape
if os.path.exists(test_image):
os.remove(test_image)

self.assertTrue(np.allclose(saved_affine, affine))
self.assertTrue(np.allclose(saved_shape, expected_shape))


if __name__ == '__main__':
unittest.main()
14 changes: 14 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import os
import unittest
import tempfile
import nibabel as nib

import numpy as np
import torch
Expand All @@ -26,6 +28,18 @@ def skip_if_quick(obj):
return unittest.skipIf(is_quick, "Skipping slow tests")(obj)


def make_nifti_image(array, affine):
"""
Create a temporary nifti image on the disk and return the image name.
User is responsible for deleting the temporary file when done with it.
"""
test_image = nib.Nifti1Image(array, affine)

_, image_name = tempfile.mkstemp(suffix='.nii.gz')
nib.save(test_image, image_name)
return image_name


class NumpyImageTestCase2D(unittest.TestCase):
im_shape = (128, 128)
input_channels = 1
Expand Down