From 861ca823723f5c4fe404c2e1caed53b73ad89fb6 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 2 Mar 2021 17:39:09 +0800 Subject: [PATCH 1/9] [DLMED] add AutoAdjustChannel transform Signed-off-by: Nic Ma --- monai/data/image_reader.py | 39 ++++++++++---- monai/transforms/__init__.py | 1 + monai/transforms/utility/array.py | 24 ++++++++- tests/test_auto_adjust_channel.py | 86 +++++++++++++++++++++++++++++++ 4 files changed, 138 insertions(+), 12 deletions(-) create mode 100644 tests/test_auto_adjust_channel.py diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index e458833979..b63d42416a 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -109,6 +109,17 @@ def _copy_compatible_dict(from_dict: Dict, to_dict: Dict): ) +def _stack_images(image_list: List, meta_dict: Dict): + if len(image_list) > 1: + if meta_dict.get("original_channel_dim", None) is not None: + raise RuntimeError("can not read a list of images which already have channel dimension.") + meta_dict["original_channel_dim"] = 0 + img_array = np.stack(image_list, axis=0) + else: + img_array = image_list[0] + return img_array + + class ITKReader(ImageReader): """ Load medical images based on ITK library. @@ -200,11 +211,12 @@ def get_data(self, img): header["original_affine"] = self._get_affine(i) header["affine"] = header["original_affine"].copy() header["spatial_shape"] = self._get_spatial_shape(i) - img_array.append(self._get_array_data(i)) + data = self._get_array_data(i) + img_array.append(data) + header["original_channel_dim"] = None if len(data.shape) == len(header["spatial_shape"]) else -1 _copy_compatible_dict(header, compatible_meta) - img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0] - return img_array_, compatible_meta + return _stack_images(img_array, compatible_meta), compatible_meta def _get_meta_dict(self, img) -> Dict: """ @@ -265,6 +277,7 @@ def _get_spatial_shape(self, img) -> np.ndarray: img: a ITK image object loaded from a image file. """ + # the img data should have no channel dim or the last dim is channel shape = list(itk.size(img)) shape.reverse() return np.asarray(shape) @@ -371,11 +384,12 @@ def get_data(self, img): i = nib.as_closest_canonical(i) header["affine"] = self._get_affine(i) header["spatial_shape"] = self._get_spatial_shape(i) - img_array.append(self._get_array_data(i)) + data = self._get_array_data(i) + img_array.append(data) + header["original_channel_dim"] = None if len(data.shape) == len(header["spatial_shape"]) else -1 _copy_compatible_dict(header, compatible_meta) - img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0] - return img_array_, compatible_meta + return _stack_images(img_array, compatible_meta), compatible_meta def _get_meta_dict(self, img) -> Dict: """ @@ -408,6 +422,7 @@ def _get_spatial_shape(self, img) -> np.ndarray: """ ndim = img.header["dim"][0] spatial_rank = min(ndim, 3) + # the img data should have no channel dim or the last dim is channel return np.asarray(img.header["dim"][1 : spatial_rank + 1]) def _get_array_data(self, img) -> np.ndarray: @@ -504,12 +519,12 @@ def get_data(self, img): for i in ensure_tuple(img): header = {} if isinstance(i, np.ndarray): + # can not detect the channel dim of numpy array, use all the dims as spatial_shape header["spatial_shape"] = i.shape img_array.append(i) _copy_compatible_dict(header, compatible_meta) - img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0] - return img_array_, compatible_meta + return _stack_images(img_array, compatible_meta), compatible_meta class PILReader(ImageReader): @@ -582,11 +597,12 @@ def get_data(self, img): for i in ensure_tuple(img): header = self._get_meta_dict(i) header["spatial_shape"] = self._get_spatial_shape(i) - img_array.append(np.asarray(i)) + data = np.asarray(i) + img_array.append(data) + header["original_channel_dim"] = None if len(data.shape) == len(header["spatial_shape"]) else -1 _copy_compatible_dict(header, compatible_meta) - img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0] - return img_array_, compatible_meta + return _stack_images(img_array, compatible_meta), compatible_meta def _get_meta_dict(self, img) -> Dict: """ @@ -608,4 +624,5 @@ def _get_spatial_shape(self, img) -> np.ndarray: Args: img: a PIL Image object loaded from a image file. """ + # the img data should have no channel dim or the last dim is channel return np.asarray((img.width, img.height)) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index cd5b195bd3..1e7277725e 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -247,6 +247,7 @@ AddExtremePointsChannel, AsChannelFirst, AsChannelLast, + AutoAdjustChannel, CastToType, ConvertToMultiChannelBasedOnBratsClasses, DataStats, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 24d2feb781..d3c70f66b1 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -15,7 +15,7 @@ import logging import time -from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Tuple, Union +from typing import Dict, TYPE_CHECKING, Callable, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -39,6 +39,7 @@ "AsChannelFirst", "AsChannelLast", "AddChannel", + "AutoAdjustChannel", "RepeatChannel", "RemoveRepeatedChannel", "SplitChannel", @@ -149,6 +150,27 @@ def __call__(self, img: NdarrayTensor): return img[None] +class AutoAdjustChannel(Transform): + """ + Automatically adjust the channel dimension of input data. + It extract the `original_channel_dim` info from provided meta_data dictionary. + Convert the data to channel_first based on the `original_channel_dim` information. + + """ + def __call__(self, img: NdarrayTensor, meta_dict: Dict): + """ + Apply the transform to `img`. + """ + if "original_channel_dim" not in meta_dict: + raise ValueError("meta_dict must contain `original_channel_dim` information.") + channel_dim = meta_dict["original_channel_dim"] + + if channel_dim is None: + return AddChannel()(img) + else: + return AsChannelFirst(channel_dim=channel_dim)(img) + + class RepeatChannel(Transform): """ Repeat channel data to construct expected input shape for models. diff --git a/tests/test_auto_adjust_channel.py b/tests/test_auto_adjust_channel.py new file mode 100644 index 0000000000..fc7ab9baa6 --- /dev/null +++ b/tests/test_auto_adjust_channel.py @@ -0,0 +1,86 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import itk +import nibabel as nib +import numpy as np +from parameterized import parameterized +from PIL import Image + +from monai.data import ITKReader +from monai.transforms import AutoAdjustChannel, LoadImage + +TEST_CASE_1 = [{"image_only": False}, ["test_image.nii.gz"], None] + +TEST_CASE_2 = [{"image_only": False}, ["test_image.nii.gz"], -1] + +TEST_CASE_3 = [ + {"image_only": False}, + ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], + None, +] + +TEST_CASE_4 = [{"reader": ITKReader(), "image_only": False}, ["test_image.nii.gz"], None] + +TEST_CASE_5 = [{"reader": ITKReader(), "image_only": False}, ["test_image.nii.gz"], -1] + +TEST_CASE_6 = [ + {"reader": ITKReader(), "image_only": False}, + ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], + None, +] + +TEST_CASE_7 = [ + {"image_only": False, "reader": ITKReader(pixel_type=itk.UC)}, + "tests/testing_data/CT_DICOM", + None, +] + + +class TestAutoAdjustChannel(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + def test_load_nifti(self, input_param, filenames, original_channel_dim): + if original_channel_dim is None: + test_image = np.random.rand(128, 128, 128) + elif original_channel_dim == -1: + test_image = np.random.rand(128, 128, 128, 1) + + with tempfile.TemporaryDirectory() as tempdir: + for i, name in enumerate(filenames): + filenames[i] = os.path.join(tempdir, name) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) + result, header = LoadImage(**input_param)(filenames) + result = AutoAdjustChannel()(result, header) + self.assertEqual(result.shape[0], len(filenames)) + + @parameterized.expand([TEST_CASE_7]) + def test_itk_dicom_series_reader(self, input_param, filenames, original_channel_dim): + result, header = LoadImage(**input_param)(filenames) + result = AutoAdjustChannel()(result, header) + self.assertEqual(result.shape[0], 1) + + def test_load_png(self): + spatial_size = (256, 256, 3) + test_image = np.random.randint(0, 256, size=spatial_size) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, "test_image.png") + Image.fromarray(test_image.astype("uint8")).save(filename) + result, header = LoadImage(image_only=False)(filename) + result = AutoAdjustChannel()(result, header) + self.assertEqual(result.shape[0], 3) + + +if __name__ == "__main__": + unittest.main() From a98d124a093af91ffbc53a0811e78dc91a2544e4 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 2 Mar 2021 18:25:39 +0800 Subject: [PATCH 2/9] [DLMED] add dict version transform Signed-off-by: Nic Ma --- docs/source/transforms.rst | 12 +++++ monai/data/image_reader.py | 8 ++-- monai/transforms/__init__.py | 3 ++ monai/transforms/utility/array.py | 8 ++-- monai/transforms/utility/dictionary.py | 31 +++++++++++++ tests/min_tests.py | 2 + tests/test_auto_adjust_channeld.py | 62 ++++++++++++++++++++++++++ 7 files changed, 119 insertions(+), 7 deletions(-) create mode 100644 tests/test_auto_adjust_channeld.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 00d8cb9053..4dfbd44e0c 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -437,6 +437,12 @@ Utility :members: :special-members: __call__ +`AutoAdjustChannel` +""""""""""""""""""" +.. autoclass:: AutoAdjustChannel + :members: + :special-members: __call__ + `RepeatChannel` """"""""""""""" .. autoclass:: RepeatChannel @@ -890,6 +896,12 @@ Utility (Dict) :members: :special-members: __call__ +`AutoAdjustChanneld` +"""""""""""""""""""" +.. autoclass:: AutoAdjustChanneld + :members: + :special-members: __call__ + `RepeatChanneld` """""""""""""""" .. autoclass:: RepeatChanneld diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index b63d42416a..dfbdaf5b41 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -111,7 +111,7 @@ def _copy_compatible_dict(from_dict: Dict, to_dict: Dict): def _stack_images(image_list: List, meta_dict: Dict): if len(image_list) > 1: - if meta_dict.get("original_channel_dim", None) is not None: + if meta_dict.get("original_channel_dim", None) not in ("no_channel", None): raise RuntimeError("can not read a list of images which already have channel dimension.") meta_dict["original_channel_dim"] = 0 img_array = np.stack(image_list, axis=0) @@ -213,7 +213,7 @@ def get_data(self, img): header["spatial_shape"] = self._get_spatial_shape(i) data = self._get_array_data(i) img_array.append(data) - header["original_channel_dim"] = None if len(data.shape) == len(header["spatial_shape"]) else -1 + header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else -1 _copy_compatible_dict(header, compatible_meta) return _stack_images(img_array, compatible_meta), compatible_meta @@ -386,7 +386,7 @@ def get_data(self, img): header["spatial_shape"] = self._get_spatial_shape(i) data = self._get_array_data(i) img_array.append(data) - header["original_channel_dim"] = None if len(data.shape) == len(header["spatial_shape"]) else -1 + header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else -1 _copy_compatible_dict(header, compatible_meta) return _stack_images(img_array, compatible_meta), compatible_meta @@ -599,7 +599,7 @@ def get_data(self, img): header["spatial_shape"] = self._get_spatial_shape(i) data = np.asarray(i) img_array.append(data) - header["original_channel_dim"] = None if len(data.shape) == len(header["spatial_shape"]) else -1 + header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else -1 _copy_compatible_dict(header, compatible_meta) return _stack_images(img_array, compatible_meta), compatible_meta diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 1e7277725e..cf8e4c49d1 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -279,6 +279,9 @@ AsChannelLastd, AsChannelLastD, AsChannelLastDict, + AutoAdjustChanneld, + AutoAdjustChannelD, + AutoAdjustChannelDict, CastToTyped, CastToTypeD, CastToTypeDict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index d3c70f66b1..271a7e3696 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -154,6 +154,7 @@ class AutoAdjustChannel(Transform): """ Automatically adjust the channel dimension of input data. It extract the `original_channel_dim` info from provided meta_data dictionary. + Typical values of `original_channel_dim` can be: "no_channel", 0, -1. Convert the data to channel_first based on the `original_channel_dim` information. """ @@ -161,11 +162,12 @@ def __call__(self, img: NdarrayTensor, meta_dict: Dict): """ Apply the transform to `img`. """ - if "original_channel_dim" not in meta_dict: - raise ValueError("meta_dict must contain `original_channel_dim` information.") - channel_dim = meta_dict["original_channel_dim"] + + channel_dim = meta_dict.get("original_channel_dim", None) if channel_dim is None: + raise ValueError("meta_dict must contain `original_channel_dim` information.") + elif channel_dim == "no_channel": return AddChannel()(img) else: return AsChannelFirst(channel_dim=channel_dim)(img) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index e9d923d0fd..6e33852700 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -28,6 +28,7 @@ AddChannel, AsChannelFirst, AsChannelLast, + AutoAdjustChannel, CastToType, ConvertToMultiChannelBasedOnBratsClasses, DataStats, @@ -60,6 +61,7 @@ "AsChannelFirstd", "AsChannelLastd", "AddChanneld", + "AutoAdjustChanneld", "RepeatChanneld", "RemoveRepeatedChanneld", "SplitChanneld", @@ -89,6 +91,8 @@ "AsChannelLastDict", "AddChannelD", "AddChannelDict", + "AutoAdjustChannelD", + "AutoAdjustChannelDict", "RandLambdaD", "RandLambdaDict", "RepeatChannelD", @@ -217,6 +221,32 @@ def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, Nda return d +class AutoAdjustChanneld(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.AutoAdjustChannel`. + """ + + def __init__(self, keys: KeysCollection, meta_key_postfix: str = "meta_dict") -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + meta_key_postfix: `key_{postfix}` was used to store the metadata in `LoadImaged`. + So need the key to extract metadata for channel dim information, default is `meta_dict`. + For example, for data with key `image`, metadata by default is in `image_meta_dict`. + + """ + super().__init__(keys) + self.adjuster = AutoAdjustChannel() + self.meta_key_postfix = meta_key_postfix + + def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + d = dict(data) + for key in self.keys: + d[key] = self.adjuster(d[key], d[f"{key}_{self.meta_key_postfix}"]) + return d + + class RepeatChanneld(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.RepeatChannel`. @@ -894,6 +924,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd AddChannelD = AddChannelDict = AddChanneld +AutoAdjustChannelD = AutoAdjustChannelDict = AutoAdjustChanneld RemoveRepeatedChannelD = RemoveRepeatedChannelDict = RemoveRepeatedChanneld RepeatChannelD = RepeatChannelDict = RepeatChanneld SplitChannelD = SplitChannelDict = SplitChanneld diff --git a/tests/min_tests.py b/tests/min_tests.py index 999a1aeaa0..7e615ec261 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -109,6 +109,8 @@ def run_testsuit(): "test_deepgrow_dataset", "test_save_image", "test_save_imaged", + "test_auto_adjust_channel", + "test_auto_adjust_channeld", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_auto_adjust_channeld.py b/tests/test_auto_adjust_channeld.py new file mode 100644 index 0000000000..b07845b928 --- /dev/null +++ b/tests/test_auto_adjust_channeld.py @@ -0,0 +1,62 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import nibabel as nib +import numpy as np +from parameterized import parameterized +from PIL import Image + +from monai.transforms import AutoAdjustChanneld, LoadImaged + +TEST_CASE_1 = [{"keys": "img"}, ["test_image.nii.gz"], None] + +TEST_CASE_2 = [{"keys": "img"}, ["test_image.nii.gz"], -1] + +TEST_CASE_3 = [ + {"keys": "img"}, + ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], + None, +] + + +class TestAutoAdjustChanneld(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_load_nifti(self, input_param, filenames, original_channel_dim): + if original_channel_dim is None: + test_image = np.random.rand(128, 128, 128) + elif original_channel_dim == -1: + test_image = np.random.rand(128, 128, 128, 1) + + with tempfile.TemporaryDirectory() as tempdir: + for i, name in enumerate(filenames): + filenames[i] = os.path.join(tempdir, name) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) + result = LoadImaged(**input_param)({"img": filenames}) + result = AutoAdjustChanneld(**input_param)(result) + self.assertEqual(result["img"].shape[0], len(filenames)) + + def test_load_png(self): + spatial_size = (256, 256, 3) + test_image = np.random.randint(0, 256, size=spatial_size) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, "test_image.png") + Image.fromarray(test_image.astype("uint8")).save(filename) + result = LoadImaged(keys="img")({"img": filename}) + result = AutoAdjustChanneld(keys="img")(result) + self.assertEqual(result["img"].shape[0], 3) + + +if __name__ == "__main__": + unittest.main() From ffaf8cf7bfe0b7b47b8d1e81b78bdedf6df897dc Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 2 Mar 2021 10:29:52 +0000 Subject: [PATCH 3/9] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/transforms/utility/array.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 271a7e3696..ef8524f9fc 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -15,7 +15,7 @@ import logging import time -from typing import Dict, TYPE_CHECKING, Callable, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -158,6 +158,7 @@ class AutoAdjustChannel(Transform): Convert the data to channel_first based on the `original_channel_dim` information. """ + def __call__(self, img: NdarrayTensor, meta_dict: Dict): """ Apply the transform to `img`. From eefee474bd3b1ddc188958c4c06bbae358305a6a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 2 Mar 2021 18:36:14 +0800 Subject: [PATCH 4/9] [DLMED] fix doc-build issue Signed-off-by: Nic Ma --- monai/transforms/utility/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index ef8524f9fc..4f4e298370 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -154,7 +154,7 @@ class AutoAdjustChannel(Transform): """ Automatically adjust the channel dimension of input data. It extract the `original_channel_dim` info from provided meta_data dictionary. - Typical values of `original_channel_dim` can be: "no_channel", 0, -1. + Typical values of `original_channel_dim` can be: "no_channel", 0, -1. Convert the data to channel_first based on the `original_channel_dim` information. """ From a10942577f5350abd6c1deb89667f39b4ce295ff Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 2 Mar 2021 18:59:29 +0800 Subject: [PATCH 5/9] [DLMED] fix flake8 issue Signed-off-by: Nic Ma --- monai/transforms/utility/array.py | 4 +++- monai/transforms/utility/dictionary.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 4f4e298370..dc823bb9c4 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -159,10 +159,12 @@ class AutoAdjustChannel(Transform): """ - def __call__(self, img: NdarrayTensor, meta_dict: Dict): + def __call__(self, img: np.ndarray, meta_dict: Optional[Dict] = None): """ Apply the transform to `img`. """ + if not isinstance(meta_dict, dict): + raise ValueError("meta_dict must be a dictionay data.") channel_dim = meta_dict.get("original_channel_dim", None) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 6e33852700..9b748a89da 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -240,7 +240,7 @@ def __init__(self, keys: KeysCollection, meta_key_postfix: str = "meta_dict") -> self.adjuster = AutoAdjustChannel() self.meta_key_postfix = meta_key_postfix - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: d[key] = self.adjuster(d[key], d[f"{key}_{self.meta_key_postfix}"]) From 4c4ce9de1587d46bee4a00c407c0786e98882e42 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 2 Mar 2021 19:19:19 +0800 Subject: [PATCH 6/9] [DLMED] fix flake8 issue Signed-off-by: Nic Ma --- monai/transforms/utility/dictionary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 9b748a89da..3ef285f9d8 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -240,7 +240,7 @@ def __init__(self, keys: KeysCollection, meta_key_postfix: str = "meta_dict") -> self.adjuster = AutoAdjustChannel() self.meta_key_postfix = meta_key_postfix - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: d[key] = self.adjuster(d[key], d[f"{key}_{self.meta_key_postfix}"]) From 03adf465d6fcd8475a8a3f682f09750082ad7afc Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 2 Mar 2021 22:52:26 +0800 Subject: [PATCH 7/9] [DLMED] update according to Wenqi's comments Signed-off-by: Nic Ma --- docs/source/transforms.rst | 8 ++++---- monai/transforms/__init__.py | 8 ++++---- monai/transforms/utility/array.py | 4 ++-- monai/transforms/utility/dictionary.py | 16 ++++++++-------- tests/min_tests.py | 4 ++-- ...t_channel.py => test_ensure_channel_first.py} | 10 +++++----- ...channeld.py => test_ensure_channel_firstd.py} | 8 ++++---- 7 files changed, 29 insertions(+), 29 deletions(-) rename tests/{test_auto_adjust_channel.py => test_ensure_channel_first.py} (91%) rename tests/{test_auto_adjust_channeld.py => test_ensure_channel_firstd.py} (90%) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 4dfbd44e0c..0c677f4fd0 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -437,9 +437,9 @@ Utility :members: :special-members: __call__ -`AutoAdjustChannel` +`EnsureChannelFirst` """"""""""""""""""" -.. autoclass:: AutoAdjustChannel +.. autoclass:: EnsureChannelFirst :members: :special-members: __call__ @@ -896,9 +896,9 @@ Utility (Dict) :members: :special-members: __call__ -`AutoAdjustChanneld` +`EnsureChannelFirstd` """""""""""""""""""" -.. autoclass:: AutoAdjustChanneld +.. autoclass:: EnsureChannelFirstd :members: :special-members: __call__ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index cf8e4c49d1..2ca6122c58 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -247,7 +247,7 @@ AddExtremePointsChannel, AsChannelFirst, AsChannelLast, - AutoAdjustChannel, + EnsureChannelFirst, CastToType, ConvertToMultiChannelBasedOnBratsClasses, DataStats, @@ -279,9 +279,9 @@ AsChannelLastd, AsChannelLastD, AsChannelLastDict, - AutoAdjustChanneld, - AutoAdjustChannelD, - AutoAdjustChannelDict, + EnsureChannelFirstd, + EnsureChannelFirstD, + EnsureChannelFirstDict, CastToTyped, CastToTypeD, CastToTypeDict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index dc823bb9c4..bdf8a27974 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -39,7 +39,7 @@ "AsChannelFirst", "AsChannelLast", "AddChannel", - "AutoAdjustChannel", + "EnsureChannelFirst", "RepeatChannel", "RemoveRepeatedChannel", "SplitChannel", @@ -150,7 +150,7 @@ def __call__(self, img: NdarrayTensor): return img[None] -class AutoAdjustChannel(Transform): +class EnsureChannelFirst(Transform): """ Automatically adjust the channel dimension of input data. It extract the `original_channel_dim` info from provided meta_data dictionary. diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 3ef285f9d8..b40f27599a 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -28,7 +28,7 @@ AddChannel, AsChannelFirst, AsChannelLast, - AutoAdjustChannel, + EnsureChannelFirst, CastToType, ConvertToMultiChannelBasedOnBratsClasses, DataStats, @@ -61,7 +61,7 @@ "AsChannelFirstd", "AsChannelLastd", "AddChanneld", - "AutoAdjustChanneld", + "EnsureChannelFirstd", "RepeatChanneld", "RemoveRepeatedChanneld", "SplitChanneld", @@ -91,8 +91,8 @@ "AsChannelLastDict", "AddChannelD", "AddChannelDict", - "AutoAdjustChannelD", - "AutoAdjustChannelDict", + "EnsureChannelFirstD", + "EnsureChannelFirstDict", "RandLambdaD", "RandLambdaDict", "RepeatChannelD", @@ -221,9 +221,9 @@ def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, Nda return d -class AutoAdjustChanneld(MapTransform): +class EnsureChannelFirstd(MapTransform): """ - Dictionary-based wrapper of :py:class:`monai.transforms.AutoAdjustChannel`. + Dictionary-based wrapper of :py:class:`monai.transforms.EnsureChannelFirst`. """ def __init__(self, keys: KeysCollection, meta_key_postfix: str = "meta_dict") -> None: @@ -237,7 +237,7 @@ def __init__(self, keys: KeysCollection, meta_key_postfix: str = "meta_dict") -> """ super().__init__(keys) - self.adjuster = AutoAdjustChannel() + self.adjuster = EnsureChannelFirst() self.meta_key_postfix = meta_key_postfix def __call__(self, data) -> Dict[Hashable, np.ndarray]: @@ -924,7 +924,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd AddChannelD = AddChannelDict = AddChanneld -AutoAdjustChannelD = AutoAdjustChannelDict = AutoAdjustChanneld +EnsureChannelFirstD = EnsureChannelFirstDict = EnsureChannelFirstd RemoveRepeatedChannelD = RemoveRepeatedChannelDict = RemoveRepeatedChanneld RepeatChannelD = RepeatChannelDict = RepeatChanneld SplitChannelD = SplitChannelDict = SplitChanneld diff --git a/tests/min_tests.py b/tests/min_tests.py index 7e615ec261..83c1ceea9f 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -109,8 +109,8 @@ def run_testsuit(): "test_deepgrow_dataset", "test_save_image", "test_save_imaged", - "test_auto_adjust_channel", - "test_auto_adjust_channeld", + "test_ensure_channel_first", + "test_ensure_channel_firstd", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_auto_adjust_channel.py b/tests/test_ensure_channel_first.py similarity index 91% rename from tests/test_auto_adjust_channel.py rename to tests/test_ensure_channel_first.py index fc7ab9baa6..ff656f2e24 100644 --- a/tests/test_auto_adjust_channel.py +++ b/tests/test_ensure_channel_first.py @@ -20,7 +20,7 @@ from PIL import Image from monai.data import ITKReader -from monai.transforms import AutoAdjustChannel, LoadImage +from monai.transforms import EnsureChannelFirst, LoadImage TEST_CASE_1 = [{"image_only": False}, ["test_image.nii.gz"], None] @@ -49,7 +49,7 @@ ] -class TestAutoAdjustChannel(unittest.TestCase): +class TestEnsureChannelFirst(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_load_nifti(self, input_param, filenames, original_channel_dim): if original_channel_dim is None: @@ -62,13 +62,13 @@ def test_load_nifti(self, input_param, filenames, original_channel_dim): filenames[i] = os.path.join(tempdir, name) nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) result, header = LoadImage(**input_param)(filenames) - result = AutoAdjustChannel()(result, header) + result = EnsureChannelFirst()(result, header) self.assertEqual(result.shape[0], len(filenames)) @parameterized.expand([TEST_CASE_7]) def test_itk_dicom_series_reader(self, input_param, filenames, original_channel_dim): result, header = LoadImage(**input_param)(filenames) - result = AutoAdjustChannel()(result, header) + result = EnsureChannelFirst()(result, header) self.assertEqual(result.shape[0], 1) def test_load_png(self): @@ -78,7 +78,7 @@ def test_load_png(self): filename = os.path.join(tempdir, "test_image.png") Image.fromarray(test_image.astype("uint8")).save(filename) result, header = LoadImage(image_only=False)(filename) - result = AutoAdjustChannel()(result, header) + result = EnsureChannelFirst()(result, header) self.assertEqual(result.shape[0], 3) diff --git a/tests/test_auto_adjust_channeld.py b/tests/test_ensure_channel_firstd.py similarity index 90% rename from tests/test_auto_adjust_channeld.py rename to tests/test_ensure_channel_firstd.py index b07845b928..a5298f4453 100644 --- a/tests/test_auto_adjust_channeld.py +++ b/tests/test_ensure_channel_firstd.py @@ -18,7 +18,7 @@ from parameterized import parameterized from PIL import Image -from monai.transforms import AutoAdjustChanneld, LoadImaged +from monai.transforms import EnsureChannelFirstd, LoadImaged TEST_CASE_1 = [{"keys": "img"}, ["test_image.nii.gz"], None] @@ -31,7 +31,7 @@ ] -class TestAutoAdjustChanneld(unittest.TestCase): +class TestEnsureChannelFirstd(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_load_nifti(self, input_param, filenames, original_channel_dim): if original_channel_dim is None: @@ -44,7 +44,7 @@ def test_load_nifti(self, input_param, filenames, original_channel_dim): filenames[i] = os.path.join(tempdir, name) nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) result = LoadImaged(**input_param)({"img": filenames}) - result = AutoAdjustChanneld(**input_param)(result) + result = EnsureChannelFirstd(**input_param)(result) self.assertEqual(result["img"].shape[0], len(filenames)) def test_load_png(self): @@ -54,7 +54,7 @@ def test_load_png(self): filename = os.path.join(tempdir, "test_image.png") Image.fromarray(test_image.astype("uint8")).save(filename) result = LoadImaged(keys="img")({"img": filename}) - result = AutoAdjustChanneld(keys="img")(result) + result = EnsureChannelFirstd(keys="img")(result) self.assertEqual(result["img"].shape[0], 3) From 7dd4b71b91673469cc9d635c7820002a747b9484 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 2 Mar 2021 22:57:41 +0800 Subject: [PATCH 8/9] [DLMED] update doc-strings Signed-off-by: Nic Ma --- docs/source/transforms.rst | 4 ++-- monai/transforms/utility/array.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 0c677f4fd0..dd10176de9 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -438,7 +438,7 @@ Utility :special-members: __call__ `EnsureChannelFirst` -""""""""""""""""""" +"""""""""""""""""""" .. autoclass:: EnsureChannelFirst :members: :special-members: __call__ @@ -897,7 +897,7 @@ Utility (Dict) :special-members: __call__ `EnsureChannelFirstd` -"""""""""""""""""""" +""""""""""""""""""""" .. autoclass:: EnsureChannelFirstd :members: :special-members: __call__ diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index bdf8a27974..62daf9309c 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -152,10 +152,10 @@ def __call__(self, img: NdarrayTensor): class EnsureChannelFirst(Transform): """ - Automatically adjust the channel dimension of input data. - It extract the `original_channel_dim` info from provided meta_data dictionary. + Automatically adjust or add the channel dimension of input data to ensure `channel_first` shape. + It extracts the `original_channel_dim` info from provided meta_data dictionary. Typical values of `original_channel_dim` can be: "no_channel", 0, -1. - Convert the data to channel_first based on the `original_channel_dim` information. + Convert the data to `channel_first` based on the `original_channel_dim` information. """ From 9dd29797b928acabd1093b28647b103cb9c2a60d Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 2 Mar 2021 15:01:49 +0000 Subject: [PATCH 9/9] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/transforms/__init__.py | 8 ++++---- monai/transforms/utility/dictionary.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 2ca6122c58..a8d647b657 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -247,10 +247,10 @@ AddExtremePointsChannel, AsChannelFirst, AsChannelLast, - EnsureChannelFirst, CastToType, ConvertToMultiChannelBasedOnBratsClasses, DataStats, + EnsureChannelFirst, FgBgToIndices, Identity, LabelToMask, @@ -279,9 +279,6 @@ AsChannelLastd, AsChannelLastD, AsChannelLastDict, - EnsureChannelFirstd, - EnsureChannelFirstD, - EnsureChannelFirstDict, CastToTyped, CastToTypeD, CastToTypeDict, @@ -300,6 +297,9 @@ DeleteItemsd, DeleteItemsD, DeleteItemsDict, + EnsureChannelFirstd, + EnsureChannelFirstD, + EnsureChannelFirstDict, FgBgToIndicesd, FgBgToIndicesD, FgBgToIndicesDict, diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index b40f27599a..4a0808fdbb 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -28,10 +28,10 @@ AddChannel, AsChannelFirst, AsChannelLast, - EnsureChannelFirst, CastToType, ConvertToMultiChannelBasedOnBratsClasses, DataStats, + EnsureChannelFirst, FgBgToIndices, Identity, LabelToMask,