diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 00d8cb9053..dd10176de9 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -437,6 +437,12 @@ Utility :members: :special-members: __call__ +`EnsureChannelFirst` +"""""""""""""""""""" +.. autoclass:: EnsureChannelFirst + :members: + :special-members: __call__ + `RepeatChannel` """"""""""""""" .. autoclass:: RepeatChannel @@ -890,6 +896,12 @@ Utility (Dict) :members: :special-members: __call__ +`EnsureChannelFirstd` +""""""""""""""""""""" +.. autoclass:: EnsureChannelFirstd + :members: + :special-members: __call__ + `RepeatChanneld` """""""""""""""" .. autoclass:: RepeatChanneld diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index e458833979..dfbdaf5b41 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -109,6 +109,17 @@ def _copy_compatible_dict(from_dict: Dict, to_dict: Dict): ) +def _stack_images(image_list: List, meta_dict: Dict): + if len(image_list) > 1: + if meta_dict.get("original_channel_dim", None) not in ("no_channel", None): + raise RuntimeError("can not read a list of images which already have channel dimension.") + meta_dict["original_channel_dim"] = 0 + img_array = np.stack(image_list, axis=0) + else: + img_array = image_list[0] + return img_array + + class ITKReader(ImageReader): """ Load medical images based on ITK library. @@ -200,11 +211,12 @@ def get_data(self, img): header["original_affine"] = self._get_affine(i) header["affine"] = header["original_affine"].copy() header["spatial_shape"] = self._get_spatial_shape(i) - img_array.append(self._get_array_data(i)) + data = self._get_array_data(i) + img_array.append(data) + header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else -1 _copy_compatible_dict(header, compatible_meta) - img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0] - return img_array_, compatible_meta + return _stack_images(img_array, compatible_meta), compatible_meta def _get_meta_dict(self, img) -> Dict: """ @@ -265,6 +277,7 @@ def _get_spatial_shape(self, img) -> np.ndarray: img: a ITK image object loaded from a image file. """ + # the img data should have no channel dim or the last dim is channel shape = list(itk.size(img)) shape.reverse() return np.asarray(shape) @@ -371,11 +384,12 @@ def get_data(self, img): i = nib.as_closest_canonical(i) header["affine"] = self._get_affine(i) header["spatial_shape"] = self._get_spatial_shape(i) - img_array.append(self._get_array_data(i)) + data = self._get_array_data(i) + img_array.append(data) + header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else -1 _copy_compatible_dict(header, compatible_meta) - img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0] - return img_array_, compatible_meta + return _stack_images(img_array, compatible_meta), compatible_meta def _get_meta_dict(self, img) -> Dict: """ @@ -408,6 +422,7 @@ def _get_spatial_shape(self, img) -> np.ndarray: """ ndim = img.header["dim"][0] spatial_rank = min(ndim, 3) + # the img data should have no channel dim or the last dim is channel return np.asarray(img.header["dim"][1 : spatial_rank + 1]) def _get_array_data(self, img) -> np.ndarray: @@ -504,12 +519,12 @@ def get_data(self, img): for i in ensure_tuple(img): header = {} if isinstance(i, np.ndarray): + # can not detect the channel dim of numpy array, use all the dims as spatial_shape header["spatial_shape"] = i.shape img_array.append(i) _copy_compatible_dict(header, compatible_meta) - img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0] - return img_array_, compatible_meta + return _stack_images(img_array, compatible_meta), compatible_meta class PILReader(ImageReader): @@ -582,11 +597,12 @@ def get_data(self, img): for i in ensure_tuple(img): header = self._get_meta_dict(i) header["spatial_shape"] = self._get_spatial_shape(i) - img_array.append(np.asarray(i)) + data = np.asarray(i) + img_array.append(data) + header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else -1 _copy_compatible_dict(header, compatible_meta) - img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0] - return img_array_, compatible_meta + return _stack_images(img_array, compatible_meta), compatible_meta def _get_meta_dict(self, img) -> Dict: """ @@ -608,4 +624,5 @@ def _get_spatial_shape(self, img) -> np.ndarray: Args: img: a PIL Image object loaded from a image file. """ + # the img data should have no channel dim or the last dim is channel return np.asarray((img.width, img.height)) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index cd5b195bd3..a8d647b657 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -250,6 +250,7 @@ CastToType, ConvertToMultiChannelBasedOnBratsClasses, DataStats, + EnsureChannelFirst, FgBgToIndices, Identity, LabelToMask, @@ -296,6 +297,9 @@ DeleteItemsd, DeleteItemsD, DeleteItemsDict, + EnsureChannelFirstd, + EnsureChannelFirstD, + EnsureChannelFirstDict, FgBgToIndicesd, FgBgToIndicesD, FgBgToIndicesDict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 24d2feb781..62daf9309c 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -15,7 +15,7 @@ import logging import time -from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -39,6 +39,7 @@ "AsChannelFirst", "AsChannelLast", "AddChannel", + "EnsureChannelFirst", "RepeatChannel", "RemoveRepeatedChannel", "SplitChannel", @@ -149,6 +150,32 @@ def __call__(self, img: NdarrayTensor): return img[None] +class EnsureChannelFirst(Transform): + """ + Automatically adjust or add the channel dimension of input data to ensure `channel_first` shape. + It extracts the `original_channel_dim` info from provided meta_data dictionary. + Typical values of `original_channel_dim` can be: "no_channel", 0, -1. + Convert the data to `channel_first` based on the `original_channel_dim` information. + + """ + + def __call__(self, img: np.ndarray, meta_dict: Optional[Dict] = None): + """ + Apply the transform to `img`. + """ + if not isinstance(meta_dict, dict): + raise ValueError("meta_dict must be a dictionay data.") + + channel_dim = meta_dict.get("original_channel_dim", None) + + if channel_dim is None: + raise ValueError("meta_dict must contain `original_channel_dim` information.") + elif channel_dim == "no_channel": + return AddChannel()(img) + else: + return AsChannelFirst(channel_dim=channel_dim)(img) + + class RepeatChannel(Transform): """ Repeat channel data to construct expected input shape for models. diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index e9d923d0fd..4a0808fdbb 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -31,6 +31,7 @@ CastToType, ConvertToMultiChannelBasedOnBratsClasses, DataStats, + EnsureChannelFirst, FgBgToIndices, Identity, LabelToMask, @@ -60,6 +61,7 @@ "AsChannelFirstd", "AsChannelLastd", "AddChanneld", + "EnsureChannelFirstd", "RepeatChanneld", "RemoveRepeatedChanneld", "SplitChanneld", @@ -89,6 +91,8 @@ "AsChannelLastDict", "AddChannelD", "AddChannelDict", + "EnsureChannelFirstD", + "EnsureChannelFirstDict", "RandLambdaD", "RandLambdaDict", "RepeatChannelD", @@ -217,6 +221,32 @@ def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, Nda return d +class EnsureChannelFirstd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.EnsureChannelFirst`. + """ + + def __init__(self, keys: KeysCollection, meta_key_postfix: str = "meta_dict") -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + meta_key_postfix: `key_{postfix}` was used to store the metadata in `LoadImaged`. + So need the key to extract metadata for channel dim information, default is `meta_dict`. + For example, for data with key `image`, metadata by default is in `image_meta_dict`. + + """ + super().__init__(keys) + self.adjuster = EnsureChannelFirst() + self.meta_key_postfix = meta_key_postfix + + def __call__(self, data) -> Dict[Hashable, np.ndarray]: + d = dict(data) + for key in self.keys: + d[key] = self.adjuster(d[key], d[f"{key}_{self.meta_key_postfix}"]) + return d + + class RepeatChanneld(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.RepeatChannel`. @@ -894,6 +924,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd AddChannelD = AddChannelDict = AddChanneld +EnsureChannelFirstD = EnsureChannelFirstDict = EnsureChannelFirstd RemoveRepeatedChannelD = RemoveRepeatedChannelDict = RemoveRepeatedChanneld RepeatChannelD = RepeatChannelDict = RepeatChanneld SplitChannelD = SplitChannelDict = SplitChanneld diff --git a/tests/min_tests.py b/tests/min_tests.py index 999a1aeaa0..83c1ceea9f 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -109,6 +109,8 @@ def run_testsuit(): "test_deepgrow_dataset", "test_save_image", "test_save_imaged", + "test_ensure_channel_first", + "test_ensure_channel_firstd", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_ensure_channel_first.py b/tests/test_ensure_channel_first.py new file mode 100644 index 0000000000..ff656f2e24 --- /dev/null +++ b/tests/test_ensure_channel_first.py @@ -0,0 +1,86 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import itk +import nibabel as nib +import numpy as np +from parameterized import parameterized +from PIL import Image + +from monai.data import ITKReader +from monai.transforms import EnsureChannelFirst, LoadImage + +TEST_CASE_1 = [{"image_only": False}, ["test_image.nii.gz"], None] + +TEST_CASE_2 = [{"image_only": False}, ["test_image.nii.gz"], -1] + +TEST_CASE_3 = [ + {"image_only": False}, + ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], + None, +] + +TEST_CASE_4 = [{"reader": ITKReader(), "image_only": False}, ["test_image.nii.gz"], None] + +TEST_CASE_5 = [{"reader": ITKReader(), "image_only": False}, ["test_image.nii.gz"], -1] + +TEST_CASE_6 = [ + {"reader": ITKReader(), "image_only": False}, + ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], + None, +] + +TEST_CASE_7 = [ + {"image_only": False, "reader": ITKReader(pixel_type=itk.UC)}, + "tests/testing_data/CT_DICOM", + None, +] + + +class TestEnsureChannelFirst(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + def test_load_nifti(self, input_param, filenames, original_channel_dim): + if original_channel_dim is None: + test_image = np.random.rand(128, 128, 128) + elif original_channel_dim == -1: + test_image = np.random.rand(128, 128, 128, 1) + + with tempfile.TemporaryDirectory() as tempdir: + for i, name in enumerate(filenames): + filenames[i] = os.path.join(tempdir, name) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) + result, header = LoadImage(**input_param)(filenames) + result = EnsureChannelFirst()(result, header) + self.assertEqual(result.shape[0], len(filenames)) + + @parameterized.expand([TEST_CASE_7]) + def test_itk_dicom_series_reader(self, input_param, filenames, original_channel_dim): + result, header = LoadImage(**input_param)(filenames) + result = EnsureChannelFirst()(result, header) + self.assertEqual(result.shape[0], 1) + + def test_load_png(self): + spatial_size = (256, 256, 3) + test_image = np.random.randint(0, 256, size=spatial_size) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, "test_image.png") + Image.fromarray(test_image.astype("uint8")).save(filename) + result, header = LoadImage(image_only=False)(filename) + result = EnsureChannelFirst()(result, header) + self.assertEqual(result.shape[0], 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_ensure_channel_firstd.py b/tests/test_ensure_channel_firstd.py new file mode 100644 index 0000000000..a5298f4453 --- /dev/null +++ b/tests/test_ensure_channel_firstd.py @@ -0,0 +1,62 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import nibabel as nib +import numpy as np +from parameterized import parameterized +from PIL import Image + +from monai.transforms import EnsureChannelFirstd, LoadImaged + +TEST_CASE_1 = [{"keys": "img"}, ["test_image.nii.gz"], None] + +TEST_CASE_2 = [{"keys": "img"}, ["test_image.nii.gz"], -1] + +TEST_CASE_3 = [ + {"keys": "img"}, + ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], + None, +] + + +class TestEnsureChannelFirstd(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_load_nifti(self, input_param, filenames, original_channel_dim): + if original_channel_dim is None: + test_image = np.random.rand(128, 128, 128) + elif original_channel_dim == -1: + test_image = np.random.rand(128, 128, 128, 1) + + with tempfile.TemporaryDirectory() as tempdir: + for i, name in enumerate(filenames): + filenames[i] = os.path.join(tempdir, name) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) + result = LoadImaged(**input_param)({"img": filenames}) + result = EnsureChannelFirstd(**input_param)(result) + self.assertEqual(result["img"].shape[0], len(filenames)) + + def test_load_png(self): + spatial_size = (256, 256, 3) + test_image = np.random.randint(0, 256, size=spatial_size) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, "test_image.png") + Image.fromarray(test_image.astype("uint8")).save(filename) + result = LoadImaged(keys="img")({"img": filename}) + result = EnsureChannelFirstd(keys="img")(result) + self.assertEqual(result["img"].shape[0], 3) + + +if __name__ == "__main__": + unittest.main()