diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 580c6c8b3c..f38a94302e 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -32,15 +32,7 @@ map_classes_to_indices, ) from monai.transforms.utils_pytorch_numpy_unification import moveaxis -from monai.utils import ( - convert_to_numpy, - convert_to_tensor, - ensure_tuple, - issequenceiterable, - look_up_option, - min_version, - optional_import, -) +from monai.utils import convert_to_numpy, convert_to_tensor, ensure_tuple, look_up_option, min_version, optional_import from monai.utils.enums import TransformBackends from monai.utils.type_conversion import convert_data_type @@ -255,20 +247,22 @@ class RemoveRepeatedChannel(Transform): repeats: the number of repetitions to be deleted for each element. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, repeats: int) -> None: if repeats <= 0: raise AssertionError("repeats count must be greater than 0.") self.repeats = repeats - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is a "channel-first" array. """ - if np.shape(img)[0] < 2: + if img.shape[0] < 2: raise AssertionError("Image must have more than one channel") - return np.array(img[:: self.repeats, :]) + return img[:: self.repeats, :] class SplitChannel(Transform): @@ -281,10 +275,12 @@ class SplitChannel(Transform): """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, channel_dim: int = 0) -> None: self.channel_dim = channel_dim - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> List[Union[np.ndarray, torch.Tensor]]: + def __call__(self, img: NdarrayOrTensor) -> List[NdarrayOrTensor]: n_classes = img.shape[self.channel_dim] if n_classes <= 1: raise RuntimeError("input image does not contain multiple channels.") @@ -335,18 +331,13 @@ class ToTensor(Transform): Converts the input image to a tensor without applying any other transformations. """ - def __call__(self, img) -> torch.Tensor: + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, img: NdarrayOrTensor) -> torch.Tensor: """ Apply the transform to `img` and make it contiguous. """ - if isinstance(img, torch.Tensor): - return img.contiguous() - if issequenceiterable(img): - # numpy array with 0 dims is also sequence iterable - if not (isinstance(img, np.ndarray) and img.ndim == 0): - # `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims - img = np.ascontiguousarray(img) - return torch.as_tensor(img) + return convert_to_tensor(img, wrap_sequence=True) # type: ignore class EnsureType(Transform): @@ -361,6 +352,8 @@ class EnsureType(Transform): """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, data_type: str = "tensor") -> None: data_type = data_type.lower() if data_type not in ("tensor", "numpy"): @@ -368,7 +361,7 @@ def __init__(self, data_type: str = "tensor") -> None: self.data_type = data_type - def __call__(self, data): + def __call__(self, data: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. @@ -377,7 +370,7 @@ def __call__(self, data): if applicable. """ - return convert_to_tensor(data) if self.data_type == "tensor" else convert_to_numpy(data) + return convert_to_tensor(data) if self.data_type == "tensor" else convert_to_numpy(data) # type: ignore class ToNumpy(Transform): @@ -385,17 +378,13 @@ class ToNumpy(Transform): Converts the input data to numpy array, can support list or tuple of numbers and PyTorch Tensor. """ - def __call__(self, img) -> np.ndarray: + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, img: NdarrayOrTensor) -> np.ndarray: """ Apply the transform to `img` and make it contiguous. """ - if isinstance(img, torch.Tensor): - img = img.detach().cpu().numpy() - elif has_cp and isinstance(img, cp_ndarray): - img = cp.asnumpy(img) - - array: np.ndarray = np.asarray(img) - return np.ascontiguousarray(array) if array.ndim > 0 else array + return convert_to_numpy(img) # type: ignore class ToCupy(Transform): @@ -403,13 +392,15 @@ class ToCupy(Transform): Converts the input data to CuPy array, can support list or tuple of numbers, NumPy and PyTorch Tensor. """ - def __call__(self, img): + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img` and make it contiguous. """ if isinstance(img, torch.Tensor): img = img.detach().cpu().numpy() - return cp.ascontiguousarray(cp.asarray(img)) + return cp.ascontiguousarray(cp.asarray(img)) # type: ignore class ToPIL(Transform): @@ -417,6 +408,8 @@ class ToPIL(Transform): Converts the input image (in the form of NumPy array or PyTorch Tensor) to PIL image """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __call__(self, img): """ Apply the transform to `img`. @@ -433,13 +426,17 @@ class Transpose(Transform): Transposes the input image based on the given `indices` dimension ordering. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, indices: Optional[Sequence[int]]) -> None: self.indices = None if indices is None else tuple(indices) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ + if isinstance(img, torch.Tensor): + return img.permute(self.indices or tuple(range(img.ndim)[::-1])) return img.transpose(self.indices) # type: ignore diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 41c2a1b9b9..1b63b308d9 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -334,6 +334,8 @@ class RemoveRepeatedChanneld(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.RemoveRepeatedChannel`. """ + backend = RemoveRepeatedChannel.backend + def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool = False) -> None: """ Args: @@ -345,7 +347,7 @@ def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.repeater = RemoveRepeatedChannel(repeats) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.repeater(d[key]) @@ -356,9 +358,10 @@ class SplitChanneld(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.SplitChannel`. All the input specified by `keys` should be split into same count of data. - """ + backend = SplitChannel.backend + def __init__( self, keys: KeysCollection, @@ -382,9 +385,7 @@ def __init__( self.output_postfixes = output_postfixes self.splitter = SplitChannel(channel_dim=channel_dim) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): rets = self.splitter(d[key]) @@ -439,6 +440,8 @@ class ToTensord(MapTransform, InvertibleTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ToTensor`. """ + backend = ToTensor.backend + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: @@ -449,14 +452,14 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.converter = ToTensor() - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): self.push_transform(d, key) d[key] = self.converter(d[key]) return d - def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): # Create inverse transform @@ -481,6 +484,8 @@ class EnsureTyped(MapTransform, InvertibleTransform): """ + backend = EnsureType.backend + def __init__(self, keys: KeysCollection, data_type: str = "tensor", allow_missing_keys: bool = False) -> None: """ Args: @@ -492,7 +497,7 @@ def __init__(self, keys: KeysCollection, data_type: str = "tensor", allow_missin super().__init__(keys, allow_missing_keys) self.converter = EnsureType(data_type=data_type) - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): self.push_transform(d, key) @@ -515,6 +520,8 @@ class ToNumpyd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ToNumpy`. """ + backend = ToNumpy.backend + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: @@ -537,6 +544,8 @@ class ToCupyd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ToCupy`. """ + backend = ToCupy.backend + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: @@ -547,7 +556,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.converter = ToCupy() - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -559,6 +568,8 @@ class ToPILd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ToNumpy`. """ + backend = ToPIL.backend + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: @@ -581,13 +592,15 @@ class Transposed(MapTransform, InvertibleTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Transpose`. """ + backend = Transpose.backend + def __init__( self, keys: KeysCollection, indices: Optional[Sequence[int]], allow_missing_keys: bool = False ) -> None: super().__init__(keys, allow_missing_keys) self.transform = Transpose(indices) - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.transform(d[key]) diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index e6df607764..14300eeca0 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -83,7 +83,7 @@ def get_dtype(data: Any): return type(data) -def convert_to_tensor(data): +def convert_to_tensor(data, wrap_sequence: bool = False): """ Utility to convert the input data to a PyTorch Tensor. If passing a dictionary, list or tuple, recursively check every item and convert it to PyTorch Tensor. @@ -92,6 +92,8 @@ def convert_to_tensor(data): data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will convert Tensor, Numpy array, float, int, bool to Tensors, strings and objects keep the original. for dictionary, list or tuple, convert every item to a Tensor if applicable. + wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. + If `True`, then `[1, 2]` -> `tensor([1, 2])`. """ if isinstance(data, torch.Tensor): @@ -105,17 +107,19 @@ def convert_to_tensor(data): return torch.as_tensor(data if data.ndim == 0 else np.ascontiguousarray(data)) elif isinstance(data, (float, int, bool)): return torch.as_tensor(data) - elif isinstance(data, dict): - return {k: convert_to_tensor(v) for k, v in data.items()} + elif isinstance(data, Sequence) and wrap_sequence: + return torch.as_tensor(data) elif isinstance(data, list): return [convert_to_tensor(i) for i in data] elif isinstance(data, tuple): return tuple(convert_to_tensor(i) for i in data) + elif isinstance(data, dict): + return {k: convert_to_tensor(v) for k, v in data.items()} return data -def convert_to_numpy(data): +def convert_to_numpy(data, wrap_sequence: bool = False): """ Utility to convert the input data to a numpy array. If passing a dictionary, list or tuple, recursively check every item and convert it to numpy array. @@ -124,7 +128,8 @@ def convert_to_numpy(data): data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will convert Tensor, Numpy array, float, int, bool to numpy arrays, strings and objects keep the original. for dictionary, list or tuple, convert every item to a numpy array if applicable. - + wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`. + If `True`, then `[1, 2]` -> `array([1, 2])`. """ if isinstance(data, torch.Tensor): data = data.detach().cpu().numpy() @@ -132,12 +137,14 @@ def convert_to_numpy(data): data = cp.asnumpy(data) elif isinstance(data, (float, int, bool)): data = np.asarray(data) - elif isinstance(data, dict): - return {k: convert_to_numpy(v) for k, v in data.items()} + elif isinstance(data, Sequence) and wrap_sequence: + return np.asarray(data) elif isinstance(data, list): return [convert_to_numpy(i) for i in data] elif isinstance(data, tuple): return tuple(convert_to_numpy(i) for i in data) + elif isinstance(data, dict): + return {k: convert_to_numpy(v) for k, v in data.items()} if isinstance(data, np.ndarray) and data.ndim > 0: data = np.ascontiguousarray(data) diff --git a/tests/test_ensure_type.py b/tests/test_ensure_type.py index 11cf6760fb..8feb96ed37 100644 --- a/tests/test_ensure_type.py +++ b/tests/test_ensure_type.py @@ -15,26 +15,33 @@ import torch from monai.transforms import EnsureType +from tests.utils import assert_allclose class TestEnsureType(unittest.TestCase): def test_array_input(self): - for test_data in (np.array([[1, 2], [3, 4]]), torch.as_tensor([[1, 2], [3, 4]])): + test_datas = [np.array([[1, 2], [3, 4]]), torch.as_tensor([[1, 2], [3, 4]])] + if torch.cuda.is_available(): + test_datas.append(test_datas[-1].cuda()) + for test_data in test_datas: for dtype in ("tensor", "NUMPY"): result = EnsureType(data_type=dtype)(test_data) self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) - torch.testing.assert_allclose(result, test_data) + assert_allclose(result, test_data) self.assertTupleEqual(result.shape, (2, 2)) def test_single_input(self): - for test_data in (5, 5.0, False, np.asarray(5), torch.tensor(5)): + test_datas = [5, 5.0, False, np.asarray(5), torch.tensor(5)] + if torch.cuda.is_available(): + test_datas.append(test_datas[-1].cuda()) + for test_data in test_datas: for dtype in ("tensor", "numpy"): result = EnsureType(data_type=dtype)(test_data) self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) if isinstance(test_data, bool): self.assertFalse(result) else: - torch.testing.assert_allclose(result, test_data) + assert_allclose(result, test_data) self.assertEqual(result.ndim, 0) def test_string(self): diff --git a/tests/test_ensure_typed.py b/tests/test_ensure_typed.py index c5f588d423..96f482afc2 100644 --- a/tests/test_ensure_typed.py +++ b/tests/test_ensure_typed.py @@ -15,26 +15,33 @@ import torch from monai.transforms import EnsureTyped +from tests.utils import assert_allclose class TestEnsureTyped(unittest.TestCase): def test_array_input(self): - for test_data in (np.array([[1, 2], [3, 4]]), torch.as_tensor([[1, 2], [3, 4]])): + test_datas = [np.array([[1, 2], [3, 4]]), torch.as_tensor([[1, 2], [3, 4]])] + if torch.cuda.is_available(): + test_datas.append(test_datas[-1].cuda()) + for test_data in test_datas: for dtype in ("tensor", "NUMPY"): result = EnsureTyped(keys="data", data_type=dtype)({"data": test_data})["data"] self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) - torch.testing.assert_allclose(result, test_data) + assert_allclose(result, test_data) self.assertTupleEqual(result.shape, (2, 2)) def test_single_input(self): - for test_data in (5, 5.0, False, np.asarray(5), torch.tensor(5)): + test_datas = [5, 5.0, False, np.asarray(5), torch.tensor(5)] + if torch.cuda.is_available(): + test_datas.append(test_datas[-1].cuda()) + for test_data in test_datas: for dtype in ("tensor", "numpy"): result = EnsureTyped(keys="data", data_type=dtype)({"data": test_data})["data"] self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) if isinstance(test_data, bool): self.assertFalse(result) else: - torch.testing.assert_allclose(result, test_data) + assert_allclose(result, test_data) self.assertEqual(result.ndim, 0) def test_string(self): diff --git a/tests/test_remove_repeated_channel.py b/tests/test_remove_repeated_channel.py index 070e0e2b8d..ebbe6c730c 100644 --- a/tests/test_remove_repeated_channel.py +++ b/tests/test_remove_repeated_channel.py @@ -12,15 +12,18 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RemoveRepeatedChannel -TEST_CASE_1 = [{"repeats": 2}, np.array([[1, 2], [1, 2], [3, 4], [3, 4]]), (2, 2)] +TEST_CASES = [] +for q in (torch.Tensor, np.array): + TEST_CASES.append([{"repeats": 2}, q([[1, 2], [1, 2], [3, 4], [3, 4]]), (2, 2)]) # type: ignore class TestRemoveRepeatedChannel(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_shape): result = RemoveRepeatedChannel(**input_param)(input_data) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_remove_repeated_channeld.py b/tests/test_remove_repeated_channeld.py index 46c68bbdc2..9d4812791e 100644 --- a/tests/test_remove_repeated_channeld.py +++ b/tests/test_remove_repeated_channeld.py @@ -15,16 +15,24 @@ from parameterized import parameterized from monai.transforms import RemoveRepeatedChanneld - -TEST_CASE_1 = [ - {"keys": ["img"], "repeats": 2}, - {"img": np.array([[1, 2], [1, 2], [3, 4], [3, 4]]), "seg": np.array([[1, 2], [1, 2], [3, 4], [3, 4]])}, - (2, 2), -] +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["img"], "repeats": 2}, + { + "img": p(np.array([[1, 2], [1, 2], [3, 4], [3, 4]])), + "seg": p(np.array([[1, 2], [1, 2], [3, 4], [3, 4]])), + }, + (2, 2), + ] + ) class TestRemoveRepeatedChanneld(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape): result = RemoveRepeatedChanneld(**input_param)(input_data) self.assertEqual(result["img"].shape, expected_shape) diff --git a/tests/test_split_channel.py b/tests/test_split_channel.py index 91e93aedcc..38315a102c 100644 --- a/tests/test_split_channel.py +++ b/tests/test_split_channel.py @@ -12,22 +12,21 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import SplitChannel +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"channel_dim": 1}, torch.randint(0, 2, size=(4, 3, 3, 4)), (4, 1, 3, 4)] - -TEST_CASE_2 = [{"channel_dim": 0}, np.random.randint(2, size=(3, 3, 4)), (1, 3, 4)] - -TEST_CASE_3 = [{"channel_dim": 2}, np.random.randint(2, size=(3, 2, 4)), (3, 2, 1)] - -TEST_CASE_4 = [{"channel_dim": -1}, np.random.randint(2, size=(3, 2, 4)), (3, 2, 1)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([{"channel_dim": 1}, p(np.random.randint(2, size=(4, 3, 3, 4))), (4, 1, 3, 4)]) + TESTS.append([{"channel_dim": 0}, p(np.random.randint(2, size=(3, 3, 4))), (1, 3, 4)]) + TESTS.append([{"channel_dim": 2}, p(np.random.randint(2, size=(3, 2, 4))), (3, 2, 1)]) + TESTS.append([{"channel_dim": -1}, p(np.random.randint(2, size=(3, 2, 4))), (3, 2, 1)]) class TestSplitChannel(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand(TESTS) def test_shape(self, input_param, test_data, expected_shape): result = SplitChannel(**input_param)(test_data) for data in result: diff --git a/tests/test_split_channeld.py b/tests/test_split_channeld.py index 57c7099b9f..f1df24364d 100644 --- a/tests/test_split_channeld.py +++ b/tests/test_split_channeld.py @@ -12,44 +12,56 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import SplitChanneld +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3"], "channel_dim": 1}, - {"pred": torch.randint(0, 2, size=(4, 3, 3, 4))}, - (4, 1, 3, 4), -] - -TEST_CASE_2 = [ - {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3"], "channel_dim": 0}, - {"pred": np.random.randint(2, size=(3, 3, 4))}, - (1, 3, 4), -] - -TEST_CASE_3 = [ - {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3", "cls4"], "channel_dim": 2}, - {"pred": np.random.randint(2, size=(3, 2, 4))}, - (3, 2, 1), -] - -TEST_CASE_4 = [ - {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3", "cls4"], "channel_dim": -1}, - {"pred": np.random.randint(2, size=(3, 2, 4))}, - (3, 2, 1), -] - -TEST_CASE_5 = [ - {"keys": "pred", "channel_dim": 1}, - {"pred": np.random.randint(2, size=(3, 2, 4))}, - (3, 1, 4), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3"], "channel_dim": 1}, + {"pred": p(np.random.randint(2, size=(4, 3, 3, 4)))}, + (4, 1, 3, 4), + ] + ) + + TESTS.append( + [ + {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3"], "channel_dim": 0}, + {"pred": p(np.random.randint(2, size=(3, 3, 4)))}, + (1, 3, 4), + ] + ) + + TESTS.append( + [ + {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3", "cls4"], "channel_dim": 2}, + {"pred": p(np.random.randint(2, size=(3, 2, 4)))}, + (3, 2, 1), + ] + ) + + TESTS.append( + [ + {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3", "cls4"], "channel_dim": -1}, + {"pred": p(np.random.randint(2, size=(3, 2, 4)))}, + (3, 2, 1), + ] + ) + + TESTS.append( + [ + {"keys": "pred", "channel_dim": 1}, + {"pred": p(np.random.randint(2, size=(3, 2, 4)))}, + (3, 1, 4), + ] + ) class TestSplitChanneld(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_shape(self, input_param, test_data, expected_shape): result = SplitChanneld(**input_param)(test_data) for k, v in result.items(): diff --git a/tests/test_to_cupy.py b/tests/test_to_cupy.py index 76c9464b20..a9460bc825 100644 --- a/tests/test_to_cupy.py +++ b/tests/test_to_cupy.py @@ -17,6 +17,7 @@ from monai.transforms import ToCupy from monai.utils import optional_import +from tests.utils import skip_if_no_cuda cp, has_cp = optional_import("cupy") @@ -52,6 +53,17 @@ def test_tensor_input(self): self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data.numpy()) + @skipUnless(has_cp, "CuPy is required.") + @skip_if_no_cuda + def test_tensor_cuda_input(self): + test_data = torch.tensor([[1, 2], [3, 4]]).cuda() + test_data = test_data.rot90() + self.assertFalse(test_data.is_contiguous()) + result = ToCupy()(test_data) + self.assertTrue(isinstance(result, cp.ndarray)) + self.assertTrue(result.flags["C_CONTIGUOUS"]) + cp.testing.assert_allclose(result, test_data.cpu().numpy()) + @skipUnless(has_cp, "CuPy is required.") def test_list_tuple(self): test_data = [[1, 2], [3, 4]] diff --git a/tests/test_to_cupyd.py b/tests/test_to_cupyd.py index b869bedc96..2f3c42dd1f 100644 --- a/tests/test_to_cupyd.py +++ b/tests/test_to_cupyd.py @@ -17,6 +17,7 @@ from monai.transforms import ToCupyd from monai.utils import optional_import +from tests.utils import skip_if_no_cuda cp, has_cp = optional_import("cupy") @@ -52,6 +53,17 @@ def test_tensor_input(self): self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data.numpy()) + @skipUnless(has_cp, "CuPy is required.") + @skip_if_no_cuda + def test_tensor_cuda_input(self): + test_data = torch.tensor([[1, 2], [3, 4]]).cuda() + test_data = test_data.rot90() + self.assertFalse(test_data.is_contiguous()) + result = ToCupyd(keys="img")({"img": test_data})["img"] + self.assertTrue(isinstance(result, cp.ndarray)) + self.assertTrue(result.flags["C_CONTIGUOUS"]) + cp.testing.assert_allclose(result, test_data.cpu().numpy()) + @skipUnless(has_cp, "CuPy is required.") def test_list_tuple(self): test_data = [[1, 2], [3, 4]] diff --git a/tests/test_to_numpy.py b/tests/test_to_numpy.py index 291601ffeb..fd49a3d473 100644 --- a/tests/test_to_numpy.py +++ b/tests/test_to_numpy.py @@ -17,6 +17,7 @@ from monai.transforms import ToNumpy from monai.utils import optional_import +from tests.utils import assert_allclose, skip_if_no_cuda cp, has_cp = optional_import("cupy") @@ -30,7 +31,7 @@ def test_cumpy_input(self): result = ToNumpy()(test_data) self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - np.testing.assert_allclose(result, test_data.get()) + assert_allclose(result, test_data.get()) def test_numpy_input(self): test_data = np.array([[1, 2], [3, 4]]) @@ -39,7 +40,7 @@ def test_numpy_input(self): result = ToNumpy()(test_data) self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - np.testing.assert_allclose(result, test_data) + assert_allclose(result, test_data) def test_tensor_input(self): test_data = torch.tensor([[1, 2], [3, 4]]) @@ -48,21 +49,31 @@ def test_tensor_input(self): result = ToNumpy()(test_data) self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - np.testing.assert_allclose(result, test_data.numpy()) + assert_allclose(result, test_data) + + @skip_if_no_cuda + def test_tensor_cuda_input(self): + test_data = torch.tensor([[1, 2], [3, 4]]).cuda() + test_data = test_data.rot90() + self.assertFalse(test_data.is_contiguous()) + result = ToNumpy()(test_data) + self.assertTrue(isinstance(result, np.ndarray)) + self.assertTrue(result.flags["C_CONTIGUOUS"]) + assert_allclose(result, test_data) def test_list_tuple(self): test_data = [[1, 2], [3, 4]] result = ToNumpy()(test_data) - np.testing.assert_allclose(result, np.asarray(test_data)) + assert_allclose(result, np.asarray(test_data)) test_data = ((1, 2), (3, 4)) result = ToNumpy()(test_data) - np.testing.assert_allclose(result, np.asarray(test_data)) + assert_allclose(result, np.asarray(test_data)) def test_single_value(self): for test_data in [5, np.array(5), torch.tensor(5)]: result = ToNumpy()(test_data) self.assertTrue(isinstance(result, np.ndarray)) - np.testing.assert_allclose(result, np.asarray(test_data)) + assert_allclose(result, np.asarray(test_data)) self.assertEqual(result.ndim, 0) diff --git a/tests/test_to_numpyd.py b/tests/test_to_numpyd.py index 1fb43ea2ac..adfab65904 100644 --- a/tests/test_to_numpyd.py +++ b/tests/test_to_numpyd.py @@ -17,6 +17,7 @@ from monai.transforms import ToNumpyd from monai.utils import optional_import +from tests.utils import assert_allclose, skip_if_no_cuda cp, has_cp = optional_import("cupy") @@ -30,7 +31,7 @@ def test_cumpy_input(self): result = ToNumpyd(keys="img")({"img": test_data})["img"] self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - np.testing.assert_allclose(result, test_data.get()) + assert_allclose(result, test_data.get()) def test_numpy_input(self): test_data = np.array([[1, 2], [3, 4]]) @@ -39,7 +40,7 @@ def test_numpy_input(self): result = ToNumpyd(keys="img")({"img": test_data})["img"] self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - np.testing.assert_allclose(result, test_data) + assert_allclose(result, test_data) def test_tensor_input(self): test_data = torch.tensor([[1, 2], [3, 4]]) @@ -48,7 +49,17 @@ def test_tensor_input(self): result = ToNumpyd(keys="img")({"img": test_data})["img"] self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - np.testing.assert_allclose(result, test_data.numpy()) + assert_allclose(result, test_data) + + @skip_if_no_cuda + def test_tensor_cuda_input(self): + test_data = torch.tensor([[1, 2], [3, 4]]).cuda() + test_data = test_data.rot90() + self.assertFalse(test_data.is_contiguous()) + result = ToNumpyd(keys="img")({"img": test_data})["img"] + self.assertTrue(isinstance(result, np.ndarray)) + self.assertTrue(result.flags["C_CONTIGUOUS"]) + assert_allclose(result, test_data) if __name__ == "__main__": diff --git a/tests/test_to_pil.py b/tests/test_to_pil.py index ec63750ce4..5690645dd8 100644 --- a/tests/test_to_pil.py +++ b/tests/test_to_pil.py @@ -14,11 +14,11 @@ from unittest import skipUnless import numpy as np -import torch from parameterized import parameterized from monai.transforms import ToPIL from monai.utils import optional_import +from tests.utils import TEST_NDARRAYS, assert_allclose if TYPE_CHECKING: from PIL.Image import Image as PILImageImage @@ -29,35 +29,21 @@ pil_image_fromarray, has_pil = optional_import("PIL.Image", name="fromarray") PILImageImage, _ = optional_import("PIL.Image", name="Image") -TEST_CASE_ARRAY_1 = [np.array([[1.0, 2.0], [3.0, 4.0]])] -TEST_CASE_TENSOR_1 = [torch.tensor([[1.0, 2.0], [3.0, 4.0]])] +im = [[1.0, 2.0], [3.0, 4.0]] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p(im)]) +if has_pil: + TESTS.append([pil_image_fromarray(np.array(im))]) class TestToPIL(unittest.TestCase): - @parameterized.expand([TEST_CASE_ARRAY_1]) + @parameterized.expand(TESTS) @skipUnless(has_pil, "Requires `pillow` package.") - def test_numpy_input(self, test_data): - self.assertTrue(isinstance(test_data, np.ndarray)) + def test_value(self, test_data): result = ToPIL()(test_data) self.assertTrue(isinstance(result, PILImageImage)) - np.testing.assert_allclose(np.array(result), test_data) - - @parameterized.expand([TEST_CASE_TENSOR_1]) - @skipUnless(has_pil, "Requires `pillow` package.") - def test_tensor_input(self, test_data): - self.assertTrue(isinstance(test_data, torch.Tensor)) - result = ToPIL()(test_data) - self.assertTrue(isinstance(result, PILImageImage)) - np.testing.assert_allclose(np.array(result), test_data.numpy()) - - @parameterized.expand([TEST_CASE_ARRAY_1]) - @skipUnless(has_pil, "Requires `pillow` package.") - def test_pil_input(self, test_data): - test_data_pil = pil_image_fromarray(test_data) - self.assertTrue(isinstance(test_data_pil, PILImageImage)) - result = ToPIL()(test_data_pil) - self.assertTrue(isinstance(result, PILImageImage)) - np.testing.assert_allclose(np.array(result), test_data) + assert_allclose(np.array(result), test_data) if __name__ == "__main__": diff --git a/tests/test_to_pild.py b/tests/test_to_pild.py index 43778022ee..3a15b1e507 100644 --- a/tests/test_to_pild.py +++ b/tests/test_to_pild.py @@ -14,11 +14,11 @@ from unittest import skipUnless import numpy as np -import torch from parameterized import parameterized from monai.transforms import ToPILd from monai.utils import optional_import +from tests.utils import TEST_NDARRAYS, assert_allclose if TYPE_CHECKING: from PIL.Image import Image as PILImageImage @@ -29,36 +29,21 @@ pil_image_fromarray, has_pil = optional_import("PIL.Image", name="fromarray") PILImageImage, _ = optional_import("PIL.Image", name="Image") -TEST_CASE_ARRAY_1 = [{"keys": "image"}, {"image": np.array([[1.0, 2.0], [3.0, 4.0]])}] -TEST_CASE__TENSOR_1 = [{"keys": "image"}, {"image": torch.tensor([[1.0, 2.0], [3.0, 4.0]])}] +im = [[1.0, 2.0], [3.0, 4.0]] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([{"keys": "image"}, {"image": p(im)}]) +if has_pil: + TESTS.append([{"keys": "image"}, {"image": pil_image_fromarray(np.array(im))}]) class TestToPIL(unittest.TestCase): - @parameterized.expand([TEST_CASE_ARRAY_1]) + @parameterized.expand(TESTS) @skipUnless(has_pil, "Requires `pillow` package.") - def test_numpy_input(self, input_param, test_data): - self.assertTrue(isinstance(test_data[input_param["keys"]], np.ndarray)) + def test_values(self, input_param, test_data): result = ToPILd(**input_param)(test_data)[input_param["keys"]] self.assertTrue(isinstance(result, PILImageImage)) - np.testing.assert_allclose(np.array(result), test_data[input_param["keys"]]) - - @parameterized.expand([TEST_CASE__TENSOR_1]) - @skipUnless(has_pil, "Requires `pillow` package.") - def test_tensor_input(self, input_param, test_data): - self.assertTrue(isinstance(test_data[input_param["keys"]], torch.Tensor)) - result = ToPILd(**input_param)(test_data)[input_param["keys"]] - self.assertTrue(isinstance(result, PILImageImage)) - np.testing.assert_allclose(np.array(result), test_data[input_param["keys"]].numpy()) - - @parameterized.expand([TEST_CASE_ARRAY_1]) - @skipUnless(has_pil, "Requires `pillow` package.") - def test_pil_input(self, input_param, test_data): - input_array = test_data[input_param["keys"]] - test_data[input_param["keys"]] = pil_image_fromarray(input_array) - self.assertTrue(isinstance(test_data[input_param["keys"]], PILImageImage)) - result = ToPILd(**input_param)(test_data)[input_param["keys"]] - self.assertTrue(isinstance(result, PILImageImage)) - np.testing.assert_allclose(np.array(result), test_data[input_param["keys"]]) + assert_allclose(np.array(result), test_data[input_param["keys"]]) if __name__ == "__main__": diff --git a/tests/test_to_tensor.py b/tests/test_to_tensor.py index 4a36254743..6ac06983f6 100644 --- a/tests/test_to_tensor.py +++ b/tests/test_to_tensor.py @@ -11,24 +11,36 @@ import unittest -import numpy as np -import torch +from parameterized import parameterized from monai.transforms import ToTensor +from tests.utils import TEST_NDARRAYS, assert_allclose + +im = [[1, 2], [3, 4]] + +TESTS = [] +TESTS.append((im, (2, 2))) +for p in TEST_NDARRAYS: + TESTS.append((p(im), (2, 2))) + +TESTS_SINGLE = [] +TESTS_SINGLE.append([5]) +for p in TEST_NDARRAYS: + TESTS_SINGLE.append([p(5)]) class TestToTensor(unittest.TestCase): - def test_array_input(self): - for test_data in ([[1, 2], [3, 4]], np.array([[1, 2], [3, 4]]), torch.as_tensor([[1, 2], [3, 4]])): - result = ToTensor()(test_data) - torch.testing.assert_allclose(result, test_data) - self.assertTupleEqual(result.shape, (2, 2)) - - def test_single_input(self): - for test_data in (5, np.asarray(5), torch.tensor(5)): - result = ToTensor()(test_data) - torch.testing.assert_allclose(result, test_data) - self.assertEqual(result.ndim, 0) + @parameterized.expand(TESTS) + def test_array_input(self, test_data, expected_shape): + result = ToTensor()(test_data) + assert_allclose(result, test_data) + self.assertTupleEqual(result.shape, expected_shape) + + @parameterized.expand(TESTS_SINGLE) + def test_single_input(self, test_data): + result = ToTensor()(test_data) + assert_allclose(result, test_data) + self.assertEqual(result.ndim, 0) if __name__ == "__main__": diff --git a/tests/test_transpose.py b/tests/test_transpose.py index 3b758b5aa2..10882c9dd8 100644 --- a/tests/test_transpose.py +++ b/tests/test_transpose.py @@ -12,28 +12,37 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import Transpose - -TEST_CASE_0 = [ - np.arange(5 * 4).reshape(5, 4), - None, -] -TEST_CASE_1 = [ - np.arange(5 * 4 * 3).reshape(5, 4, 3), - [2, 0, 1], -] -TEST_CASES = [TEST_CASE_0, TEST_CASE_1] +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + p(np.arange(5 * 4).reshape(5, 4)), + None, + ] + ) + TESTS.append( + [ + p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), + [2, 0, 1], + ] + ) class TestTranspose(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_transpose(self, im, indices): tr = Transpose(indices) out1 = tr(im) + if isinstance(im, torch.Tensor): + im = im.cpu().numpy() out2 = np.transpose(im, indices) - np.testing.assert_array_equal(out1, out2) + assert_allclose(out1, out2) if __name__ == "__main__": diff --git a/tests/test_transposed.py b/tests/test_transposed.py index 56375f3981..88ecd0c872 100644 --- a/tests/test_transposed.py +++ b/tests/test_transposed.py @@ -13,44 +13,57 @@ from copy import deepcopy import numpy as np +import torch from parameterized import parameterized from monai.transforms import Transposed +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_0 = [ - np.arange(5 * 4).reshape(5, 4), - [1, 0], -] -TEST_CASE_1 = [ - np.arange(5 * 4).reshape(5, 4), - None, -] -TEST_CASE_2 = [ - np.arange(5 * 4 * 3).reshape(5, 4, 3), - [2, 0, 1], -] -TEST_CASE_3 = [ - np.arange(5 * 4 * 3).reshape(5, 4, 3), - None, -] -TEST_CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + p(np.arange(5 * 4).reshape(5, 4)), + [1, 0], + ] + ) + TESTS.append( + [ + p(np.arange(5 * 4).reshape(5, 4)), + None, + ] + ) + TESTS.append( + [ + p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), + [2, 0, 1], + ] + ) + TESTS.append( + [ + p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), + None, + ] + ) class TestTranspose(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_transpose(self, im, indices): data = {"i": deepcopy(im), "j": deepcopy(im)} tr = Transposed(["i", "j"], indices) out_data = tr(data) out_im1, out_im2 = out_data["i"], out_data["j"] + if isinstance(im, torch.Tensor): + im = im.cpu().numpy() out_gt = np.transpose(im, indices) - np.testing.assert_array_equal(out_im1, out_gt) - np.testing.assert_array_equal(out_im2, out_gt) + assert_allclose(out_im1, out_gt) + assert_allclose(out_im2, out_gt) # test inverse fwd_inv_data = tr.inverse(out_data) for i, j in zip(data.values(), fwd_inv_data.values()): - np.testing.assert_array_equal(i, j) + assert_allclose(i, j) if __name__ == "__main__":