diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 41b0872698..b0ba1e39d9 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -518,3 +518,4 @@ weighted_patch_samples, zero_margins, ) +from .utils_pytorch_numpy_unification import moveaxis diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index d56bca0d8d..580c6c8b3c 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -31,6 +31,7 @@ map_binary_to_indices, map_classes_to_indices, ) +from monai.transforms.utils_pytorch_numpy_unification import moveaxis from monai.utils import ( convert_to_numpy, convert_to_tensor, @@ -82,17 +83,18 @@ class Identity(Transform): """ - Convert the input to an np.ndarray, if input data is np.ndarray or subclasses, return unchanged data. + Do nothing to the data. As the output value is same as input, it can be used as a testing tool to verify the transform chain, Compose or transform adaptor, etc. - """ - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> np.ndarray: + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - return np.asanyarray(img) + return img class AsChannelFirst(Transform): @@ -111,16 +113,18 @@ class AsChannelFirst(Transform): channel_dim: which dimension of input image is the channel, default is the last dimension. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, channel_dim: int = -1) -> None: if not (isinstance(channel_dim, int) and channel_dim >= -1): raise AssertionError("invalid channel dimension.") self.channel_dim = channel_dim - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - return np.moveaxis(img, self.channel_dim, 0) + return moveaxis(img, self.channel_dim, 0) class AsChannelLast(Transform): @@ -138,16 +142,18 @@ class AsChannelLast(Transform): channel_dim: which dimension of input image is the channel, default is the first dimension. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, channel_dim: int = 0) -> None: if not (isinstance(channel_dim, int) and channel_dim >= -1): raise AssertionError("invalid channel dimension.") self.channel_dim = channel_dim - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - return np.moveaxis(img, self.channel_dim, -1) + return moveaxis(img, self.channel_dim, -1) class AddChannel(Transform): @@ -164,7 +170,9 @@ class AddChannel(Transform): transforms. """ - def __call__(self, img: NdarrayTensor): + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ @@ -179,6 +187,8 @@ class EnsureChannelFirst(Transform): Convert the data to `channel_first` based on the `original_channel_dim` information. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, strict_check: bool = True): """ Args: @@ -186,7 +196,7 @@ def __init__(self, strict_check: bool = True): """ self.strict_check = strict_check - def __call__(self, img: np.ndarray, meta_dict: Optional[Mapping] = None): + def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> NdarrayOrTensor: """ Apply the transform to `img`. """ @@ -220,16 +230,19 @@ class RepeatChannel(Transform): repeats: the number of repetitions for each element. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, repeats: int) -> None: if repeats <= 0: raise AssertionError("repeats count must be greater than 0.") self.repeats = repeats - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is a "channel-first" array. """ - return np.repeat(img, self.repeats, 0) + repeeat_fn = torch.repeat_interleave if isinstance(img, torch.Tensor) else np.repeat + return repeeat_fn(img, self.repeats, 0) # type: ignore class RemoveRepeatedChannel(Transform): diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index a53e4f3235..41c2a1b9b9 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -169,6 +169,8 @@ class Identityd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Identity`. """ + backend = Identity.backend + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: @@ -180,9 +182,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.identity = Identity() - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.identity(d[key]) @@ -194,6 +194,8 @@ class AsChannelFirstd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.AsChannelFirst`. """ + backend = AsChannelFirst.backend + def __init__(self, keys: KeysCollection, channel_dim: int = -1, allow_missing_keys: bool = False) -> None: """ Args: @@ -205,7 +207,7 @@ def __init__(self, keys: KeysCollection, channel_dim: int = -1, allow_missing_ke super().__init__(keys, allow_missing_keys) self.converter = AsChannelFirst(channel_dim=channel_dim) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -217,6 +219,8 @@ class AsChannelLastd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.AsChannelLast`. """ + backend = AsChannelLast.backend + def __init__(self, keys: KeysCollection, channel_dim: int = 0, allow_missing_keys: bool = False) -> None: """ Args: @@ -228,7 +232,7 @@ def __init__(self, keys: KeysCollection, channel_dim: int = 0, allow_missing_key super().__init__(keys, allow_missing_keys) self.converter = AsChannelLast(channel_dim=channel_dim) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -240,6 +244,8 @@ class AddChanneld(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.AddChannel`. """ + backend = AddChannel.backend + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: @@ -250,7 +256,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.adder = AddChannel() - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.adder(d[key]) @@ -262,6 +268,8 @@ class EnsureChannelFirstd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.EnsureChannelFirst`. """ + backend = EnsureChannelFirst.backend + def __init__( self, keys: KeysCollection, @@ -289,7 +297,7 @@ def __init__( self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - def __call__(self, data) -> Dict[Hashable, np.ndarray]: + def __call__(self, data) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, meta_key, meta_key_postfix in zip(self.keys, self.meta_keys, self.meta_key_postfix): d[key] = self.adjuster(d[key], d[meta_key or f"{key}_{meta_key_postfix}"]) @@ -301,6 +309,8 @@ class RepeatChanneld(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.RepeatChannel`. """ + backend = RepeatChannel.backend + def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool = False) -> None: """ Args: @@ -312,7 +322,7 @@ def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.repeater = RepeatChannel(repeats) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.repeater(d[key]) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py new file mode 100644 index 0000000000..e6dc151596 --- /dev/null +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -0,0 +1,41 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch + +from monai.config.type_definitions import NdarrayOrTensor + +__all__ = [ + "moveaxis", +] + + +def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor: + if isinstance(x, torch.Tensor): + if hasattr(torch, "moveaxis"): + return torch.moveaxis(x, src, dst) + # moveaxis only available in pytorch as of 1.8.0 + else: + # get original indices + indices = list(range(x.ndim)) + # make src and dst positive + if src < 0: + src = len(indices) + src + if dst < 0: + dst = len(indices) + dst + # remove desired index and insert it in new position + indices.pop(src) + indices.insert(dst, src) + return x.permute(indices) + elif isinstance(x, np.ndarray): + return np.moveaxis(x, src, dst) + raise RuntimeError() diff --git a/tests/test_add_channeld.py b/tests/test_add_channeld.py index ca4af37271..8bdd89a4ae 100644 --- a/tests/test_add_channeld.py +++ b/tests/test_add_channeld.py @@ -15,16 +15,21 @@ from parameterized import parameterized from monai.transforms import AddChanneld +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"keys": ["img", "seg"]}, - {"img": np.array([[0, 1], [1, 2]]), "seg": np.array([[0, 1], [1, 2]])}, - (1, 2, 2), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["img", "seg"]}, + {"img": p(np.array([[0, 1], [1, 2]])), "seg": p(np.array([[0, 1], [1, 2]]))}, + (1, 2, 2), + ] + ) class TestAddChanneld(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape): result = AddChanneld(**input_param)(input_data) self.assertEqual(result["img"].shape, expected_shape) diff --git a/tests/test_as_channel_first.py b/tests/test_as_channel_first.py index e7d9866ae1..0d1b1c7d3a 100644 --- a/tests/test_as_channel_first.py +++ b/tests/test_as_channel_first.py @@ -12,23 +12,29 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import AsChannelFirst +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [{"channel_dim": -1}, (4, 1, 2, 3)] - -TEST_CASE_2 = [{"channel_dim": 3}, (4, 1, 2, 3)] - -TEST_CASE_3 = [{"channel_dim": 2}, (3, 1, 2, 4)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"channel_dim": -1}, (4, 1, 2, 3)]) + TESTS.append([p, {"channel_dim": 3}, (4, 1, 2, 3)]) + TESTS.append([p, {"channel_dim": 2}, (3, 1, 2, 4)]) class TestAsChannelFirst(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_param, expected_shape): - test_data = np.random.randint(0, 2, size=[1, 2, 3, 4]) + @parameterized.expand(TESTS) + def test_value(self, in_type, input_param, expected_shape): + test_data = in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])) result = AsChannelFirst(**input_param)(test_data) self.assertTupleEqual(result.shape, expected_shape) + if isinstance(test_data, torch.Tensor): + test_data = test_data.cpu().numpy() + expected = np.moveaxis(test_data, input_param["channel_dim"], 0) + assert_allclose(expected, result) if __name__ == "__main__": diff --git a/tests/test_as_channel_firstd.py b/tests/test_as_channel_firstd.py index e70c2e1b47..68d33434c1 100644 --- a/tests/test_as_channel_firstd.py +++ b/tests/test_as_channel_firstd.py @@ -15,21 +15,22 @@ from parameterized import parameterized from monai.transforms import AsChannelFirstd +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"keys": ["image", "label", "extra"], "channel_dim": -1}, (4, 1, 2, 3)] - -TEST_CASE_2 = [{"keys": ["image", "label", "extra"], "channel_dim": 3}, (4, 1, 2, 3)] - -TEST_CASE_3 = [{"keys": ["image", "label", "extra"], "channel_dim": 2}, (3, 1, 2, 4)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": -1}, (4, 1, 2, 3)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 3}, (4, 1, 2, 3)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 2}, (3, 1, 2, 4)]) class TestAsChannelFirstd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_param, expected_shape): + @parameterized.expand(TESTS) + def test_shape(self, in_type, input_param, expected_shape): test_data = { - "image": np.random.randint(0, 2, size=[1, 2, 3, 4]), - "label": np.random.randint(0, 2, size=[1, 2, 3, 4]), - "extra": np.random.randint(0, 2, size=[1, 2, 3, 4]), + "image": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), + "label": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), + "extra": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), } result = AsChannelFirstd(**input_param)(test_data) self.assertTupleEqual(result["image"].shape, expected_shape) diff --git a/tests/test_as_channel_last.py b/tests/test_as_channel_last.py index 6ec6c8d6e6..55a7a08676 100644 --- a/tests/test_as_channel_last.py +++ b/tests/test_as_channel_last.py @@ -15,18 +15,19 @@ from parameterized import parameterized from monai.transforms import AsChannelLast +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"channel_dim": 0}, (2, 3, 4, 1)] - -TEST_CASE_2 = [{"channel_dim": 1}, (1, 3, 4, 2)] - -TEST_CASE_3 = [{"channel_dim": 3}, (1, 2, 3, 4)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"channel_dim": 0}, (2, 3, 4, 1)]) + TESTS.append([p, {"channel_dim": 1}, (1, 3, 4, 2)]) + TESTS.append([p, {"channel_dim": 3}, (1, 2, 3, 4)]) class TestAsChannelLast(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_param, expected_shape): - test_data = np.random.randint(0, 2, size=[1, 2, 3, 4]) + @parameterized.expand(TESTS) + def test_shape(self, in_type, input_param, expected_shape): + test_data = in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])) result = AsChannelLast(**input_param)(test_data) self.assertTupleEqual(result.shape, expected_shape) diff --git a/tests/test_as_channel_lastd.py b/tests/test_as_channel_lastd.py index 2ef4dd4da1..350f639f3f 100644 --- a/tests/test_as_channel_lastd.py +++ b/tests/test_as_channel_lastd.py @@ -15,21 +15,22 @@ from parameterized import parameterized from monai.transforms import AsChannelLastd +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"keys": ["image", "label", "extra"], "channel_dim": 0}, (2, 3, 4, 1)] - -TEST_CASE_2 = [{"keys": ["image", "label", "extra"], "channel_dim": 1}, (1, 3, 4, 2)] - -TEST_CASE_3 = [{"keys": ["image", "label", "extra"], "channel_dim": 3}, (1, 2, 3, 4)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 0}, (2, 3, 4, 1)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 1}, (1, 3, 4, 2)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 3}, (1, 2, 3, 4)]) class TestAsChannelLastd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_param, expected_shape): + @parameterized.expand(TESTS) + def test_shape(self, in_type, input_param, expected_shape): test_data = { - "image": np.random.randint(0, 2, size=[1, 2, 3, 4]), - "label": np.random.randint(0, 2, size=[1, 2, 3, 4]), - "extra": np.random.randint(0, 2, size=[1, 2, 3, 4]), + "image": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), + "label": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), + "extra": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), } result = AsChannelLastd(**input_param)(test_data) self.assertTupleEqual(result["image"].shape, expected_shape) diff --git a/tests/test_ensure_channel_first.py b/tests/test_ensure_channel_first.py index 6b9def1cea..23126d326f 100644 --- a/tests/test_ensure_channel_first.py +++ b/tests/test_ensure_channel_first.py @@ -21,6 +21,7 @@ from monai.data import ITKReader from monai.transforms import EnsureChannelFirst, LoadImage +from tests.utils import TEST_NDARRAYS TEST_CASE_1 = [{"image_only": False}, ["test_image.nii.gz"], None] @@ -61,9 +62,10 @@ def test_load_nifti(self, input_param, filenames, original_channel_dim): for i, name in enumerate(filenames): filenames[i] = os.path.join(tempdir, name) nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) - result, header = LoadImage(**input_param)(filenames) - result = EnsureChannelFirst()(result, header) - self.assertEqual(result.shape[0], len(filenames)) + for p in TEST_NDARRAYS: + result, header = LoadImage(**input_param)(filenames) + result = EnsureChannelFirst()(p(result), header) + self.assertEqual(result.shape[0], len(filenames)) @parameterized.expand([TEST_CASE_7]) def test_itk_dicom_series_reader(self, input_param, filenames, original_channel_dim): diff --git a/tests/test_ensure_channel_firstd.py b/tests/test_ensure_channel_firstd.py index 59eb32c576..b4cde02a8f 100644 --- a/tests/test_ensure_channel_firstd.py +++ b/tests/test_ensure_channel_firstd.py @@ -19,6 +19,7 @@ from PIL import Image from monai.transforms import EnsureChannelFirstd, LoadImaged +from tests.utils import TEST_NDARRAYS TEST_CASE_1 = [{"keys": "img"}, ["test_image.nii.gz"], None] @@ -43,9 +44,11 @@ def test_load_nifti(self, input_param, filenames, original_channel_dim): for i, name in enumerate(filenames): filenames[i] = os.path.join(tempdir, name) nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) - result = LoadImaged(**input_param)({"img": filenames}) - result = EnsureChannelFirstd(**input_param)(result) - self.assertEqual(result["img"].shape[0], len(filenames)) + for p in TEST_NDARRAYS: + result = LoadImaged(**input_param)({"img": filenames}) + result["img"] = p(result["img"]) + result = EnsureChannelFirstd(**input_param)(result) + self.assertEqual(result["img"].shape[0], len(filenames)) def test_load_png(self): spatial_size = (256, 256, 3) diff --git a/tests/test_identity.py b/tests/test_identity.py index 2dff2bb13d..172860668c 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -11,17 +11,16 @@ import unittest -import numpy as np - from monai.transforms.utility.array import Identity -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestIdentity(NumpyImageTestCase2D): def test_identity(self): - img = self.imt - identity = Identity() - self.assertTrue(np.allclose(img, identity(img))) + for p in TEST_NDARRAYS: + img = p(self.imt) + identity = Identity() + assert_allclose(img, identity(img)) if __name__ == "__main__": diff --git a/tests/test_identityd.py b/tests/test_identityd.py index 8796f28da8..665b7d5d1c 100644 --- a/tests/test_identityd.py +++ b/tests/test_identityd.py @@ -12,16 +12,17 @@ import unittest from monai.transforms.utility.dictionary import Identityd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestIdentityd(NumpyImageTestCase2D): def test_identityd(self): - img = self.imt - data = {} - data["img"] = img - identity = Identityd(keys=data.keys()) - self.assertEqual(data, identity(data)) + for p in TEST_NDARRAYS: + img = p(self.imt) + data = {} + data["img"] = img + identity = Identityd(keys=data.keys()) + assert_allclose(img, identity(data)["img"]) if __name__ == "__main__": diff --git a/tests/test_repeat_channel.py b/tests/test_repeat_channel.py index 643ebc64de..e246dd1212 100644 --- a/tests/test_repeat_channel.py +++ b/tests/test_repeat_channel.py @@ -11,16 +11,18 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import RepeatChannel +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"repeats": 3}, np.array([[[0, 1], [1, 2]]]), (3, 2, 2)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([{"repeats": 3}, p([[[0, 1], [1, 2]]]), (3, 2, 2)]) class TestRepeatChannel(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape): result = RepeatChannel(**input_param)(input_data) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_repeat_channeld.py b/tests/test_repeat_channeld.py index 7bd58bd1fe..3b73962bb9 100644 --- a/tests/test_repeat_channeld.py +++ b/tests/test_repeat_channeld.py @@ -15,16 +15,21 @@ from parameterized import parameterized from monai.transforms import RepeatChanneld +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"keys": ["img"], "repeats": 3}, - {"img": np.array([[[0, 1], [1, 2]]]), "seg": np.array([[[0, 1], [1, 2]]])}, - (3, 2, 2), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["img"], "repeats": 3}, + {"img": p(np.array([[[0, 1], [1, 2]]])), "seg": p(np.array([[[0, 1], [1, 2]]]))}, + (3, 2, 2), + ] + ) class TestRepeatChanneld(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape): result = RepeatChanneld(**input_param)(input_data) self.assertEqual(result["img"].shape, expected_shape)