Skip to content
Merged
12 changes: 12 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,12 @@ Utility
:members:
:special-members: __call__

`EnsureChannelFirst`
""""""""""""""""""""
.. autoclass:: EnsureChannelFirst
:members:
:special-members: __call__

`RepeatChannel`
"""""""""""""""
.. autoclass:: RepeatChannel
Expand Down Expand Up @@ -890,6 +896,12 @@ Utility (Dict)
:members:
:special-members: __call__

`EnsureChannelFirstd`
"""""""""""""""""""""
.. autoclass:: EnsureChannelFirstd
:members:
:special-members: __call__

`RepeatChanneld`
""""""""""""""""
.. autoclass:: RepeatChanneld
Expand Down
39 changes: 28 additions & 11 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,17 @@ def _copy_compatible_dict(from_dict: Dict, to_dict: Dict):
)


def _stack_images(image_list: List, meta_dict: Dict):
if len(image_list) > 1:
if meta_dict.get("original_channel_dim", None) not in ("no_channel", None):
raise RuntimeError("can not read a list of images which already have channel dimension.")
meta_dict["original_channel_dim"] = 0
img_array = np.stack(image_list, axis=0)
else:
img_array = image_list[0]
return img_array


class ITKReader(ImageReader):
"""
Load medical images based on ITK library.
Expand Down Expand Up @@ -200,11 +211,12 @@ def get_data(self, img):
header["original_affine"] = self._get_affine(i)
header["affine"] = header["original_affine"].copy()
header["spatial_shape"] = self._get_spatial_shape(i)
img_array.append(self._get_array_data(i))
data = self._get_array_data(i)
img_array.append(data)
header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else -1
_copy_compatible_dict(header, compatible_meta)

img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0]
return img_array_, compatible_meta
return _stack_images(img_array, compatible_meta), compatible_meta

def _get_meta_dict(self, img) -> Dict:
"""
Expand Down Expand Up @@ -265,6 +277,7 @@ def _get_spatial_shape(self, img) -> np.ndarray:
img: a ITK image object loaded from a image file.

"""
# the img data should have no channel dim or the last dim is channel
shape = list(itk.size(img))
shape.reverse()
return np.asarray(shape)
Expand Down Expand Up @@ -371,11 +384,12 @@ def get_data(self, img):
i = nib.as_closest_canonical(i)
header["affine"] = self._get_affine(i)
header["spatial_shape"] = self._get_spatial_shape(i)
img_array.append(self._get_array_data(i))
data = self._get_array_data(i)
img_array.append(data)
header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else -1
_copy_compatible_dict(header, compatible_meta)

img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0]
return img_array_, compatible_meta
return _stack_images(img_array, compatible_meta), compatible_meta

def _get_meta_dict(self, img) -> Dict:
"""
Expand Down Expand Up @@ -408,6 +422,7 @@ def _get_spatial_shape(self, img) -> np.ndarray:
"""
ndim = img.header["dim"][0]
spatial_rank = min(ndim, 3)
# the img data should have no channel dim or the last dim is channel
return np.asarray(img.header["dim"][1 : spatial_rank + 1])

def _get_array_data(self, img) -> np.ndarray:
Expand Down Expand Up @@ -504,12 +519,12 @@ def get_data(self, img):
for i in ensure_tuple(img):
header = {}
if isinstance(i, np.ndarray):
# can not detect the channel dim of numpy array, use all the dims as spatial_shape
header["spatial_shape"] = i.shape
img_array.append(i)
_copy_compatible_dict(header, compatible_meta)

img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0]
return img_array_, compatible_meta
return _stack_images(img_array, compatible_meta), compatible_meta


class PILReader(ImageReader):
Expand Down Expand Up @@ -582,11 +597,12 @@ def get_data(self, img):
for i in ensure_tuple(img):
header = self._get_meta_dict(i)
header["spatial_shape"] = self._get_spatial_shape(i)
img_array.append(np.asarray(i))
data = np.asarray(i)
img_array.append(data)
header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else -1
_copy_compatible_dict(header, compatible_meta)

img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0]
return img_array_, compatible_meta
return _stack_images(img_array, compatible_meta), compatible_meta

def _get_meta_dict(self, img) -> Dict:
"""
Expand All @@ -608,4 +624,5 @@ def _get_spatial_shape(self, img) -> np.ndarray:
Args:
img: a PIL Image object loaded from a image file.
"""
# the img data should have no channel dim or the last dim is channel
return np.asarray((img.width, img.height))
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@
CastToType,
ConvertToMultiChannelBasedOnBratsClasses,
DataStats,
EnsureChannelFirst,
FgBgToIndices,
Identity,
LabelToMask,
Expand Down Expand Up @@ -296,6 +297,9 @@
DeleteItemsd,
DeleteItemsD,
DeleteItemsDict,
EnsureChannelFirstd,
EnsureChannelFirstD,
EnsureChannelFirstDict,
FgBgToIndicesd,
FgBgToIndicesD,
FgBgToIndicesDict,
Expand Down
29 changes: 28 additions & 1 deletion monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import logging
import time
from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand All @@ -39,6 +39,7 @@
"AsChannelFirst",
"AsChannelLast",
"AddChannel",
"EnsureChannelFirst",
"RepeatChannel",
"RemoveRepeatedChannel",
"SplitChannel",
Expand Down Expand Up @@ -149,6 +150,32 @@ def __call__(self, img: NdarrayTensor):
return img[None]


class EnsureChannelFirst(Transform):
"""
Automatically adjust or add the channel dimension of input data to ensure `channel_first` shape.
It extracts the `original_channel_dim` info from provided meta_data dictionary.
Typical values of `original_channel_dim` can be: "no_channel", 0, -1.
Convert the data to `channel_first` based on the `original_channel_dim` information.

"""

