diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 15717d090b..9a0e26607b 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -155,16 +155,23 @@ class ITKReader(ImageReader): series_name: the name of the DICOM series if there are multiple ones. used when loading DICOM series. + reverse_indexing: whether to use a reversed spatial indexing convention for the returned data array. + If ``False``, the spatial indexing follows the numpy convention; + otherwise, the spatial indexing convention is reversed to be compatible with ITK. Default is ``False``. + This option does not affect the metadata. kwargs: additional args for `itk.imread` API. more details about available args: https://github.com/InsightSoftwareConsortium/ITK/blob/master/Wrapping/Generators/Python/itk/support/extras.py """ - def __init__(self, channel_dim: Optional[int] = None, series_name: str = "", **kwargs): + def __init__( + self, channel_dim: Optional[int] = None, series_name: str = "", reverse_indexing: bool = False, **kwargs + ): super().__init__() self.kwargs = kwargs self.channel_dim = channel_dim self.series_name = series_name + self.reverse_indexing = reverse_indexing def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ @@ -308,19 +315,21 @@ def _get_array_data(self, img): Following PyTorch conventions, the returned array data has contiguous channels, e.g. for an RGB image, all red channel image pixels are contiguous in memory. - The first axis of the returned array is the channel axis. + The last axis of the returned array is the channel axis. + + See also: + + - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Modules/Bridge/NumPy/wrapping/PyBuffer.i.in Args: img: an ITK image object loaded from an image file. """ - channels = img.GetNumberOfComponentsPerPixel() - np_data = itk.array_view_from_image(img).T - if channels == 1: - return np_data - if channels != np_data.shape[0]: - warnings.warn("itk_img.GetNumberOfComponentsPerPixel != numpy data channels") - return np.moveaxis(np_data, 0, -1) # channel last is compatible with `write_nifti` + np_img = itk.array_view_from_image(img, keep_axes=False) + if img.GetNumberOfComponentsPerPixel() == 1: # handling spatial images + return np_img if self.reverse_indexing else np_img.T + # handling multi-channel images + return np_img if self.reverse_indexing else np.moveaxis(np_img.T, 0, -1) @require_pkg(pkg_name="nibabel") diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 849fd9b4e1..5f33e99a65 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -90,12 +90,21 @@ def get_data(self, _obj): {"image_only": False, "reader": ITKReader(pixel_type=itk.UC)}, "tests/testing_data/CT_DICOM", (16, 16, 4), + (16, 16, 4), ] TEST_CASE_11 = [ {"image_only": False, "reader": "ITKReader", "pixel_type": itk.UC}, "tests/testing_data/CT_DICOM", (16, 16, 4), + (16, 16, 4), +] + +TEST_CASE_12 = [ + {"image_only": False, "reader": "ITKReader", "pixel_type": itk.UC, "reverse_indexing": True}, + "tests/testing_data/CT_DICOM", + (16, 16, 4), + (4, 16, 16), ] @@ -138,8 +147,8 @@ def test_itk_reader(self, input_param, filenames, expected_shape): np.testing.assert_allclose(header["original_affine"], np_diag) self.assertTupleEqual(result.shape, expected_shape) - @parameterized.expand([TEST_CASE_10, TEST_CASE_11]) - def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape): + @parameterized.expand([TEST_CASE_10, TEST_CASE_11, TEST_CASE_12]) + def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape, expected_np_shape): result, header = LoadImage(**input_param)(filenames) self.assertTrue("affine" in header) self.assertEqual(header["filename_or_obj"], f"{Path(filenames)}") @@ -154,8 +163,8 @@ def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape): ] ), ) - self.assertTupleEqual(result.shape, expected_shape) self.assertTupleEqual(tuple(header["spatial_shape"]), expected_shape) + self.assertTupleEqual(result.shape, expected_np_shape) def test_itk_reader_multichannel(self): test_image = np.random.randint(0, 256, size=(256, 224, 3)).astype("uint8") @@ -163,12 +172,14 @@ def test_itk_reader_multichannel(self): filename = os.path.join(tempdir, "test_image.png") itk_np_view = itk.image_view_from_array(test_image, is_vector=True) itk.imwrite(itk_np_view, filename) - result, header = LoadImage(reader=ITKReader())(Path(filename)) + for flag in (False, True): + result, header = LoadImage(reader=ITKReader(reverse_indexing=flag))(Path(filename)) - self.assertTupleEqual(tuple(header["spatial_shape"]), (224, 256)) - np.testing.assert_allclose(result[:, :, 0], test_image[:, :, 0].T) - np.testing.assert_allclose(result[:, :, 1], test_image[:, :, 1].T) - np.testing.assert_allclose(result[:, :, 2], test_image[:, :, 2].T) + self.assertTupleEqual(tuple(header["spatial_shape"]), (224, 256)) + test_image = test_image.transpose(1, 0, 2) + np.testing.assert_allclose(result[:, :, 0], test_image[:, :, 0]) + np.testing.assert_allclose(result[:, :, 1], test_image[:, :, 1]) + np.testing.assert_allclose(result[:, :, 2], test_image[:, :, 2]) def test_load_nifti_multichannel(self): test_image = np.random.randint(0, 256, size=(31, 64, 16, 2)).astype(np.float32) @@ -185,6 +196,8 @@ def test_load_nifti_multichannel(self): self.assertTupleEqual(tuple(nib_header["spatial_shape"]), (16, 64, 31)) self.assertTupleEqual(tuple(nib_image.shape), (16, 64, 31, 2)) + np.testing.assert_allclose(itk_img, nib_image, atol=1e-3, rtol=1e-3) + def test_load_png(self): spatial_size = (256, 224) test_image = np.random.randint(0, 256, size=spatial_size)