Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,23 @@ class ITKReader(ImageReader):

series_name: the name of the DICOM series if there are multiple ones.
used when loading DICOM series.
reverse_indexing: whether to use a reversed spatial indexing convention for the returned data array.
If ``False``, the spatial indexing follows the numpy convention;
otherwise, the spatial indexing convention is reversed to be compatible with ITK. Default is ``False``.
This option does not affect the metadata.
kwargs: additional args for `itk.imread` API. more details about available args:
https://github.com/InsightSoftwareConsortium/ITK/blob/master/Wrapping/Generators/Python/itk/support/extras.py

"""

def __init__(self, channel_dim: Optional[int] = None, series_name: str = "", **kwargs):
def __init__(
self, channel_dim: Optional[int] = None, series_name: str = "", reverse_indexing: bool = False, **kwargs
):
super().__init__()
self.kwargs = kwargs
self.channel_dim = channel_dim
self.series_name = series_name
self.reverse_indexing = reverse_indexing

def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool:
"""
Expand Down Expand Up @@ -308,19 +315,21 @@ def _get_array_data(self, img):

Following PyTorch conventions, the returned array data has contiguous channels,
e.g. for an RGB image, all red channel image pixels are contiguous in memory.
The first axis of the returned array is the channel axis.
The last axis of the returned array is the channel axis.

See also:

- https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Modules/Bridge/NumPy/wrapping/PyBuffer.i.in

Args:
img: an ITK image object loaded from an image file.

"""
channels = img.GetNumberOfComponentsPerPixel()
np_data = itk.array_view_from_image(img).T
if channels == 1:
return np_data
if channels != np_data.shape[0]:
warnings.warn("itk_img.GetNumberOfComponentsPerPixel != numpy data channels")
return np.moveaxis(np_data, 0, -1) # channel last is compatible with `write_nifti`
np_img = itk.array_view_from_image(img, keep_axes=False)
if img.GetNumberOfComponentsPerPixel() == 1: # handling spatial images
return np_img if self.reverse_indexing else np_img.T
# handling multi-channel images
return np_img if self.reverse_indexing else np.moveaxis(np_img.T, 0, -1)


@require_pkg(pkg_name="nibabel")
Expand Down
29 changes: 21 additions & 8 deletions tests/test_load_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,21 @@ def get_data(self, _obj):
{"image_only": False, "reader": ITKReader(pixel_type=itk.UC)},
"tests/testing_data/CT_DICOM",
(16, 16, 4),
(16, 16, 4),
]

TEST_CASE_11 = [
{"image_only": False, "reader": "ITKReader", "pixel_type": itk.UC},
"tests/testing_data/CT_DICOM",
(16, 16, 4),
(16, 16, 4),
]

TEST_CASE_12 = [
{"image_only": False, "reader": "ITKReader", "pixel_type": itk.UC, "reverse_indexing": True},
"tests/testing_data/CT_DICOM",
(16, 16, 4),
(4, 16, 16),
]


Expand Down Expand Up @@ -138,8 +147,8 @@ def test_itk_reader(self, input_param, filenames, expected_shape):
np.testing.assert_allclose(header["original_affine"], np_diag)
self.assertTupleEqual(result.shape, expected_shape)

@parameterized.expand([TEST_CASE_10, TEST_CASE_11])
def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape):
@parameterized.expand([TEST_CASE_10, TEST_CASE_11, TEST_CASE_12])
def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape, expected_np_shape):
result, header = LoadImage(**input_param)(filenames)
self.assertTrue("affine" in header)
self.assertEqual(header["filename_or_obj"], f"{Path(filenames)}")
Expand All @@ -154,21 +163,23 @@ def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape):
]
),
)
self.assertTupleEqual(result.shape, expected_shape)
self.assertTupleEqual(tuple(header["spatial_shape"]), expected_shape)
self.assertTupleEqual(result.shape, expected_np_shape)

def test_itk_reader_multichannel(self):
test_image = np.random.randint(0, 256, size=(256, 224, 3)).astype("uint8")
with tempfile.TemporaryDirectory() as tempdir:
filename = os.path.join(tempdir, "test_image.png")
itk_np_view = itk.image_view_from_array(test_image, is_vector=True)
itk.imwrite(itk_np_view, filename)
result, header = LoadImage(reader=ITKReader())(Path(filename))
for flag in (False, True):
result, header = LoadImage(reader=ITKReader(reverse_indexing=flag))(Path(filename))

self.assertTupleEqual(tuple(header["spatial_shape"]), (224, 256))
np.testing.assert_allclose(result[:, :, 0], test_image[:, :, 0].T)
np.testing.assert_allclose(result[:, :, 1], test_image[:, :, 1].T)
np.testing.assert_allclose(result[:, :, 2], test_image[:, :, 2].T)
self.assertTupleEqual(tuple(header["spatial_shape"]), (224, 256))
test_image = test_image.transpose(1, 0, 2)
np.testing.assert_allclose(result[:, :, 0], test_image[:, :, 0])
np.testing.assert_allclose(result[:, :, 1], test_image[:, :, 1])
np.testing.assert_allclose(result[:, :, 2], test_image[:, :, 2])

def test_load_nifti_multichannel(self):
test_image = np.random.randint(0, 256, size=(31, 64, 16, 2)).astype(np.float32)
Expand All @@ -185,6 +196,8 @@ def test_load_nifti_multichannel(self):
self.assertTupleEqual(tuple(nib_header["spatial_shape"]), (16, 64, 31))
self.assertTupleEqual(tuple(nib_image.shape), (16, 64, 31, 2))

np.testing.assert_allclose(itk_img, nib_image, atol=1e-3, rtol=1e-3)

def test_load_png(self):
spatial_size = (256, 224)
test_image = np.random.randint(0, 256, size=spatial_size)
Expand Down