def __call__(self, img: np.ndarray, meta_dict: Optional[Dict] = None):
"""
Apply the transform to `img`.
"""
if not isinstance(meta_dict, dict):
raise ValueError("meta_dict must be a dictionay data.")

channel_dim = meta_dict.get("original_channel_dim", None)

if channel_dim is None:
raise ValueError("meta_dict must contain `original_channel_dim` information.")
elif channel_dim == "no_channel":
return AddChannel()(img)
else:
return AsChannelFirst(channel_dim=channel_dim)(img)


class RepeatChannel(Transform):
"""
Repeat channel data to construct expected input shape for models.
Expand Down
31 changes: 31 additions & 0 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
CastToType,
ConvertToMultiChannelBasedOnBratsClasses,
DataStats,
EnsureChannelFirst,
FgBgToIndices,
Identity,
LabelToMask,
Expand Down Expand Up @@ -60,6 +61,7 @@
"AsChannelFirstd",
"AsChannelLastd",
"AddChanneld",
"EnsureChannelFirstd",
"RepeatChanneld",
"RemoveRepeatedChanneld",
"SplitChanneld",
Expand Down Expand Up @@ -89,6 +91,8 @@
"AsChannelLastDict",
"AddChannelD",
"AddChannelDict",
"EnsureChannelFirstD",
"EnsureChannelFirstDict",
"RandLambdaD",
"RandLambdaDict",
"RepeatChannelD",
Expand Down Expand Up @@ -217,6 +221,32 @@ def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, Nda
return d


class EnsureChannelFirstd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.EnsureChannelFirst`.
"""

def __init__(self, keys: KeysCollection, meta_key_postfix: str = "meta_dict") -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
meta_key_postfix: `key_{postfix}` was used to store the metadata in `LoadImaged`.
So need the key to extract metadata for channel dim information, default is `meta_dict`.
For example, for data with key `image`, metadata by default is in `image_meta_dict`.

"""
super().__init__(keys)
self.adjuster = EnsureChannelFirst()
self.meta_key_postfix = meta_key_postfix

def __call__(self, data) -> Dict[Hashable, np.ndarray]:
d = dict(data)
for key in self.keys:
d[key] = self.adjuster(d[key], d[f"{key}_{self.meta_key_postfix}"])
return d


class RepeatChanneld(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.RepeatChannel`.
Expand Down Expand Up @@ -894,6 +924,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd
AsChannelLastD = AsChannelLastDict = AsChannelLastd
AddChannelD = AddChannelDict = AddChanneld
EnsureChannelFirstD = EnsureChannelFirstDict = EnsureChannelFirstd
RemoveRepeatedChannelD = RemoveRepeatedChannelDict = RemoveRepeatedChanneld
RepeatChannelD = RepeatChannelDict = RepeatChanneld
SplitChannelD = SplitChannelDict = SplitChanneld
Expand Down
2 changes: 2 additions & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def run_testsuit():
"test_deepgrow_dataset",
"test_save_image",
"test_save_imaged",
"test_ensure_channel_first",
"test_ensure_channel_firstd",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"

Expand Down
86 changes: 86 additions & 0 deletions tests/test_ensure_channel_first.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest

import itk
import nibabel as nib
import numpy as np
from parameterized import parameterized
from PIL import Image

from monai.data import ITKReader
from monai.transforms import EnsureChannelFirst, LoadImage

TEST_CASE_1 = [{"image_only": False}, ["test_image.nii.gz"], None]

TEST_CASE_2 = [{"image_only": False}, ["test_image.nii.gz"], -1]

TEST_CASE_3 = [
{"image_only": False},
["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"],
None,
]

TEST_CASE_4 = [{"reader": ITKReader(), "image_only": False}, ["test_image.nii.gz"], None]

TEST_CASE_5 = [{"reader": ITKReader(), "image_only": False}, ["test_image.nii.gz"], -1]

TEST_CASE_6 = [
{"reader": ITKReader(), "image_only": False},
["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"],
None,
]

TEST_CASE_7 = [
{"image_only": False, "reader": ITKReader(pixel_type=itk.UC)},
"tests/testing_data/CT_DICOM",
None,
]


class TestEnsureChannelFirst(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])
def test_load_nifti(self, input_param, filenames, original_channel_dim):
if original_channel_dim is None:
test_image = np.random.rand(128, 128, 128)
elif original_channel_dim == -1:
test_image = np.random.rand(128, 128, 128, 1)

with tempfile.TemporaryDirectory() as tempdir:
for i, name in enumerate(filenames):
filenames[i] = os.path.join(tempdir, name)
nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i])
result, header = LoadImage(**input_param)(filenames)
result = EnsureChannelFirst()(result, header)
self.assertEqual(result.shape[0], len(filenames))

@parameterized.expand([TEST_CASE_7])
def test_itk_dicom_series_reader(self, input_param, filenames, original_channel_dim):
result, header = LoadImage(**input_param)(filenames)
result = EnsureChannelFirst()(result, header)
self.assertEqual(result.shape[0], 1)

def test_load_png(self):
spatial_size = (256, 256, 3)
test_image = np.random.randint(0, 256, size=spatial_size)
with tempfile.TemporaryDirectory() as tempdir:
filename = os.path.join(tempdir, "test_image.png")
Image.fromarray(test_image.astype("uint8")).save(filename)
result, header = LoadImage(image_only=False)(filename)
result = EnsureChannelFirst()(result, header)
self.assertEqual(result.shape[0], 3)


if __name__ == "__main__":
unittest.main()
Loading