From beab17dd107496d002da48afe32623623f1269aa Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 25 Aug 2021 11:04:12 +0100 Subject: [PATCH 01/17] backends -> backend Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 1 + monai/transforms/intensity/array.py | 10 +-- monai/transforms/utility/array.py | 2 +- monai/transforms/utils.py | 95 +++++++++++++++++--------- tests/test_print_transform_backends.py | 8 ++- 5 files changed, 74 insertions(+), 42 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 4203724a7d..5267af4048 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -500,6 +500,7 @@ get_extreme_points, get_largest_connected_component_mask, get_number_image_type_conversions, + get_transform_backends, img_bounds, in_bounds, is_empty, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 46c512c96c..b36c7adf96 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -131,7 +131,7 @@ class RandRicianNoise(RandomizableTransform): uniformly from 0 to std. """ - backends = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( self, @@ -197,7 +197,7 @@ class ShiftIntensity(Transform): offset: offset value to shift the intensity of image. """ - backends = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__(self, offset: float) -> None: self.offset = offset @@ -219,7 +219,7 @@ class RandShiftIntensity(RandomizableTransform): Randomly shift intensity with randomly picked offset. """ - backends = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__(self, offsets: Union[Tuple[float, float], float], prob: float = 0.1) -> None: """ @@ -273,7 +273,7 @@ class StdShiftIntensity(Transform): dtype: output data type, defaults to float32. """ - backends = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( self, factor: float, nonzero: bool = False, channel_wise: bool = False, dtype: DtypeLike = np.float32 @@ -318,7 +318,7 @@ class RandStdShiftIntensity(RandomizableTransform): by: ``v = v + factor * std(v)`` where the `factor` is randomly picked. """ - backends = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( self, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 3ef6413090..d56bca0d8d 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -291,7 +291,7 @@ class CastToType(Transform): specified PyTorch data type. """ - backends = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__(self, dtype=np.float32) -> None: """ diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e3e61b6c97..6281319973 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -10,6 +10,7 @@ # limitations under the License. import itertools +from monai.utils.enums import TransformBackends import random import warnings from contextlib import contextmanager @@ -81,6 +82,7 @@ "zero_margins", "equalize_hist", "get_number_image_type_conversions", + "get_transform_backends", "print_transform_backends", ] @@ -1157,23 +1159,17 @@ def _get_data(obj, key): num_conversions += 1 return num_conversions +def get_transform_backends() -> dict[str, List]: + """Get the backends of all MONAI transforms. -def print_transform_backends(): - """Prints a list of backends of all MONAI transforms.""" - - class Colours: - red = "91" - green = "92" - yellow = "93" - - def print_colour(t, colour): - print(f"\033[{colour}m{t}\033[00m") - - tr_total = 0 - tr_t_or_np = 0 - tr_t = 0 - tr_np = 0 - tr_uncategorised = 0 + Returns: + Dictionary, where each key is a transform, and its + corresponding values are a boolean list, stating + whether that transform supports (1) `torch.Tensor`, + and (2) `np.ndarray` as input without needing to + convert. + """ + backends = {} unique_transforms = [] for n, obj in getmembers(monai.transforms): # skip aliases @@ -1194,21 +1190,52 @@ def print_colour(t, colour): "InverteD", ]: continue - tr_total += 1 - if obj.backend == ["torch", "numpy"]: - tr_t_or_np += 1 - print_colour(f"TorchOrNumpy: {n}", Colours.green) - elif obj.backend == ["torch"]: - tr_t += 1 - print_colour(f"Torch: {n}", Colours.green) - elif obj.backend == ["numpy"]: - tr_np += 1 - print_colour(f"Numpy: {n}", Colours.yellow) - else: - tr_uncategorised += 1 - print_colour(f"Uncategorised: {n}", Colours.red) - print("Total number of transforms:", tr_total) - print_colour(f"Number transforms allowing both torch and numpy: {tr_t_or_np}", Colours.green) - print_colour(f"Number of TorchTransform: {tr_t}", Colours.green) - print_colour(f"Number of NumpyTransform: {tr_np}", Colours.yellow) - print_colour(f"Number of uncategorised: {tr_uncategorised}", Colours.red) + + backends[n] = [ + TransformBackends.TORCH in obj.backend, + TransformBackends.NUMPY in obj.backend, + ] + return backends + + +def print_transform_backends(): + """Prints a list of backends of all MONAI transforms.""" + class Colors: + none = "" + red = "91" + green = "92" + yellow = "93" + + def print_color(t, color): + print(f"\033[{color}m{t}\033[00m") + + def print_table_column(name, torch, numpy, color=Colors.none): + print_color("{:<50} {:<8} {:<8}".format(name, torch, numpy), color) + + backends = get_transform_backends() + n_total = len(backends) + n_t_or_np, n_t, n_np, n_uncategorized = 0, 0, 0, 0 + print_table_column("Transform", "Torch?", "Numpy?") + for k, v in backends.items(): + if all(v): + color = Colors.green + n_t_or_np += 1 + elif v[0]: + color = Colors.green + n_t += 1 + elif v[1]: + color = Colors.yellow + n_np += 1 + else: + color = Colors.red + n_uncategorized += 1 + print_table_column(k, *v, color) + + print("Total number of transforms:", n_total) + print_color(f"Number transforms allowing both torch and numpy: {n_t_or_np}", Colors.green) + print_color(f"Number of TorchTransform: {n_t}", Colors.green) + print_color(f"Number of NumpyTransform: {n_np}", Colors.yellow) + print_color(f"Number of uncategorised: {n_uncategorized}", Colors.red) + +if __name__ == "__main__": + print_transform_backends() diff --git a/tests/test_print_transform_backends.py b/tests/test_print_transform_backends.py index 09828f0a27..4164687f01 100644 --- a/tests/test_print_transform_backends.py +++ b/tests/test_print_transform_backends.py @@ -11,13 +11,17 @@ import unittest -from monai.transforms.utils import print_transform_backends +from monai.transforms.utils import get_transform_backends, print_transform_backends class TestPrintTransformBackends(unittest.TestCase): def test_get_number_of_conversions(self): + tr_t_or_np, *_ = get_transform_backends() + self.assertGreater(len(tr_t_or_np), 0) print_transform_backends() if __name__ == "__main__": - unittest.main() + # unittest.main() + a = TestPrintTransformBackends() + a.test_get_number_of_conversions() From c3192bad9b6585676a06966d57930790f89c4988 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 25 Aug 2021 11:10:42 +0100 Subject: [PATCH 02/17] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 6281319973..8bb4a38310 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -10,7 +10,6 @@ # limitations under the License. import itertools -from monai.utils.enums import TransformBackends import random import warnings from contextlib import contextmanager @@ -39,6 +38,7 @@ min_version, optional_import, ) +from monai.utils.enums import TransformBackends from monai.utils.type_conversion import convert_data_type measure, _ = optional_import("skimage.measure", "0.14.2", min_version) @@ -1159,6 +1159,7 @@ def _get_data(obj, key): num_conversions += 1 return num_conversions + def get_transform_backends() -> dict[str, List]: """Get the backends of all MONAI transforms. @@ -1200,6 +1201,7 @@ def get_transform_backends() -> dict[str, List]: def print_transform_backends(): """Prints a list of backends of all MONAI transforms.""" + class Colors: none = "" red = "91" @@ -1237,5 +1239,6 @@ def print_table_column(name, torch, numpy, color=Colors.none): print_color(f"Number of NumpyTransform: {n_np}", Colors.yellow) print_color(f"Number of uncategorised: {n_uncategorized}", Colors.red) + if __name__ == "__main__": print_transform_backends() From f790ad7acc151e33aabe88dd8ba43ab87dd057c0 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 25 Aug 2021 11:30:33 +0100 Subject: [PATCH 03/17] code format2 Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 8bb4a38310..30aa5e7b99 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1160,7 +1160,7 @@ def _get_data(obj, key): return num_conversions -def get_transform_backends() -> dict[str, List]: +def get_transform_backends(): """Get the backends of all MONAI transforms. Returns: From 1d2870481031e0dd7df79caf0daf7c3191f3cc13 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 25 Aug 2021 11:55:22 +0100 Subject: [PATCH 04/17] AddChannel, AsChannelFirst, AsChannelLast, EnsureChannelFirst, Identity, RepeatChannel Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utility/array.py | 46 +++++++++++++++++++------- monai/transforms/utility/dictionary.py | 26 ++++++++++----- tests/test_add_channeld.py | 17 ++++++---- tests/test_as_channel_first.py | 17 +++++----- tests/test_as_channel_firstd.py | 21 ++++++------ tests/test_as_channel_last.py | 17 +++++----- tests/test_as_channel_lastd.py | 21 ++++++------ tests/test_ensure_channel_first.py | 8 +++-- tests/test_ensure_channel_firstd.py | 9 +++-- tests/test_identity.py | 11 +++--- tests/test_identityd.py | 13 ++++---- tests/test_repeat_channel.py | 8 +++-- tests/test_repeat_channeld.py | 17 ++++++---- 13 files changed, 142 insertions(+), 89 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index d56bca0d8d..fc233b50cc 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -82,17 +82,18 @@ class Identity(Transform): """ - Convert the input to an np.ndarray, if input data is np.ndarray or subclasses, return unchanged data. + Do nothing to the data. As the output value is same as input, it can be used as a testing tool to verify the transform chain, Compose or transform adaptor, etc. - """ - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> np.ndarray: + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - return np.asanyarray(img) + return img class AsChannelFirst(Transform): @@ -111,16 +112,23 @@ class AsChannelFirst(Transform): channel_dim: which dimension of input image is the channel, default is the last dimension. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, channel_dim: int = -1) -> None: if not (isinstance(channel_dim, int) and channel_dim >= -1): raise AssertionError("invalid channel dimension.") self.channel_dim = channel_dim - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - return np.moveaxis(img, self.channel_dim, 0) + # old versions of pytorch don't have moveaxis + if isinstance(img, torch.Tensor) and hasattr(torch, "moveaxis"): + return torch.moveaxis(img, self.channel_dim, 0) + + img, *_ = convert_data_type(img, np.ndarray) + return np.moveaxis(img, self.channel_dim, 0) # type: ignore class AsChannelLast(Transform): @@ -138,16 +146,23 @@ class AsChannelLast(Transform): channel_dim: which dimension of input image is the channel, default is the first dimension. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, channel_dim: int = 0) -> None: if not (isinstance(channel_dim, int) and channel_dim >= -1): raise AssertionError("invalid channel dimension.") self.channel_dim = channel_dim - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - return np.moveaxis(img, self.channel_dim, -1) + # old versions of pytorch don't have moveaxis + if isinstance(img, torch.Tensor) and hasattr(torch, "moveaxis"): + return torch.moveaxis(img, self.channel_dim, -1) + + img, *_ = convert_data_type(img, np.ndarray) + return np.moveaxis(img, self.channel_dim, -1) # type: ignore class AddChannel(Transform): @@ -164,7 +179,9 @@ class AddChannel(Transform): transforms. """ - def __call__(self, img: NdarrayTensor): + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ @@ -179,6 +196,8 @@ class EnsureChannelFirst(Transform): Convert the data to `channel_first` based on the `original_channel_dim` information. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, strict_check: bool = True): """ Args: @@ -186,7 +205,7 @@ def __init__(self, strict_check: bool = True): """ self.strict_check = strict_check - def __call__(self, img: np.ndarray, meta_dict: Optional[Mapping] = None): + def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> NdarrayOrTensor: """ Apply the transform to `img`. """ @@ -220,16 +239,19 @@ class RepeatChannel(Transform): repeats: the number of repetitions for each element. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, repeats: int) -> None: if repeats <= 0: raise AssertionError("repeats count must be greater than 0.") self.repeats = repeats - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is a "channel-first" array. """ - return np.repeat(img, self.repeats, 0) + repeeat_fn = torch.repeat_interleave if isinstance(img, torch.Tensor) else np.repeat + return repeeat_fn(img, self.repeats, 0) # type: ignore class RemoveRepeatedChannel(Transform): diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index a53e4f3235..41c2a1b9b9 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -169,6 +169,8 @@ class Identityd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Identity`. """ + backend = Identity.backend + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: @@ -180,9 +182,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.identity = Identity() - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.identity(d[key]) @@ -194,6 +194,8 @@ class AsChannelFirstd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.AsChannelFirst`. """ + backend = AsChannelFirst.backend + def __init__(self, keys: KeysCollection, channel_dim: int = -1, allow_missing_keys: bool = False) -> None: """ Args: @@ -205,7 +207,7 @@ def __init__(self, keys: KeysCollection, channel_dim: int = -1, allow_missing_ke super().__init__(keys, allow_missing_keys) self.converter = AsChannelFirst(channel_dim=channel_dim) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -217,6 +219,8 @@ class AsChannelLastd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.AsChannelLast`. """ + backend = AsChannelLast.backend + def __init__(self, keys: KeysCollection, channel_dim: int = 0, allow_missing_keys: bool = False) -> None: """ Args: @@ -228,7 +232,7 @@ def __init__(self, keys: KeysCollection, channel_dim: int = 0, allow_missing_key super().__init__(keys, allow_missing_keys) self.converter = AsChannelLast(channel_dim=channel_dim) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -240,6 +244,8 @@ class AddChanneld(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.AddChannel`. """ + backend = AddChannel.backend + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: @@ -250,7 +256,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.adder = AddChannel() - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.adder(d[key]) @@ -262,6 +268,8 @@ class EnsureChannelFirstd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.EnsureChannelFirst`. """ + backend = EnsureChannelFirst.backend + def __init__( self, keys: KeysCollection, @@ -289,7 +297,7 @@ def __init__( self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - def __call__(self, data) -> Dict[Hashable, np.ndarray]: + def __call__(self, data) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, meta_key, meta_key_postfix in zip(self.keys, self.meta_keys, self.meta_key_postfix): d[key] = self.adjuster(d[key], d[meta_key or f"{key}_{meta_key_postfix}"]) @@ -301,6 +309,8 @@ class RepeatChanneld(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.RepeatChannel`. """ + backend = RepeatChannel.backend + def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool = False) -> None: """ Args: @@ -312,7 +322,7 @@ def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.repeater = RepeatChannel(repeats) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.repeater(d[key]) diff --git a/tests/test_add_channeld.py b/tests/test_add_channeld.py index ca4af37271..8bdd89a4ae 100644 --- a/tests/test_add_channeld.py +++ b/tests/test_add_channeld.py @@ -15,16 +15,21 @@ from parameterized import parameterized from monai.transforms import AddChanneld +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"keys": ["img", "seg"]}, - {"img": np.array([[0, 1], [1, 2]]), "seg": np.array([[0, 1], [1, 2]])}, - (1, 2, 2), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["img", "seg"]}, + {"img": p(np.array([[0, 1], [1, 2]])), "seg": p(np.array([[0, 1], [1, 2]]))}, + (1, 2, 2), + ] + ) class TestAddChanneld(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape): result = AddChanneld(**input_param)(input_data) self.assertEqual(result["img"].shape, expected_shape) diff --git a/tests/test_as_channel_first.py b/tests/test_as_channel_first.py index e7d9866ae1..bc9158f277 100644 --- a/tests/test_as_channel_first.py +++ b/tests/test_as_channel_first.py @@ -15,18 +15,19 @@ from parameterized import parameterized from monai.transforms import AsChannelFirst +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"channel_dim": -1}, (4, 1, 2, 3)] - -TEST_CASE_2 = [{"channel_dim": 3}, (4, 1, 2, 3)] - -TEST_CASE_3 = [{"channel_dim": 2}, (3, 1, 2, 4)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"channel_dim": -1}, (4, 1, 2, 3)]) + TESTS.append([p, {"channel_dim": 3}, (4, 1, 2, 3)]) + TESTS.append([p, {"channel_dim": 2}, (3, 1, 2, 4)]) class TestAsChannelFirst(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_param, expected_shape): - test_data = np.random.randint(0, 2, size=[1, 2, 3, 4]) + @parameterized.expand(TESTS) + def test_shape(self, in_type, input_param, expected_shape): + test_data = in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])) result = AsChannelFirst(**input_param)(test_data) self.assertTupleEqual(result.shape, expected_shape) diff --git a/tests/test_as_channel_firstd.py b/tests/test_as_channel_firstd.py index e70c2e1b47..68d33434c1 100644 --- a/tests/test_as_channel_firstd.py +++ b/tests/test_as_channel_firstd.py @@ -15,21 +15,22 @@ from parameterized import parameterized from monai.transforms import AsChannelFirstd +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"keys": ["image", "label", "extra"], "channel_dim": -1}, (4, 1, 2, 3)] - -TEST_CASE_2 = [{"keys": ["image", "label", "extra"], "channel_dim": 3}, (4, 1, 2, 3)] - -TEST_CASE_3 = [{"keys": ["image", "label", "extra"], "channel_dim": 2}, (3, 1, 2, 4)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": -1}, (4, 1, 2, 3)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 3}, (4, 1, 2, 3)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 2}, (3, 1, 2, 4)]) class TestAsChannelFirstd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_param, expected_shape): + @parameterized.expand(TESTS) + def test_shape(self, in_type, input_param, expected_shape): test_data = { - "image": np.random.randint(0, 2, size=[1, 2, 3, 4]), - "label": np.random.randint(0, 2, size=[1, 2, 3, 4]), - "extra": np.random.randint(0, 2, size=[1, 2, 3, 4]), + "image": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), + "label": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), + "extra": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), } result = AsChannelFirstd(**input_param)(test_data) self.assertTupleEqual(result["image"].shape, expected_shape) diff --git a/tests/test_as_channel_last.py b/tests/test_as_channel_last.py index 6ec6c8d6e6..55a7a08676 100644 --- a/tests/test_as_channel_last.py +++ b/tests/test_as_channel_last.py @@ -15,18 +15,19 @@ from parameterized import parameterized from monai.transforms import AsChannelLast +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"channel_dim": 0}, (2, 3, 4, 1)] - -TEST_CASE_2 = [{"channel_dim": 1}, (1, 3, 4, 2)] - -TEST_CASE_3 = [{"channel_dim": 3}, (1, 2, 3, 4)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"channel_dim": 0}, (2, 3, 4, 1)]) + TESTS.append([p, {"channel_dim": 1}, (1, 3, 4, 2)]) + TESTS.append([p, {"channel_dim": 3}, (1, 2, 3, 4)]) class TestAsChannelLast(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_param, expected_shape): - test_data = np.random.randint(0, 2, size=[1, 2, 3, 4]) + @parameterized.expand(TESTS) + def test_shape(self, in_type, input_param, expected_shape): + test_data = in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])) result = AsChannelLast(**input_param)(test_data) self.assertTupleEqual(result.shape, expected_shape) diff --git a/tests/test_as_channel_lastd.py b/tests/test_as_channel_lastd.py index 2ef4dd4da1..350f639f3f 100644 --- a/tests/test_as_channel_lastd.py +++ b/tests/test_as_channel_lastd.py @@ -15,21 +15,22 @@ from parameterized import parameterized from monai.transforms import AsChannelLastd +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"keys": ["image", "label", "extra"], "channel_dim": 0}, (2, 3, 4, 1)] - -TEST_CASE_2 = [{"keys": ["image", "label", "extra"], "channel_dim": 1}, (1, 3, 4, 2)] - -TEST_CASE_3 = [{"keys": ["image", "label", "extra"], "channel_dim": 3}, (1, 2, 3, 4)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 0}, (2, 3, 4, 1)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 1}, (1, 3, 4, 2)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 3}, (1, 2, 3, 4)]) class TestAsChannelLastd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_param, expected_shape): + @parameterized.expand(TESTS) + def test_shape(self, in_type, input_param, expected_shape): test_data = { - "image": np.random.randint(0, 2, size=[1, 2, 3, 4]), - "label": np.random.randint(0, 2, size=[1, 2, 3, 4]), - "extra": np.random.randint(0, 2, size=[1, 2, 3, 4]), + "image": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), + "label": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), + "extra": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), } result = AsChannelLastd(**input_param)(test_data) self.assertTupleEqual(result["image"].shape, expected_shape) diff --git a/tests/test_ensure_channel_first.py b/tests/test_ensure_channel_first.py index 6b9def1cea..23126d326f 100644 --- a/tests/test_ensure_channel_first.py +++ b/tests/test_ensure_channel_first.py @@ -21,6 +21,7 @@ from monai.data import ITKReader from monai.transforms import EnsureChannelFirst, LoadImage +from tests.utils import TEST_NDARRAYS TEST_CASE_1 = [{"image_only": False}, ["test_image.nii.gz"], None] @@ -61,9 +62,10 @@ def test_load_nifti(self, input_param, filenames, original_channel_dim): for i, name in enumerate(filenames): filenames[i] = os.path.join(tempdir, name) nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) - result, header = LoadImage(**input_param)(filenames) - result = EnsureChannelFirst()(result, header) - self.assertEqual(result.shape[0], len(filenames)) + for p in TEST_NDARRAYS: + result, header = LoadImage(**input_param)(filenames) + result = EnsureChannelFirst()(p(result), header) + self.assertEqual(result.shape[0], len(filenames)) @parameterized.expand([TEST_CASE_7]) def test_itk_dicom_series_reader(self, input_param, filenames, original_channel_dim): diff --git a/tests/test_ensure_channel_firstd.py b/tests/test_ensure_channel_firstd.py index 59eb32c576..b4cde02a8f 100644 --- a/tests/test_ensure_channel_firstd.py +++ b/tests/test_ensure_channel_firstd.py @@ -19,6 +19,7 @@ from PIL import Image from monai.transforms import EnsureChannelFirstd, LoadImaged +from tests.utils import TEST_NDARRAYS TEST_CASE_1 = [{"keys": "img"}, ["test_image.nii.gz"], None] @@ -43,9 +44,11 @@ def test_load_nifti(self, input_param, filenames, original_channel_dim): for i, name in enumerate(filenames): filenames[i] = os.path.join(tempdir, name) nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) - result = LoadImaged(**input_param)({"img": filenames}) - result = EnsureChannelFirstd(**input_param)(result) - self.assertEqual(result["img"].shape[0], len(filenames)) + for p in TEST_NDARRAYS: + result = LoadImaged(**input_param)({"img": filenames}) + result["img"] = p(result["img"]) + result = EnsureChannelFirstd(**input_param)(result) + self.assertEqual(result["img"].shape[0], len(filenames)) def test_load_png(self): spatial_size = (256, 256, 3) diff --git a/tests/test_identity.py b/tests/test_identity.py index 2dff2bb13d..172860668c 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -11,17 +11,16 @@ import unittest -import numpy as np - from monai.transforms.utility.array import Identity -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestIdentity(NumpyImageTestCase2D): def test_identity(self): - img = self.imt - identity = Identity() - self.assertTrue(np.allclose(img, identity(img))) + for p in TEST_NDARRAYS: + img = p(self.imt) + identity = Identity() + assert_allclose(img, identity(img)) if __name__ == "__main__": diff --git a/tests/test_identityd.py b/tests/test_identityd.py index 8796f28da8..665b7d5d1c 100644 --- a/tests/test_identityd.py +++ b/tests/test_identityd.py @@ -12,16 +12,17 @@ import unittest from monai.transforms.utility.dictionary import Identityd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestIdentityd(NumpyImageTestCase2D): def test_identityd(self): - img = self.imt - data = {} - data["img"] = img - identity = Identityd(keys=data.keys()) - self.assertEqual(data, identity(data)) + for p in TEST_NDARRAYS: + img = p(self.imt) + data = {} + data["img"] = img + identity = Identityd(keys=data.keys()) + assert_allclose(img, identity(data)["img"]) if __name__ == "__main__": diff --git a/tests/test_repeat_channel.py b/tests/test_repeat_channel.py index 643ebc64de..e246dd1212 100644 --- a/tests/test_repeat_channel.py +++ b/tests/test_repeat_channel.py @@ -11,16 +11,18 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import RepeatChannel +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"repeats": 3}, np.array([[[0, 1], [1, 2]]]), (3, 2, 2)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([{"repeats": 3}, p([[[0, 1], [1, 2]]]), (3, 2, 2)]) class TestRepeatChannel(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape): result = RepeatChannel(**input_param)(input_data) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_repeat_channeld.py b/tests/test_repeat_channeld.py index 7bd58bd1fe..3b73962bb9 100644 --- a/tests/test_repeat_channeld.py +++ b/tests/test_repeat_channeld.py @@ -15,16 +15,21 @@ from parameterized import parameterized from monai.transforms import RepeatChanneld +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"keys": ["img"], "repeats": 3}, - {"img": np.array([[[0, 1], [1, 2]]]), "seg": np.array([[[0, 1], [1, 2]]])}, - (3, 2, 2), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["img"], "repeats": 3}, + {"img": p(np.array([[[0, 1], [1, 2]]])), "seg": p(np.array([[[0, 1], [1, 2]]]))}, + (3, 2, 2), + ] + ) class TestRepeatChanneld(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape): result = RepeatChanneld(**input_param)(input_data) self.assertEqual(result["img"].shape, expected_shape) From be7eac206da7329ddffcdc585ed8b68499d63858 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 25 Aug 2021 15:59:31 +0100 Subject: [PATCH 05/17] moveaxis backwards compatible Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 4 ++ monai/transforms/utility/array.py | 18 +++------ .../utils_pytorch_numpy_unification.py | 37 +++++++++++++++++++ tests/test_as_channel_first.py | 10 ++++- 4 files changed, 54 insertions(+), 15 deletions(-) create mode 100644 monai/transforms/utils_pytorch_numpy_unification.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 5267af4048..b648b3331b 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -517,3 +517,7 @@ weighted_patch_samples, zero_margins, ) + +from .utils_pytorch_numpy_unification import ( + moveaxis, +) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index fc233b50cc..58d1bec34d 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -31,6 +31,9 @@ map_binary_to_indices, map_classes_to_indices, ) +from monai.transforms.utils_pytorch_numpy_unification import ( + moveaxis, +) from monai.utils import ( convert_to_numpy, convert_to_tensor, @@ -123,12 +126,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - # old versions of pytorch don't have moveaxis - if isinstance(img, torch.Tensor) and hasattr(torch, "moveaxis"): - return torch.moveaxis(img, self.channel_dim, 0) - - img, *_ = convert_data_type(img, np.ndarray) - return np.moveaxis(img, self.channel_dim, 0) # type: ignore + return moveaxis(img, self.channel_dim, 0) class AsChannelLast(Transform): @@ -157,13 +155,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - # old versions of pytorch don't have moveaxis - if isinstance(img, torch.Tensor) and hasattr(torch, "moveaxis"): - return torch.moveaxis(img, self.channel_dim, -1) - - img, *_ = convert_data_type(img, np.ndarray) - return np.moveaxis(img, self.channel_dim, -1) # type: ignore - + return moveaxis(img, self.channel_dim, -1) class AddChannel(Transform): """ diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py new file mode 100644 index 0000000000..6eab76678c --- /dev/null +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -0,0 +1,37 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np + +from monai.config.type_definitions import NdarrayOrTensor + + +__all__ = [ + "moveaxis", +] + + +def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor: + if isinstance(x, torch.Tensor): + if hasattr(torch, "moveaxis"): + return torch.moveaxis(x, src, dst) + # moveaxis only available in pytorch as of 1.8.0 + else: + # get original indices, remove desired index and insert it in new position + indices = list(range(x.ndim)) + indices.pop(src) + indices.insert(dst, src) + return x.permute(indices) + elif isinstance(x, np.ndarray): + return np.moveaxis(x, src, dst) + raise RuntimeError() + diff --git a/tests/test_as_channel_first.py b/tests/test_as_channel_first.py index bc9158f277..ffa66c5ec5 100644 --- a/tests/test_as_channel_first.py +++ b/tests/test_as_channel_first.py @@ -13,9 +13,10 @@ import numpy as np from parameterized import parameterized +import torch from monai.transforms import AsChannelFirst -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS, assert_allclose TESTS = [] for p in TEST_NDARRAYS: @@ -26,10 +27,15 @@ class TestAsChannelFirst(unittest.TestCase): @parameterized.expand(TESTS) - def test_shape(self, in_type, input_param, expected_shape): + def test_value(self, in_type, input_param, expected_shape): test_data = in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])) result = AsChannelFirst(**input_param)(test_data) self.assertTupleEqual(result.shape, expected_shape) + if isinstance(test_data, torch.Tensor): + test_data = test_data.cpu().numpy() + expected = np.moveaxis(test_data, input_param["channel_dim"], 0) + assert_allclose(expected, result) + if __name__ == "__main__": From 0d27527cbae154cb9c37dc9f06889e38c2f75a24 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 25 Aug 2021 16:39:24 +0100 Subject: [PATCH 06/17] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 5 +---- monai/transforms/utility/array.py | 5 ++--- monai/transforms/utils_pytorch_numpy_unification.py | 4 +--- tests/test_as_channel_first.py | 3 +-- 4 files changed, 5 insertions(+), 12 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index b648b3331b..28379aadd8 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -517,7 +517,4 @@ weighted_patch_samples, zero_margins, ) - -from .utils_pytorch_numpy_unification import ( - moveaxis, -) +from .utils_pytorch_numpy_unification import moveaxis diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 58d1bec34d..580c6c8b3c 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -31,9 +31,7 @@ map_binary_to_indices, map_classes_to_indices, ) -from monai.transforms.utils_pytorch_numpy_unification import ( - moveaxis, -) +from monai.transforms.utils_pytorch_numpy_unification import moveaxis from monai.utils import ( convert_to_numpy, convert_to_tensor, @@ -157,6 +155,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ return moveaxis(img, self.channel_dim, -1) + class AddChannel(Transform): """ Adds a 1-length channel dimension to the input image. diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 6eab76678c..dd9e54fa45 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -9,12 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch import numpy as np +import torch from monai.config.type_definitions import NdarrayOrTensor - __all__ = [ "moveaxis", ] @@ -34,4 +33,3 @@ def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor: elif isinstance(x, np.ndarray): return np.moveaxis(x, src, dst) raise RuntimeError() - diff --git a/tests/test_as_channel_first.py b/tests/test_as_channel_first.py index ffa66c5ec5..0d1b1c7d3a 100644 --- a/tests/test_as_channel_first.py +++ b/tests/test_as_channel_first.py @@ -12,8 +12,8 @@ import unittest import numpy as np -from parameterized import parameterized import torch +from parameterized import parameterized from monai.transforms import AsChannelFirst from tests.utils import TEST_NDARRAYS, assert_allclose @@ -37,6 +37,5 @@ def test_value(self, in_type, input_param, expected_shape): assert_allclose(expected, result) - if __name__ == "__main__": unittest.main() From c9dcd8da725d08ee382fa18d01fc3e21d65c5633 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 26 Aug 2021 12:23:35 +0100 Subject: [PATCH 07/17] EnsureType, RemoveRepeatedChannel, SplitChannel, ToCupy, ToNumpy, ToPil, ToTensor, Transpose Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utility/array.py | 67 +++++++++++------------ monai/transforms/utility/dictionary.py | 33 ++++++++---- monai/utils/type_conversion.py | 21 +++++--- tests/test_ensure_type.py | 15 ++++-- tests/test_ensure_typed.py | 15 ++++-- tests/test_remove_repeated_channel.py | 7 ++- tests/test_remove_repeated_channeld.py | 22 +++++--- tests/test_split_channel.py | 17 +++--- tests/test_split_channeld.py | 74 +++++++++++++++----------- tests/test_to_cupy.py | 12 +++++ tests/test_to_cupyd.py | 12 +++++ tests/test_to_numpy.py | 23 +++++--- tests/test_to_numpyd.py | 17 ++++-- tests/test_to_pil.py | 33 ++++-------- tests/test_to_pild.py | 34 ++++-------- tests/test_to_tensor.py | 38 ++++++++----- tests/test_transpose.py | 33 +++++++----- tests/test_transposed.py | 55 +++++++++++-------- 18 files changed, 315 insertions(+), 213 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 580c6c8b3c..f38a94302e 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -32,15 +32,7 @@ map_classes_to_indices, ) from monai.transforms.utils_pytorch_numpy_unification import moveaxis -from monai.utils import ( - convert_to_numpy, - convert_to_tensor, - ensure_tuple, - issequenceiterable, - look_up_option, - min_version, - optional_import, -) +from monai.utils import convert_to_numpy, convert_to_tensor, ensure_tuple, look_up_option, min_version, optional_import from monai.utils.enums import TransformBackends from monai.utils.type_conversion import convert_data_type @@ -255,20 +247,22 @@ class RemoveRepeatedChannel(Transform): repeats: the number of repetitions to be deleted for each element. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, repeats: int) -> None: if repeats <= 0: raise AssertionError("repeats count must be greater than 0.") self.repeats = repeats - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is a "channel-first" array. """ - if np.shape(img)[0] < 2: + if img.shape[0] < 2: raise AssertionError("Image must have more than one channel") - return np.array(img[:: self.repeats, :]) + return img[:: self.repeats, :] class SplitChannel(Transform): @@ -281,10 +275,12 @@ class SplitChannel(Transform): """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, channel_dim: int = 0) -> None: self.channel_dim = channel_dim - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> List[Union[np.ndarray, torch.Tensor]]: + def __call__(self, img: NdarrayOrTensor) -> List[NdarrayOrTensor]: n_classes = img.shape[self.channel_dim] if n_classes <= 1: raise RuntimeError("input image does not contain multiple channels.") @@ -335,18 +331,13 @@ class ToTensor(Transform): Converts the input image to a tensor without applying any other transformations. """ - def __call__(self, img) -> torch.Tensor: + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, img: NdarrayOrTensor) -> torch.Tensor: """ Apply the transform to `img` and make it contiguous. """ - if isinstance(img, torch.Tensor): - return img.contiguous() - if issequenceiterable(img): - # numpy array with 0 dims is also sequence iterable - if not (isinstance(img, np.ndarray) and img.ndim == 0): - # `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims - img = np.ascontiguousarray(img) - return torch.as_tensor(img) + return convert_to_tensor(img, wrap_sequence=True) # type: ignore class EnsureType(Transform): @@ -361,6 +352,8 @@ class EnsureType(Transform): """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, data_type: str = "tensor") -> None: data_type = data_type.lower() if data_type not in ("tensor", "numpy"): @@ -368,7 +361,7 @@ def __init__(self, data_type: str = "tensor") -> None: self.data_type = data_type - def __call__(self, data): + def __call__(self, data: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. @@ -377,7 +370,7 @@ def __call__(self, data): if applicable. """ - return convert_to_tensor(data) if self.data_type == "tensor" else convert_to_numpy(data) + return convert_to_tensor(data) if self.data_type == "tensor" else convert_to_numpy(data) # type: ignore class ToNumpy(Transform): @@ -385,17 +378,13 @@ class ToNumpy(Transform): Converts the input data to numpy array, can support list or tuple of numbers and PyTorch Tensor. """ - def __call__(self, img) -> np.ndarray: + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, img: NdarrayOrTensor) -> np.ndarray: """ Apply the transform to `img` and make it contiguous. """ - if isinstance(img, torch.Tensor): - img = img.detach().cpu().numpy() - elif has_cp and isinstance(img, cp_ndarray): - img = cp.asnumpy(img) - - array: np.ndarray = np.asarray(img) - return np.ascontiguousarray(array) if array.ndim > 0 else array + return convert_to_numpy(img) # type: ignore class ToCupy(Transform): @@ -403,13 +392,15 @@ class ToCupy(Transform): Converts the input data to CuPy array, can support list or tuple of numbers, NumPy and PyTorch Tensor. """ - def __call__(self, img): + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img` and make it contiguous. """ if isinstance(img, torch.Tensor): img = img.detach().cpu().numpy() - return cp.ascontiguousarray(cp.asarray(img)) + return cp.ascontiguousarray(cp.asarray(img)) # type: ignore class ToPIL(Transform): @@ -417,6 +408,8 @@ class ToPIL(Transform): Converts the input image (in the form of NumPy array or PyTorch Tensor) to PIL image """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __call__(self, img): """ Apply the transform to `img`. @@ -433,13 +426,17 @@ class Transpose(Transform): Transposes the input image based on the given `indices` dimension ordering. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, indices: Optional[Sequence[int]]) -> None: self.indices = None if indices is None else tuple(indices) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ + if isinstance(img, torch.Tensor): + return img.permute(self.indices or tuple(range(img.ndim)[::-1])) return img.transpose(self.indices) # type: ignore diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 41c2a1b9b9..1b63b308d9 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -334,6 +334,8 @@ class RemoveRepeatedChanneld(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.RemoveRepeatedChannel`. """ + backend = RemoveRepeatedChannel.backend + def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool = False) -> None: """ Args: @@ -345,7 +347,7 @@ def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.repeater = RemoveRepeatedChannel(repeats) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.repeater(d[key]) @@ -356,9 +358,10 @@ class SplitChanneld(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.SplitChannel`. All the input specified by `keys` should be split into same count of data. - """ + backend = SplitChannel.backend + def __init__( self, keys: KeysCollection, @@ -382,9 +385,7 @@ def __init__( self.output_postfixes = output_postfixes self.splitter = SplitChannel(channel_dim=channel_dim) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): rets = self.splitter(d[key]) @@ -439,6 +440,8 @@ class ToTensord(MapTransform, InvertibleTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ToTensor`. """ + backend = ToTensor.backend + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: @@ -449,14 +452,14 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.converter = ToTensor() - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): self.push_transform(d, key) d[key] = self.converter(d[key]) return d - def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): # Create inverse transform @@ -481,6 +484,8 @@ class EnsureTyped(MapTransform, InvertibleTransform): """ + backend = EnsureType.backend + def __init__(self, keys: KeysCollection, data_type: str = "tensor", allow_missing_keys: bool = False) -> None: """ Args: @@ -492,7 +497,7 @@ def __init__(self, keys: KeysCollection, data_type: str = "tensor", allow_missin super().__init__(keys, allow_missing_keys) self.converter = EnsureType(data_type=data_type) - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): self.push_transform(d, key) @@ -515,6 +520,8 @@ class ToNumpyd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ToNumpy`. """ + backend = ToNumpy.backend + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: @@ -537,6 +544,8 @@ class ToCupyd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ToCupy`. """ + backend = ToCupy.backend + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: @@ -547,7 +556,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.converter = ToCupy() - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -559,6 +568,8 @@ class ToPILd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ToNumpy`. """ + backend = ToPIL.backend + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: @@ -581,13 +592,15 @@ class Transposed(MapTransform, InvertibleTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Transpose`. """ + backend = Transpose.backend + def __init__( self, keys: KeysCollection, indices: Optional[Sequence[int]], allow_missing_keys: bool = False ) -> None: super().__init__(keys, allow_missing_keys) self.transform = Transpose(indices) - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.transform(d[key]) diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index e6df607764..14300eeca0 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -83,7 +83,7 @@ def get_dtype(data: Any): return type(data) -def convert_to_tensor(data): +def convert_to_tensor(data, wrap_sequence: bool = False): """ Utility to convert the input data to a PyTorch Tensor. If passing a dictionary, list or tuple, recursively check every item and convert it to PyTorch Tensor. @@ -92,6 +92,8 @@ def convert_to_tensor(data): data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will convert Tensor, Numpy array, float, int, bool to Tensors, strings and objects keep the original. for dictionary, list or tuple, convert every item to a Tensor if applicable. + wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. + If `True`, then `[1, 2]` -> `tensor([1, 2])`. """ if isinstance(data, torch.Tensor): @@ -105,17 +107,19 @@ def convert_to_tensor(data): return torch.as_tensor(data if data.ndim == 0 else np.ascontiguousarray(data)) elif isinstance(data, (float, int, bool)): return torch.as_tensor(data) - elif isinstance(data, dict): - return {k: convert_to_tensor(v) for k, v in data.items()} + elif isinstance(data, Sequence) and wrap_sequence: + return torch.as_tensor(data) elif isinstance(data, list): return [convert_to_tensor(i) for i in data] elif isinstance(data, tuple): return tuple(convert_to_tensor(i) for i in data) + elif isinstance(data, dict): + return {k: convert_to_tensor(v) for k, v in data.items()} return data -def convert_to_numpy(data): +def convert_to_numpy(data, wrap_sequence: bool = False): """ Utility to convert the input data to a numpy array. If passing a dictionary, list or tuple, recursively check every item and convert it to numpy array. @@ -124,7 +128,8 @@ def convert_to_numpy(data): data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will convert Tensor, Numpy array, float, int, bool to numpy arrays, strings and objects keep the original. for dictionary, list or tuple, convert every item to a numpy array if applicable. - + wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`. + If `True`, then `[1, 2]` -> `array([1, 2])`. """ if isinstance(data, torch.Tensor): data = data.detach().cpu().numpy() @@ -132,12 +137,14 @@ def convert_to_numpy(data): data = cp.asnumpy(data) elif isinstance(data, (float, int, bool)): data = np.asarray(data) - elif isinstance(data, dict): - return {k: convert_to_numpy(v) for k, v in data.items()} + elif isinstance(data, Sequence) and wrap_sequence: + return np.asarray(data) elif isinstance(data, list): return [convert_to_numpy(i) for i in data] elif isinstance(data, tuple): return tuple(convert_to_numpy(i) for i in data) + elif isinstance(data, dict): + return {k: convert_to_numpy(v) for k, v in data.items()} if isinstance(data, np.ndarray) and data.ndim > 0: data = np.ascontiguousarray(data) diff --git a/tests/test_ensure_type.py b/tests/test_ensure_type.py index 11cf6760fb..8feb96ed37 100644 --- a/tests/test_ensure_type.py +++ b/tests/test_ensure_type.py @@ -15,26 +15,33 @@ import torch from monai.transforms import EnsureType +from tests.utils import assert_allclose class TestEnsureType(unittest.TestCase): def test_array_input(self): - for test_data in (np.array([[1, 2], [3, 4]]), torch.as_tensor([[1, 2], [3, 4]])): + test_datas = [np.array([[1, 2], [3, 4]]), torch.as_tensor([[1, 2], [3, 4]])] + if torch.cuda.is_available(): + test_datas.append(test_datas[-1].cuda()) + for test_data in test_datas: for dtype in ("tensor", "NUMPY"): result = EnsureType(data_type=dtype)(test_data) self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) - torch.testing.assert_allclose(result, test_data) + assert_allclose(result, test_data) self.assertTupleEqual(result.shape, (2, 2)) def test_single_input(self): - for test_data in (5, 5.0, False, np.asarray(5), torch.tensor(5)): + test_datas = [5, 5.0, False, np.asarray(5), torch.tensor(5)] + if torch.cuda.is_available(): + test_datas.append(test_datas[-1].cuda()) + for test_data in test_datas: for dtype in ("tensor", "numpy"): result = EnsureType(data_type=dtype)(test_data) self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) if isinstance(test_data, bool): self.assertFalse(result) else: - torch.testing.assert_allclose(result, test_data) + assert_allclose(result, test_data) self.assertEqual(result.ndim, 0) def test_string(self): diff --git a/tests/test_ensure_typed.py b/tests/test_ensure_typed.py index c5f588d423..96f482afc2 100644 --- a/tests/test_ensure_typed.py +++ b/tests/test_ensure_typed.py @@ -15,26 +15,33 @@ import torch from monai.transforms import EnsureTyped +from tests.utils import assert_allclose class TestEnsureTyped(unittest.TestCase): def test_array_input(self): - for test_data in (np.array([[1, 2], [3, 4]]), torch.as_tensor([[1, 2], [3, 4]])): + test_datas = [np.array([[1, 2], [3, 4]]), torch.as_tensor([[1, 2], [3, 4]])] + if torch.cuda.is_available(): + test_datas.append(test_datas[-1].cuda()) + for test_data in test_datas: for dtype in ("tensor", "NUMPY"): result = EnsureTyped(keys="data", data_type=dtype)({"data": test_data})["data"] self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) - torch.testing.assert_allclose(result, test_data) + assert_allclose(result, test_data) self.assertTupleEqual(result.shape, (2, 2)) def test_single_input(self): - for test_data in (5, 5.0, False, np.asarray(5), torch.tensor(5)): + test_datas = [5, 5.0, False, np.asarray(5), torch.tensor(5)] + if torch.cuda.is_available(): + test_datas.append(test_datas[-1].cuda()) + for test_data in test_datas: for dtype in ("tensor", "numpy"): result = EnsureTyped(keys="data", data_type=dtype)({"data": test_data})["data"] self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) if isinstance(test_data, bool): self.assertFalse(result) else: - torch.testing.assert_allclose(result, test_data) + assert_allclose(result, test_data) self.assertEqual(result.ndim, 0) def test_string(self): diff --git a/tests/test_remove_repeated_channel.py b/tests/test_remove_repeated_channel.py index 070e0e2b8d..ebbe6c730c 100644 --- a/tests/test_remove_repeated_channel.py +++ b/tests/test_remove_repeated_channel.py @@ -12,15 +12,18 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RemoveRepeatedChannel -TEST_CASE_1 = [{"repeats": 2}, np.array([[1, 2], [1, 2], [3, 4], [3, 4]]), (2, 2)] +TEST_CASES = [] +for q in (torch.Tensor, np.array): + TEST_CASES.append([{"repeats": 2}, q([[1, 2], [1, 2], [3, 4], [3, 4]]), (2, 2)]) # type: ignore class TestRemoveRepeatedChannel(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_shape): result = RemoveRepeatedChannel(**input_param)(input_data) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_remove_repeated_channeld.py b/tests/test_remove_repeated_channeld.py index 46c68bbdc2..9d4812791e 100644 --- a/tests/test_remove_repeated_channeld.py +++ b/tests/test_remove_repeated_channeld.py @@ -15,16 +15,24 @@ from parameterized import parameterized from monai.transforms import RemoveRepeatedChanneld - -TEST_CASE_1 = [ - {"keys": ["img"], "repeats": 2}, - {"img": np.array([[1, 2], [1, 2], [3, 4], [3, 4]]), "seg": np.array([[1, 2], [1, 2], [3, 4], [3, 4]])}, - (2, 2), -] +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["img"], "repeats": 2}, + { + "img": p(np.array([[1, 2], [1, 2], [3, 4], [3, 4]])), + "seg": p(np.array([[1, 2], [1, 2], [3, 4], [3, 4]])), + }, + (2, 2), + ] + ) class TestRemoveRepeatedChanneld(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape): result = RemoveRepeatedChanneld(**input_param)(input_data) self.assertEqual(result["img"].shape, expected_shape) diff --git a/tests/test_split_channel.py b/tests/test_split_channel.py index 91e93aedcc..38315a102c 100644 --- a/tests/test_split_channel.py +++ b/tests/test_split_channel.py @@ -12,22 +12,21 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import SplitChannel +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"channel_dim": 1}, torch.randint(0, 2, size=(4, 3, 3, 4)), (4, 1, 3, 4)] - -TEST_CASE_2 = [{"channel_dim": 0}, np.random.randint(2, size=(3, 3, 4)), (1, 3, 4)] - -TEST_CASE_3 = [{"channel_dim": 2}, np.random.randint(2, size=(3, 2, 4)), (3, 2, 1)] - -TEST_CASE_4 = [{"channel_dim": -1}, np.random.randint(2, size=(3, 2, 4)), (3, 2, 1)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([{"channel_dim": 1}, p(np.random.randint(2, size=(4, 3, 3, 4))), (4, 1, 3, 4)]) + TESTS.append([{"channel_dim": 0}, p(np.random.randint(2, size=(3, 3, 4))), (1, 3, 4)]) + TESTS.append([{"channel_dim": 2}, p(np.random.randint(2, size=(3, 2, 4))), (3, 2, 1)]) + TESTS.append([{"channel_dim": -1}, p(np.random.randint(2, size=(3, 2, 4))), (3, 2, 1)]) class TestSplitChannel(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand(TESTS) def test_shape(self, input_param, test_data, expected_shape): result = SplitChannel(**input_param)(test_data) for data in result: diff --git a/tests/test_split_channeld.py b/tests/test_split_channeld.py index 57c7099b9f..f1df24364d 100644 --- a/tests/test_split_channeld.py +++ b/tests/test_split_channeld.py @@ -12,44 +12,56 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import SplitChanneld +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3"], "channel_dim": 1}, - {"pred": torch.randint(0, 2, size=(4, 3, 3, 4))}, - (4, 1, 3, 4), -] - -TEST_CASE_2 = [ - {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3"], "channel_dim": 0}, - {"pred": np.random.randint(2, size=(3, 3, 4))}, - (1, 3, 4), -] - -TEST_CASE_3 = [ - {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3", "cls4"], "channel_dim": 2}, - {"pred": np.random.randint(2, size=(3, 2, 4))}, - (3, 2, 1), -] - -TEST_CASE_4 = [ - {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3", "cls4"], "channel_dim": -1}, - {"pred": np.random.randint(2, size=(3, 2, 4))}, - (3, 2, 1), -] - -TEST_CASE_5 = [ - {"keys": "pred", "channel_dim": 1}, - {"pred": np.random.randint(2, size=(3, 2, 4))}, - (3, 1, 4), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3"], "channel_dim": 1}, + {"pred": p(np.random.randint(2, size=(4, 3, 3, 4)))}, + (4, 1, 3, 4), + ] + ) + + TESTS.append( + [ + {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3"], "channel_dim": 0}, + {"pred": p(np.random.randint(2, size=(3, 3, 4)))}, + (1, 3, 4), + ] + ) + + TESTS.append( + [ + {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3", "cls4"], "channel_dim": 2}, + {"pred": p(np.random.randint(2, size=(3, 2, 4)))}, + (3, 2, 1), + ] + ) + + TESTS.append( + [ + {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3", "cls4"], "channel_dim": -1}, + {"pred": p(np.random.randint(2, size=(3, 2, 4)))}, + (3, 2, 1), + ] + ) + + TESTS.append( + [ + {"keys": "pred", "channel_dim": 1}, + {"pred": p(np.random.randint(2, size=(3, 2, 4)))}, + (3, 1, 4), + ] + ) class TestSplitChanneld(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_shape(self, input_param, test_data, expected_shape): result = SplitChanneld(**input_param)(test_data) for k, v in result.items(): diff --git a/tests/test_to_cupy.py b/tests/test_to_cupy.py index 76c9464b20..a9460bc825 100644 --- a/tests/test_to_cupy.py +++ b/tests/test_to_cupy.py @@ -17,6 +17,7 @@ from monai.transforms import ToCupy from monai.utils import optional_import +from tests.utils import skip_if_no_cuda cp, has_cp = optional_import("cupy") @@ -52,6 +53,17 @@ def test_tensor_input(self): self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data.numpy()) + @skipUnless(has_cp, "CuPy is required.") + @skip_if_no_cuda + def test_tensor_cuda_input(self): + test_data = torch.tensor([[1, 2], [3, 4]]).cuda() + test_data = test_data.rot90() + self.assertFalse(test_data.is_contiguous()) + result = ToCupy()(test_data) + self.assertTrue(isinstance(result, cp.ndarray)) + self.assertTrue(result.flags["C_CONTIGUOUS"]) + cp.testing.assert_allclose(result, test_data.cpu().numpy()) + @skipUnless(has_cp, "CuPy is required.") def test_list_tuple(self): test_data = [[1, 2], [3, 4]] diff --git a/tests/test_to_cupyd.py b/tests/test_to_cupyd.py index b869bedc96..2f3c42dd1f 100644 --- a/tests/test_to_cupyd.py +++ b/tests/test_to_cupyd.py @@ -17,6 +17,7 @@ from monai.transforms import ToCupyd from monai.utils import optional_import +from tests.utils import skip_if_no_cuda cp, has_cp = optional_import("cupy") @@ -52,6 +53,17 @@ def test_tensor_input(self): self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data.numpy()) + @skipUnless(has_cp, "CuPy is required.") + @skip_if_no_cuda + def test_tensor_cuda_input(self): + test_data = torch.tensor([[1, 2], [3, 4]]).cuda() + test_data = test_data.rot90() + self.assertFalse(test_data.is_contiguous()) + result = ToCupyd(keys="img")({"img": test_data})["img"] + self.assertTrue(isinstance(result, cp.ndarray)) + self.assertTrue(result.flags["C_CONTIGUOUS"]) + cp.testing.assert_allclose(result, test_data.cpu().numpy()) + @skipUnless(has_cp, "CuPy is required.") def test_list_tuple(self): test_data = [[1, 2], [3, 4]] diff --git a/tests/test_to_numpy.py b/tests/test_to_numpy.py index 291601ffeb..fd49a3d473 100644 --- a/tests/test_to_numpy.py +++ b/tests/test_to_numpy.py @@ -17,6 +17,7 @@ from monai.transforms import ToNumpy from monai.utils import optional_import +from tests.utils import assert_allclose, skip_if_no_cuda cp, has_cp = optional_import("cupy") @@ -30,7 +31,7 @@ def test_cumpy_input(self): result = ToNumpy()(test_data) self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - np.testing.assert_allclose(result, test_data.get()) + assert_allclose(result, test_data.get()) def test_numpy_input(self): test_data = np.array([[1, 2], [3, 4]]) @@ -39,7 +40,7 @@ def test_numpy_input(self): result = ToNumpy()(test_data) self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - np.testing.assert_allclose(result, test_data) + assert_allclose(result, test_data) def test_tensor_input(self): test_data = torch.tensor([[1, 2], [3, 4]]) @@ -48,21 +49,31 @@ def test_tensor_input(self): result = ToNumpy()(test_data) self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - np.testing.assert_allclose(result, test_data.numpy()) + assert_allclose(result, test_data) + + @skip_if_no_cuda + def test_tensor_cuda_input(self): + test_data = torch.tensor([[1, 2], [3, 4]]).cuda() + test_data = test_data.rot90() + self.assertFalse(test_data.is_contiguous()) + result = ToNumpy()(test_data) + self.assertTrue(isinstance(result, np.ndarray)) + self.assertTrue(result.flags["C_CONTIGUOUS"]) + assert_allclose(result, test_data) def test_list_tuple(self): test_data = [[1, 2], [3, 4]] result = ToNumpy()(test_data) - np.testing.assert_allclose(result, np.asarray(test_data)) + assert_allclose(result, np.asarray(test_data)) test_data = ((1, 2), (3, 4)) result = ToNumpy()(test_data) - np.testing.assert_allclose(result, np.asarray(test_data)) + assert_allclose(result, np.asarray(test_data)) def test_single_value(self): for test_data in [5, np.array(5), torch.tensor(5)]: result = ToNumpy()(test_data) self.assertTrue(isinstance(result, np.ndarray)) - np.testing.assert_allclose(result, np.asarray(test_data)) + assert_allclose(result, np.asarray(test_data)) self.assertEqual(result.ndim, 0) diff --git a/tests/test_to_numpyd.py b/tests/test_to_numpyd.py index 1fb43ea2ac..adfab65904 100644 --- a/tests/test_to_numpyd.py +++ b/tests/test_to_numpyd.py @@ -17,6 +17,7 @@ from monai.transforms import ToNumpyd from monai.utils import optional_import +from tests.utils import assert_allclose, skip_if_no_cuda cp, has_cp = optional_import("cupy") @@ -30,7 +31,7 @@ def test_cumpy_input(self): result = ToNumpyd(keys="img")({"img": test_data})["img"] self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - np.testing.assert_allclose(result, test_data.get()) + assert_allclose(result, test_data.get()) def test_numpy_input(self): test_data = np.array([[1, 2], [3, 4]]) @@ -39,7 +40,7 @@ def test_numpy_input(self): result = ToNumpyd(keys="img")({"img": test_data})["img"] self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - np.testing.assert_allclose(result, test_data) + assert_allclose(result, test_data) def test_tensor_input(self): test_data = torch.tensor([[1, 2], [3, 4]]) @@ -48,7 +49,17 @@ def test_tensor_input(self): result = ToNumpyd(keys="img")({"img": test_data})["img"] self.assertTrue(isinstance(result, np.ndarray)) self.assertTrue(result.flags["C_CONTIGUOUS"]) - np.testing.assert_allclose(result, test_data.numpy()) + assert_allclose(result, test_data) + + @skip_if_no_cuda + def test_tensor_cuda_input(self): + test_data = torch.tensor([[1, 2], [3, 4]]).cuda() + test_data = test_data.rot90() + self.assertFalse(test_data.is_contiguous()) + result = ToNumpyd(keys="img")({"img": test_data})["img"] + self.assertTrue(isinstance(result, np.ndarray)) + self.assertTrue(result.flags["C_CONTIGUOUS"]) + assert_allclose(result, test_data) if __name__ == "__main__": diff --git a/tests/test_to_pil.py b/tests/test_to_pil.py index ec63750ce4..f6e95a5834 100644 --- a/tests/test_to_pil.py +++ b/tests/test_to_pil.py @@ -14,11 +14,11 @@ from unittest import skipUnless import numpy as np -import torch from parameterized import parameterized from monai.transforms import ToPIL from monai.utils import optional_import +from tests.utils import TEST_NDARRAYS, assert_allclose if TYPE_CHECKING: from PIL.Image import Image as PILImageImage @@ -29,35 +29,20 @@ pil_image_fromarray, has_pil = optional_import("PIL.Image", name="fromarray") PILImageImage, _ = optional_import("PIL.Image", name="Image") -TEST_CASE_ARRAY_1 = [np.array([[1.0, 2.0], [3.0, 4.0]])] -TEST_CASE_TENSOR_1 = [torch.tensor([[1.0, 2.0], [3.0, 4.0]])] +im = [[1.0, 2.0], [3.0, 4.0]] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p(im)]) +TESTS.append([pil_image_fromarray(np.array(im))]) class TestToPIL(unittest.TestCase): - @parameterized.expand([TEST_CASE_ARRAY_1]) + @parameterized.expand(TESTS) @skipUnless(has_pil, "Requires `pillow` package.") - def test_numpy_input(self, test_data): - self.assertTrue(isinstance(test_data, np.ndarray)) + def test_value(self, test_data): result = ToPIL()(test_data) self.assertTrue(isinstance(result, PILImageImage)) - np.testing.assert_allclose(np.array(result), test_data) - - @parameterized.expand([TEST_CASE_TENSOR_1]) - @skipUnless(has_pil, "Requires `pillow` package.") - def test_tensor_input(self, test_data): - self.assertTrue(isinstance(test_data, torch.Tensor)) - result = ToPIL()(test_data) - self.assertTrue(isinstance(result, PILImageImage)) - np.testing.assert_allclose(np.array(result), test_data.numpy()) - - @parameterized.expand([TEST_CASE_ARRAY_1]) - @skipUnless(has_pil, "Requires `pillow` package.") - def test_pil_input(self, test_data): - test_data_pil = pil_image_fromarray(test_data) - self.assertTrue(isinstance(test_data_pil, PILImageImage)) - result = ToPIL()(test_data_pil) - self.assertTrue(isinstance(result, PILImageImage)) - np.testing.assert_allclose(np.array(result), test_data) + assert_allclose(np.array(result), test_data) if __name__ == "__main__": diff --git a/tests/test_to_pild.py b/tests/test_to_pild.py index 43778022ee..2fb9358b1d 100644 --- a/tests/test_to_pild.py +++ b/tests/test_to_pild.py @@ -14,11 +14,11 @@ from unittest import skipUnless import numpy as np -import torch from parameterized import parameterized from monai.transforms import ToPILd from monai.utils import optional_import +from tests.utils import TEST_NDARRAYS, assert_allclose if TYPE_CHECKING: from PIL.Image import Image as PILImageImage @@ -29,36 +29,20 @@ pil_image_fromarray, has_pil = optional_import("PIL.Image", name="fromarray") PILImageImage, _ = optional_import("PIL.Image", name="Image") -TEST_CASE_ARRAY_1 = [{"keys": "image"}, {"image": np.array([[1.0, 2.0], [3.0, 4.0]])}] -TEST_CASE__TENSOR_1 = [{"keys": "image"}, {"image": torch.tensor([[1.0, 2.0], [3.0, 4.0]])}] +im = [[1.0, 2.0], [3.0, 4.0]] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([{"keys": "image"}, {"image": p(im)}]) +TESTS.append([{"keys": "image"}, {"image": pil_image_fromarray(np.array(im))}]) class TestToPIL(unittest.TestCase): - @parameterized.expand([TEST_CASE_ARRAY_1]) + @parameterized.expand(TESTS) @skipUnless(has_pil, "Requires `pillow` package.") - def test_numpy_input(self, input_param, test_data): - self.assertTrue(isinstance(test_data[input_param["keys"]], np.ndarray)) + def test_values(self, input_param, test_data): result = ToPILd(**input_param)(test_data)[input_param["keys"]] self.assertTrue(isinstance(result, PILImageImage)) - np.testing.assert_allclose(np.array(result), test_data[input_param["keys"]]) - - @parameterized.expand([TEST_CASE__TENSOR_1]) - @skipUnless(has_pil, "Requires `pillow` package.") - def test_tensor_input(self, input_param, test_data): - self.assertTrue(isinstance(test_data[input_param["keys"]], torch.Tensor)) - result = ToPILd(**input_param)(test_data)[input_param["keys"]] - self.assertTrue(isinstance(result, PILImageImage)) - np.testing.assert_allclose(np.array(result), test_data[input_param["keys"]].numpy()) - - @parameterized.expand([TEST_CASE_ARRAY_1]) - @skipUnless(has_pil, "Requires `pillow` package.") - def test_pil_input(self, input_param, test_data): - input_array = test_data[input_param["keys"]] - test_data[input_param["keys"]] = pil_image_fromarray(input_array) - self.assertTrue(isinstance(test_data[input_param["keys"]], PILImageImage)) - result = ToPILd(**input_param)(test_data)[input_param["keys"]] - self.assertTrue(isinstance(result, PILImageImage)) - np.testing.assert_allclose(np.array(result), test_data[input_param["keys"]]) + assert_allclose(np.array(result), test_data[input_param["keys"]]) if __name__ == "__main__": diff --git a/tests/test_to_tensor.py b/tests/test_to_tensor.py index 4a36254743..6ac06983f6 100644 --- a/tests/test_to_tensor.py +++ b/tests/test_to_tensor.py @@ -11,24 +11,36 @@ import unittest -import numpy as np -import torch +from parameterized import parameterized from monai.transforms import ToTensor +from tests.utils import TEST_NDARRAYS, assert_allclose + +im = [[1, 2], [3, 4]] + +TESTS = [] +TESTS.append((im, (2, 2))) +for p in TEST_NDARRAYS: + TESTS.append((p(im), (2, 2))) + +TESTS_SINGLE = [] +TESTS_SINGLE.append([5]) +for p in TEST_NDARRAYS: + TESTS_SINGLE.append([p(5)]) class TestToTensor(unittest.TestCase): - def test_array_input(self): - for test_data in ([[1, 2], [3, 4]], np.array([[1, 2], [3, 4]]), torch.as_tensor([[1, 2], [3, 4]])): - result = ToTensor()(test_data) - torch.testing.assert_allclose(result, test_data) - self.assertTupleEqual(result.shape, (2, 2)) - - def test_single_input(self): - for test_data in (5, np.asarray(5), torch.tensor(5)): - result = ToTensor()(test_data) - torch.testing.assert_allclose(result, test_data) - self.assertEqual(result.ndim, 0) + @parameterized.expand(TESTS) + def test_array_input(self, test_data, expected_shape): + result = ToTensor()(test_data) + assert_allclose(result, test_data) + self.assertTupleEqual(result.shape, expected_shape) + + @parameterized.expand(TESTS_SINGLE) + def test_single_input(self, test_data): + result = ToTensor()(test_data) + assert_allclose(result, test_data) + self.assertEqual(result.ndim, 0) if __name__ == "__main__": diff --git a/tests/test_transpose.py b/tests/test_transpose.py index 3b758b5aa2..10882c9dd8 100644 --- a/tests/test_transpose.py +++ b/tests/test_transpose.py @@ -12,28 +12,37 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import Transpose - -TEST_CASE_0 = [ - np.arange(5 * 4).reshape(5, 4), - None, -] -TEST_CASE_1 = [ - np.arange(5 * 4 * 3).reshape(5, 4, 3), - [2, 0, 1], -] -TEST_CASES = [TEST_CASE_0, TEST_CASE_1] +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + p(np.arange(5 * 4).reshape(5, 4)), + None, + ] + ) + TESTS.append( + [ + p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), + [2, 0, 1], + ] + ) class TestTranspose(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_transpose(self, im, indices): tr = Transpose(indices) out1 = tr(im) + if isinstance(im, torch.Tensor): + im = im.cpu().numpy() out2 = np.transpose(im, indices) - np.testing.assert_array_equal(out1, out2) + assert_allclose(out1, out2) if __name__ == "__main__": diff --git a/tests/test_transposed.py b/tests/test_transposed.py index 56375f3981..88ecd0c872 100644 --- a/tests/test_transposed.py +++ b/tests/test_transposed.py @@ -13,44 +13,57 @@ from copy import deepcopy import numpy as np +import torch from parameterized import parameterized from monai.transforms import Transposed +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_0 = [ - np.arange(5 * 4).reshape(5, 4), - [1, 0], -] -TEST_CASE_1 = [ - np.arange(5 * 4).reshape(5, 4), - None, -] -TEST_CASE_2 = [ - np.arange(5 * 4 * 3).reshape(5, 4, 3), - [2, 0, 1], -] -TEST_CASE_3 = [ - np.arange(5 * 4 * 3).reshape(5, 4, 3), - None, -] -TEST_CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + p(np.arange(5 * 4).reshape(5, 4)), + [1, 0], + ] + ) + TESTS.append( + [ + p(np.arange(5 * 4).reshape(5, 4)), + None, + ] + ) + TESTS.append( + [ + p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), + [2, 0, 1], + ] + ) + TESTS.append( + [ + p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), + None, + ] + ) class TestTranspose(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_transpose(self, im, indices): data = {"i": deepcopy(im), "j": deepcopy(im)} tr = Transposed(["i", "j"], indices) out_data = tr(data) out_im1, out_im2 = out_data["i"], out_data["j"] + if isinstance(im, torch.Tensor): + im = im.cpu().numpy() out_gt = np.transpose(im, indices) - np.testing.assert_array_equal(out_im1, out_gt) - np.testing.assert_array_equal(out_im2, out_gt) + assert_allclose(out_im1, out_gt) + assert_allclose(out_im2, out_gt) # test inverse fwd_inv_data = tr.inverse(out_data) for i, j in zip(data.values(), fwd_inv_data.values()): - np.testing.assert_array_equal(i, j) + assert_allclose(i, j) if __name__ == "__main__": From 754a6843ff6a99ffe9e411e68caf285fb5f034eb Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 26 Aug 2021 12:25:46 +0100 Subject: [PATCH 08/17] trigger ci/cd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> From e38e7a3fc048b7c27d7b93c45c3b40b150062e9e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 26 Aug 2021 15:42:55 +0100 Subject: [PATCH 09/17] permute requires positive indices Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utils_pytorch_numpy_unification.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index dd9e54fa45..bb16369cf0 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -25,8 +25,14 @@ def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor: return torch.moveaxis(x, src, dst) # moveaxis only available in pytorch as of 1.8.0 else: - # get original indices, remove desired index and insert it in new position + # get original indices indices = list(range(x.ndim)) + # make src and dst positive + if src < 0: + src = len(indices) - src + if dst < 0: + dst = len(indices) - dst + # remove desired index and insert it in new position indices.pop(src) indices.insert(dst, src) return x.permute(indices) From af2d2eccfc471544ef4b43a833f88db536f2beec Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 27 Aug 2021 10:41:09 +0100 Subject: [PATCH 10/17] correct permute Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utils_pytorch_numpy_unification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index bb16369cf0..e6dc151596 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -29,9 +29,9 @@ def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor: indices = list(range(x.ndim)) # make src and dst positive if src < 0: - src = len(indices) - src + src = len(indices) + src if dst < 0: - dst = len(indices) - dst + dst = len(indices) + dst # remove desired index and insert it in new position indices.pop(src) indices.insert(dst, src) From b1e476d267f4a8e22f7fa4cff0fc2b1fbc93662e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 27 Aug 2021 10:41:09 +0100 Subject: [PATCH 11/17] correct permute Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utils_pytorch_numpy_unification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index bb16369cf0..e6dc151596 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -29,9 +29,9 @@ def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor: indices = list(range(x.ndim)) # make src and dst positive if src < 0: - src = len(indices) - src + src = len(indices) + src if dst < 0: - dst = len(indices) - dst + dst = len(indices) + dst # remove desired index and insert it in new position indices.pop(src) indices.insert(dst, src) From 3a9e170eb91b57067411dfd94b39b4a7ea570107 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 27 Aug 2021 14:33:59 +0100 Subject: [PATCH 12/17] has_pil Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_to_pil.py | 3 ++- tests/test_to_pild.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_to_pil.py b/tests/test_to_pil.py index f6e95a5834..5690645dd8 100644 --- a/tests/test_to_pil.py +++ b/tests/test_to_pil.py @@ -33,7 +33,8 @@ TESTS = [] for p in TEST_NDARRAYS: TESTS.append([p(im)]) -TESTS.append([pil_image_fromarray(np.array(im))]) +if has_pil: + TESTS.append([pil_image_fromarray(np.array(im))]) class TestToPIL(unittest.TestCase): diff --git a/tests/test_to_pild.py b/tests/test_to_pild.py index 2fb9358b1d..3a15b1e507 100644 --- a/tests/test_to_pild.py +++ b/tests/test_to_pild.py @@ -33,7 +33,8 @@ TESTS = [] for p in TEST_NDARRAYS: TESTS.append([{"keys": "image"}, {"image": p(im)}]) -TESTS.append([{"keys": "image"}, {"image": pil_image_fromarray(np.array(im))}]) +if has_pil: + TESTS.append([{"keys": "image"}, {"image": pil_image_fromarray(np.array(im))}]) class TestToPIL(unittest.TestCase): From 162e7d5a4176a394a352643a456d06bdd68d469e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 27 Aug 2021 14:52:53 +0000 Subject: [PATCH 13/17] DataStats, LabelToMask, Lambda, RandLambda, SqueezeDim, is_module_ver_at_least Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 2 +- monai/transforms/utility/array.py | 85 +++++++++++------ monai/transforms/utility/dictionary.py | 50 ++++++---- .../utils_pytorch_numpy_unification.py | 10 ++ monai/utils/__init__.py | 1 + monai/utils/misc.py | 16 +++- tests/test_data_stats.py | 2 +- tests/test_data_statsd.py | 2 +- tests/test_label_to_mask.py | 75 ++++++++------- tests/test_label_to_maskd.py | 80 +++++++++------- tests/test_lambda.py | 26 +++--- tests/test_lambdad.py | 39 ++++---- tests/test_squeezedim.py | 28 +++--- tests/test_squeezedimd.py | 92 +++++++++++-------- tests/utils.py | 4 +- 15 files changed, 312 insertions(+), 200 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index b0ba1e39d9..2ea7e3aa63 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -518,4 +518,4 @@ weighted_patch_samples, zero_margins, ) -from .utils_pytorch_numpy_unification import moveaxis +from .utils_pytorch_numpy_unification import in1d, moveaxis diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index f38a94302e..a09d7ff641 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -22,7 +22,7 @@ import numpy as np import torch -from monai.config import DtypeLike, NdarrayTensor +from monai.config import DtypeLike from monai.config.type_definitions import NdarrayOrTensor from monai.transforms.transform import Randomizable, RandomizableTransform, Transform from monai.transforms.utils import ( @@ -31,9 +31,10 @@ map_binary_to_indices, map_classes_to_indices, ) -from monai.transforms.utils_pytorch_numpy_unification import moveaxis +from monai.transforms.utils_pytorch_numpy_unification import in1d, moveaxis from monai.utils import convert_to_numpy, convert_to_tensor, ensure_tuple, look_up_option, min_version, optional_import from monai.utils.enums import TransformBackends +from monai.utils.misc import is_module_ver_at_least from monai.utils.type_conversion import convert_data_type PILImageImage, has_pil = optional_import("PIL.Image", name="Image") @@ -445,6 +446,8 @@ class SqueezeDim(Transform): Squeeze a unitary dimension. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, dim: Optional[int] = 0) -> None: """ Args: @@ -459,12 +462,17 @@ def __init__(self, dim: Optional[int] = 0) -> None: raise TypeError(f"dim must be None or a int but is {type(dim).__name__}.") self.dim = dim - def __call__(self, img: NdarrayTensor) -> NdarrayTensor: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: numpy arrays with required dimension `dim` removed """ - return img.squeeze(self.dim) # type: ignore + if self.dim is None: + return img.squeeze() + # for pytorch/numpy unification + if img.shape[self.dim] != 1: + raise ValueError("Can only squeeze singleton dimension") + return img.squeeze(self.dim) class DataStats(Transform): @@ -475,6 +483,8 @@ class DataStats(Transform): so it can be used in pre-processing and post-processing. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, prefix: str = "Data", @@ -523,14 +533,14 @@ def __init__( def __call__( self, - img: NdarrayTensor, + img: NdarrayOrTensor, prefix: Optional[str] = None, data_type: Optional[bool] = None, data_shape: Optional[bool] = None, value_range: Optional[bool] = None, data_value: Optional[bool] = None, additional_info: Optional[Callable] = None, - ) -> NdarrayTensor: + ) -> NdarrayOrTensor: """ Apply the transform to `img`, optionally take arguments similar to the class constructor. """ @@ -570,6 +580,8 @@ class SimulateDelay(Transform): to sub-optimal design choices. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, delay_time: float = 0.0) -> None: """ Args: @@ -579,7 +591,7 @@ def __init__(self, delay_time: float = 0.0) -> None: super().__init__() self.delay_time: float = delay_time - def __call__(self, img: NdarrayTensor, delay_time: Optional[float] = None) -> NdarrayTensor: + def __call__(self, img: NdarrayOrTensor, delay_time: Optional[float] = None) -> NdarrayOrTensor: """ Args: img: data remain unchanged throughout this transform. @@ -612,12 +624,14 @@ class Lambda(Transform): """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, func: Optional[Callable] = None) -> None: if func is not None and not callable(func): raise TypeError(f"func must be None or callable but is {type(func).__name__}.") self.func = func - def __call__(self, img: Union[np.ndarray, torch.Tensor], func: Optional[Callable] = None): + def __call__(self, img: NdarrayOrTensor, func: Optional[Callable] = None): """ Apply `self.func` to `img`. @@ -648,14 +662,15 @@ class RandLambda(Lambda, RandomizableTransform): prob: probability of executing the random function, default to 1.0, with 100% probability to execute. For more details, please check :py:class:`monai.transforms.Lambda`. - """ + backend = Lambda.backend + def __init__(self, func: Optional[Callable] = None, prob: float = 1.0) -> None: Lambda.__init__(self=self, func=func) RandomizableTransform.__init__(self=self, prob=prob) - def __call__(self, img: Union[np.ndarray, torch.Tensor], func: Optional[Callable] = None): + def __call__(self, img: NdarrayOrTensor, func: Optional[Callable] = None): self.randomize(img) return super().__call__(img=img, func=func) if self._do_transform else img @@ -679,6 +694,8 @@ class LabelToMask(Transform): """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( # pytype: disable=annotation-type-mismatch self, select_labels: Union[Sequence[int], int], @@ -688,8 +705,11 @@ def __init__( # pytype: disable=annotation-type-mismatch self.merge_channels = merge_channels def __call__( - self, img: np.ndarray, select_labels: Optional[Union[Sequence[int], int]] = None, merge_channels: bool = False - ): + self, + img: NdarrayOrTensor, + select_labels: Optional[Union[Sequence[int], int]] = None, + merge_channels: bool = False, + ) -> NdarrayOrTensor: """ Args: select_labels: labels to generate mask from. for 1 channel label, the `select_labels` @@ -706,26 +726,39 @@ def __call__( if img.shape[0] > 1: data = img[[*select_labels]] else: - data = np.where(np.in1d(img, select_labels), True, False).reshape(img.shape) + where = np.where if isinstance(img, np.ndarray) else torch.where + if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): + data = where(in1d(img, select_labels), True, False).reshape(img.shape) + # pre pytorch 1.8.0, need to use 1/0 instead of True/False + else: + data = where(in1d(img, select_labels), 1, 0).reshape(img.shape) + + if merge_channels or self.merge_channels: + if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): + return data.any(0)[None] + # pre pytorch 1.8.0 compatibility + else: + return data.to(torch.uint8).any(0)[None].to(bool) # type: ignore - return np.any(data, axis=0, keepdims=True) if (merge_channels or self.merge_channels) else data + return data class FgBgToIndices(Transform): - def __init__(self, image_threshold: float = 0.0, output_shape: Optional[Sequence[int]] = None) -> None: - """ - Compute foreground and background of the input label data, return the indices. - If no output_shape specified, output data will be 1 dim indices after flattening. - This transform can help pre-compute foreground and background regions for other transforms. - A typical usage is to randomly select foreground and background to crop. - The main logic is based on :py:class:`monai.transforms.utils.map_binary_to_indices`. + """ + Compute foreground and background of the input label data, return the indices. + If no output_shape specified, output data will be 1 dim indices after flattening. + This transform can help pre-compute foreground and background regions for other transforms. + A typical usage is to randomly select foreground and background to crop. + The main logic is based on :py:class:`monai.transforms.utils.map_binary_to_indices`. - Args: - image_threshold: if enabled `image` at runtime, use ``image > image_threshold`` to - determine the valid image content area and select background only in this area. - output_shape: expected shape of output indices. if not None, unravel indices to specified shape. + Args: + image_threshold: if enabled `image` at runtime, use ``image > image_threshold`` to + determine the valid image content area and select background only in this area. + output_shape: expected shape of output indices. if not None, unravel indices to specified shape. - """ + """ + + def __init__(self, image_threshold: float = 0.0, output_shape: Optional[Sequence[int]] = None) -> None: self.image_threshold = image_threshold self.output_shape = output_shape diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 1b63b308d9..e9bcce93b0 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -23,7 +23,7 @@ import numpy as np import torch -from monai.config import DtypeLike, KeysCollection, NdarrayTensor +from monai.config import DtypeLike, KeysCollection from monai.config.type_definitions import NdarrayOrTensor from monai.data.utils import no_collation from monai.transforms.inverse import InvertibleTransform @@ -59,7 +59,7 @@ ) from monai.transforms.utils import extreme_points_to_image, get_extreme_points from monai.utils import convert_to_numpy, ensure_tuple, ensure_tuple_rep -from monai.utils.enums import InverseKeys +from monai.utils.enums import InverseKeys, TransformBackends __all__ = [ "AddChannelD", @@ -650,6 +650,8 @@ class SqueezeDimd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.SqueezeDim`. """ + backend = SqueezeDim.backend + def __init__(self, keys: KeysCollection, dim: int = 0, allow_missing_keys: bool = False) -> None: """ Args: @@ -661,7 +663,7 @@ def __init__(self, keys: KeysCollection, dim: int = 0, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.converter = SqueezeDim(dim=dim) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -673,6 +675,8 @@ class DataStatsd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.DataStats`. """ + backend = DataStats.backend + def __init__( self, keys: KeysCollection, @@ -719,7 +723,7 @@ def __init__( self.logger_handler = logger_handler self.printer = DataStats(logger_handler=logger_handler) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, prefix, data_type, data_shape, value_range, data_value, additional_info in self.key_iterator( d, self.prefix, self.data_type, self.data_shape, self.value_range, self.data_value, self.additional_info @@ -741,6 +745,8 @@ class SimulateDelayd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.SimulateDelay`. """ + backend = SimulateDelay.backend + def __init__( self, keys: KeysCollection, delay_time: Union[Sequence[float], float] = 0.0, allow_missing_keys: bool = False ) -> None: @@ -757,7 +763,7 @@ def __init__( self.delay_time = ensure_tuple_rep(delay_time, len(self.keys)) self.delayer = SimulateDelay() - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, delay_time in self.key_iterator(d, self.delay_time): d[key] = self.delayer(d[key], delay_time=delay_time) @@ -768,9 +774,10 @@ class CopyItemsd(MapTransform): """ Copy specified items from data dictionary and save with different key names. It can copy several items together and copy several times. - """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, keys: KeysCollection, times: int, names: KeysCollection, allow_missing_keys: bool = False ) -> None: @@ -802,7 +809,7 @@ def __init__( ) self.names = names - def __call__(self, data): + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: """ Raises: KeyError: When a key in ``self.names`` already exists in ``data``. @@ -814,10 +821,11 @@ def __call__(self, data): for key, new_key in self.key_iterator(d, self.names[i * key_len : (i + 1) * key_len]): if new_key in d: raise KeyError(f"Key {new_key} already exists in data.") - if isinstance(d[key], torch.Tensor): - d[new_key] = d[key].detach().clone() + val = d[key] + if isinstance(val, torch.Tensor): + d[new_key] = val.detach().clone() else: - d[new_key] = copy.deepcopy(d[key]) + d[new_key] = copy.deepcopy(val) return d @@ -825,9 +833,10 @@ class ConcatItemsd(MapTransform): """ Concatenate specified items from data dictionary together on the first dim to construct a big array. Expect all the items are numpy array or PyTorch Tensor. - """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, keys: KeysCollection, name: str, dim: int = 0, allow_missing_keys: bool = False) -> None: """ Args: @@ -841,7 +850,7 @@ def __init__(self, keys: KeysCollection, name: str, dim: int = 0, allow_missing_ self.name = name self.dim = dim - def __call__(self, data): + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: """ Raises: TypeError: When items in ``data`` differ in type. @@ -857,10 +866,10 @@ def __call__(self, data): elif not isinstance(d[key], data_type): raise TypeError("All items in data must have the same type.") output.append(d[key]) - if data_type == np.ndarray: + if data_type is np.ndarray: d[self.name] = np.concatenate(output, axis=self.dim) - elif data_type == torch.Tensor: - d[self.name] = torch.cat(output, dim=self.dim) + elif data_type is torch.Tensor: + d[self.name] = torch.cat(output, dim=self.dim) # type: ignore else: raise TypeError(f"Unsupported data type: {data_type}, available options are (numpy.ndarray, torch.Tensor).") return d @@ -896,6 +905,8 @@ class Lambdad(MapTransform, InvertibleTransform): """ + backend = Lambda.backend + def __init__( self, keys: KeysCollection, @@ -913,7 +924,7 @@ def __init__( def _transform(self, data: Any, func: Callable): return self._lambd(data, func=func) - def __call__(self, data): + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, func, overwrite in self.key_iterator(d, self.func, self.overwrite): ret = self._transform(data=d[key], func=func) @@ -958,9 +969,10 @@ class RandLambdad(Lambdad, RandomizableTransform): Note: The inverse operation doesn't allow to define `extra_info` or access other information, such as the image's original size. If need these complicated information, please write a new InvertibleTransform directly. - """ + backend = Lambda.backend + def __init__( self, keys: KeysCollection, @@ -1007,6 +1019,8 @@ class LabelToMaskd(MapTransform): """ + backend = LabelToMask.backend + def __init__( # pytype: disable=annotation-type-mismatch self, keys: KeysCollection, @@ -1017,7 +1031,7 @@ def __init__( # pytype: disable=annotation-type-mismatch super().__init__(keys, allow_missing_keys) self.converter = LabelToMask(select_labels=select_labels, merge_channels=merge_channels) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index e6dc151596..b2179b584e 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -16,10 +16,12 @@ __all__ = [ "moveaxis", + "in1d", ] def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor: + """`moveaxis` for pytorch and numpy, using `permute` for pytorch ver < 1.8""" if isinstance(x, torch.Tensor): if hasattr(torch, "moveaxis"): return torch.moveaxis(x, src, dst) @@ -39,3 +41,11 @@ def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor: elif isinstance(x, np.ndarray): return np.moveaxis(x, src, dst) raise RuntimeError() + + +def in1d(x, y): + """`np.in1d` with equivalent implementation for torch.""" + if isinstance(x, np.ndarray): + return np.in1d(x, y) + else: + return (x[..., None] == torch.tensor(y, device=x.device)).any(-1).view(-1) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 0ea5afc40c..aa8f02f815 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -46,6 +46,7 @@ first, get_seed, has_option, + is_module_ver_at_least, is_scalar, is_scalar_tensor, issequenceiterable, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 66f6557032..3b287b3fe4 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -22,7 +22,7 @@ import numpy as np import torch -from monai.utils.module import get_torch_version_tuple +from monai.utils.module import get_torch_version_tuple, version_leq __all__ = [ "zip_with", @@ -42,6 +42,7 @@ "MAX_SEED", "copy_to_device", "ImageMetaKey", + "is_module_ver_at_least", ] _seed = None @@ -355,3 +356,16 @@ def has_option(obj, keywords: Union[str, Sequence[str]]) -> bool: return False sig = inspect.signature(obj) return all(key in sig.parameters for key in ensure_tuple(keywords)) + + +def is_module_ver_at_least(module, version): + """Determine if a module's version is at least equal to the given value. + + Args: + module: imported module's name, e.g., `np` or `torch`. + version: required version, given as a tuple, e.g., `(1, 8, 0)`. + Returns: + `True` if module is the given version or newer. + """ + test_ver = ".".join(map(str, version)) + return module.__version__ != test_ver and version_leq(test_ver, module.__version__) diff --git a/tests/test_data_stats.py b/tests/test_data_stats.py index 43068797a3..50536f2a5c 100644 --- a/tests/test_data_stats.py +++ b/tests/test_data_stats.py @@ -117,7 +117,7 @@ "additional_info": lambda x: torch.mean(x.float()), "logger_handler": None, }, - torch.tensor([[0, 1], [1, 2]]), + torch.tensor([[0, 1], [1, 2]]).to("cuda" if torch.cuda.is_available() else "cpu"), ( "test data statistics:\nType: \nShape: torch.Size([2, 2])\nValue range: (0, 2)\n" "Value: tensor([[0, 1],\n [1, 2]])\nAdditional info: 1.0" diff --git a/tests/test_data_statsd.py b/tests/test_data_statsd.py index be7e54bc25..aea0f1e721 100644 --- a/tests/test_data_statsd.py +++ b/tests/test_data_statsd.py @@ -124,7 +124,7 @@ "additional_info": lambda x: torch.mean(x.float()), "logger_handler": None, }, - {"img": torch.tensor([[0, 1], [1, 2]])}, + {"img": torch.tensor([[0, 1], [1, 2]]).to("cuda" if torch.cuda.is_available() else "cpu")}, ( "test data statistics:\nType: \nShape: torch.Size([2, 2])\nValue range: (0, 2)\n" "Value: tensor([[0, 1],\n [1, 2]])\nAdditional info: 1.0" diff --git a/tests/test_label_to_mask.py b/tests/test_label_to_mask.py index 2a84c7bea6..9caa7252f3 100644 --- a/tests/test_label_to_mask.py +++ b/tests/test_label_to_mask.py @@ -12,46 +12,59 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import LabelToMask +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"select_labels": [2, 3], "merge_channels": False}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array([[[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), -] - -TEST_CASE_2 = [ - {"select_labels": 2, "merge_channels": False}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array([[[0, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), -] - -TEST_CASE_3 = [ - {"select_labels": [1, 2], "merge_channels": False}, - np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), - np.array([[[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), -] - -TEST_CASE_4 = [ - {"select_labels": 2, "merge_channels": False}, - np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), - np.array([[[1, 0, 1], [1, 1, 0]]]), -] - -TEST_CASE_5 = [ - {"select_labels": [1, 2], "merge_channels": True}, - np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), - np.array([[[1, 0, 1], [1, 1, 1]]]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"select_labels": [2, 3], "merge_channels": False}, + p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]])), + np.array([[[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), + ] + ) + TESTS.append( + [ + {"select_labels": 2, "merge_channels": False}, + p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]])), + np.array([[[0, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), + ] + ) + TESTS.append( + [ + {"select_labels": [1, 2], "merge_channels": False}, + p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])), + np.array([[[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), + ] + ) + TESTS.append( + [ + {"select_labels": 2, "merge_channels": False}, + p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])), + np.array([[[1, 0, 1], [1, 1, 0]]]), + ] + ) + TESTS.append( + [ + {"select_labels": [1, 2], "merge_channels": True}, + p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])), + np.array([[[1, 0, 1], [1, 1, 1]]]), + ] + ) class TestLabelToMask(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = LabelToMask(**argments)(image) - np.testing.assert_allclose(result, expected_data) + self.assertEqual(type(result), type(image)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, image.device) + assert_allclose(result, expected_data) if __name__ == "__main__": diff --git a/tests/test_label_to_maskd.py b/tests/test_label_to_maskd.py index f046390c19..b8f0d3c171 100644 --- a/tests/test_label_to_maskd.py +++ b/tests/test_label_to_maskd.py @@ -12,46 +12,60 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import LabelToMaskd +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"keys": "img", "select_labels": [2, 3], "merge_channels": False}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array([[[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), -] - -TEST_CASE_2 = [ - {"keys": "img", "select_labels": 2, "merge_channels": False}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array([[[0, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), -] - -TEST_CASE_3 = [ - {"keys": "img", "select_labels": [1, 2], "merge_channels": False}, - {"img": np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])}, - np.array([[[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), -] - -TEST_CASE_4 = [ - {"keys": "img", "select_labels": 2, "merge_channels": False}, - {"img": np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])}, - np.array([[[1, 0, 1], [1, 1, 0]]]), -] - -TEST_CASE_5 = [ - {"keys": "img", "select_labels": [1, 2], "merge_channels": True}, - {"img": np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])}, - np.array([[[1, 0, 1], [1, 1, 1]]]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": "img", "select_labels": [2, 3], "merge_channels": False}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array([[[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), + ] + ) + TESTS.append( + [ + {"keys": "img", "select_labels": 2, "merge_channels": False}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array([[[0, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), + ] + ) + TESTS.append( + [ + {"keys": "img", "select_labels": [1, 2], "merge_channels": False}, + {"img": p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]))}, + np.array([[[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), + ] + ) + TESTS.append( + [ + {"keys": "img", "select_labels": 2, "merge_channels": False}, + {"img": p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]))}, + np.array([[[1, 0, 1], [1, 1, 0]]]), + ] + ) + TESTS.append( + [ + {"keys": "img", "select_labels": [1, 2], "merge_channels": True}, + {"img": p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]))}, + np.array([[[1, 0, 1], [1, 1, 1]]]), + ] + ) class TestLabelToMaskd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) - def test_value(self, argments, image, expected_data): - result = LabelToMaskd(**argments)(image) - np.testing.assert_allclose(result["img"], expected_data) + @parameterized.expand(TESTS) + def test_value(self, argments, input_data, expected_data): + result = LabelToMaskd(**argments)(input_data) + r, i = result["img"], input_data["img"] + self.assertEqual(type(r), type(i)) + if isinstance(r, torch.Tensor): + self.assertEqual(r.device, i.device) + assert_allclose(r, expected_data) if __name__ == "__main__": diff --git a/tests/test_lambda.py b/tests/test_lambda.py index e71eb3e5b0..738c81130d 100644 --- a/tests/test_lambda.py +++ b/tests/test_lambda.py @@ -11,30 +11,30 @@ import unittest -import numpy as np - from monai.transforms.utility.array import Lambda -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestLambda(NumpyImageTestCase2D): def test_lambda_identity(self): - img = self.imt + for p in TEST_NDARRAYS: + img = p(self.imt) - def identity_func(x): - return x + def identity_func(x): + return x - lambd = Lambda(func=identity_func) - self.assertTrue(np.allclose(identity_func(img), lambd(img))) + lambd = Lambda(func=identity_func) + assert_allclose(identity_func(img), lambd(img)) def test_lambda_slicing(self): - img = self.imt + for p in TEST_NDARRAYS: + img = p(self.imt) - def slice_func(x): - return x[:, :, :6, ::-2] + def slice_func(x): + return x[:, :, :6, ::2] - lambd = Lambda(func=slice_func) - self.assertTrue(np.allclose(slice_func(img), lambd(img))) + lambd = Lambda(func=slice_func) + assert_allclose(slice_func(img), lambd(img)) if __name__ == "__main__": diff --git a/tests/test_lambdad.py b/tests/test_lambdad.py index ca28af778b..05ba0ff6bc 100644 --- a/tests/test_lambdad.py +++ b/tests/test_lambdad.py @@ -11,37 +11,36 @@ import unittest -import numpy as np - from monai.transforms.utility.dictionary import Lambdad -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestLambdad(NumpyImageTestCase2D): def test_lambdad_identity(self): - img = self.imt - data = {"img": img, "prop": 1.0} + for p in TEST_NDARRAYS: + img = p(self.imt) + data = {"img": img, "prop": 1.0} - def noise_func(x): - return x + 1.0 + def noise_func(x): + return x + 1.0 - expected = {"img": noise_func(data["img"]), "prop": 1.0} - ret = Lambdad(keys=["img", "prop"], func=noise_func, overwrite=[True, False])(data) - self.assertTrue(np.allclose(expected["img"], ret["img"])) - self.assertTrue(np.allclose(expected["prop"], ret["prop"])) + expected = {"img": noise_func(data["img"]), "prop": 1.0} + ret = Lambdad(keys=["img", "prop"], func=noise_func, overwrite=[True, False])(data) + assert_allclose(expected["img"], ret["img"]) + assert_allclose(expected["prop"], ret["prop"]) def test_lambdad_slicing(self): - img = self.imt - data = {} - data["img"] = img + for p in TEST_NDARRAYS: + img = p(self.imt) + data = {"img": img} - def slice_func(x): - return x[:, :, :6, ::-2] + def slice_func(x): + return x[:, :, :6, ::2] - lambd = Lambdad(keys=data.keys(), func=slice_func) - expected = {} - expected["img"] = slice_func(data["img"]) - self.assertTrue(np.allclose(expected["img"], lambd(data)["img"])) + lambd = Lambdad(keys=data.keys(), func=slice_func) + expected = {} + expected["img"] = slice_func(data["img"]) + assert_allclose(expected["img"], lambd(data)["img"]) if __name__ == "__main__": diff --git a/tests/test_squeezedim.py b/tests/test_squeezedim.py index 01ea489320..15ff7e94d6 100644 --- a/tests/test_squeezedim.py +++ b/tests/test_squeezedim.py @@ -12,34 +12,32 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import SqueezeDim +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"dim": None}, np.random.rand(1, 2, 1, 3), (2, 3)] +TESTS, TESTS_FAIL = [], [] +for p in TEST_NDARRAYS: + TESTS.append([{"dim": None}, p(np.random.rand(1, 2, 1, 3)), (2, 3)]) + TESTS.append([{"dim": 2}, p(np.random.rand(1, 2, 1, 8, 16)), (1, 2, 8, 16)]) + TESTS.append([{"dim": -1}, p(np.random.rand(1, 1, 16, 8, 1)), (1, 1, 16, 8)]) + TESTS.append([{}, p(np.random.rand(1, 2, 1, 3)), (2, 1, 3)]) -TEST_CASE_2 = [{"dim": 2}, np.random.rand(1, 2, 1, 8, 16), (1, 2, 8, 16)] - -TEST_CASE_3 = [{"dim": -1}, np.random.rand(1, 1, 16, 8, 1), (1, 1, 16, 8)] - -TEST_CASE_4 = [{}, np.random.rand(1, 2, 1, 3), (2, 1, 3)] - -TEST_CASE_4_PT = [{}, torch.rand(1, 2, 1, 3), (2, 1, 3)] - -TEST_CASE_5 = [ValueError, {"dim": -2}, np.random.rand(1, 1, 16, 8, 1)] - -TEST_CASE_6 = [TypeError, {"dim": 0.5}, np.random.rand(1, 1, 16, 8, 1)] + TESTS_FAIL.append([ValueError, {"dim": -2}, p(np.random.rand(1, 1, 16, 8, 1))]) + TESTS_FAIL.append([TypeError, {"dim": 0.5}, p(np.random.rand(1, 1, 16, 8, 1))]) class TestSqueezeDim(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_4_PT]) + @parameterized.expand(TESTS) def test_shape(self, input_param, test_data, expected_shape): + result = SqueezeDim(**input_param)(test_data) self.assertTupleEqual(result.shape, expected_shape) - @parameterized.expand([TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand(TESTS_FAIL) def test_invalid_inputs(self, exception, input_param, test_data): + with self.assertRaises(exception): SqueezeDim(**input_param)(test_data) diff --git a/tests/test_squeezedimd.py b/tests/test_squeezedimd.py index dcbd9212c7..35e7cd5d74 100644 --- a/tests/test_squeezedimd.py +++ b/tests/test_squeezedimd.py @@ -12,62 +12,78 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import SqueezeDimd +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"keys": ["img", "seg"], "dim": None}, - {"img": np.random.rand(1, 2, 1, 3), "seg": np.random.randint(0, 2, size=[1, 2, 1, 3])}, - (2, 3), -] +TESTS, TESTS_FAIL = [], [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["img", "seg"], "dim": None}, + {"img": p(np.random.rand(1, 2, 1, 3)), "seg": p(np.random.randint(0, 2, size=[1, 2, 1, 3]))}, + (2, 3), + ] + ) -TEST_CASE_2 = [ - {"keys": ["img", "seg"], "dim": 2}, - {"img": np.random.rand(1, 2, 1, 8, 16), "seg": np.random.randint(0, 2, size=[1, 2, 1, 8, 16])}, - (1, 2, 8, 16), -] + TESTS.append( + [ + {"keys": ["img", "seg"], "dim": 2}, + {"img": p(np.random.rand(1, 2, 1, 8, 16)), "seg": p(np.random.randint(0, 2, size=[1, 2, 1, 8, 16]))}, + (1, 2, 8, 16), + ] + ) -TEST_CASE_3 = [ - {"keys": ["img", "seg"], "dim": -1}, - {"img": np.random.rand(1, 1, 16, 8, 1), "seg": np.random.randint(0, 2, size=[1, 1, 16, 8, 1])}, - (1, 1, 16, 8), -] + TESTS.append( + [ + {"keys": ["img", "seg"], "dim": -1}, + {"img": p(np.random.rand(1, 1, 16, 8, 1)), "seg": p(np.random.randint(0, 2, size=[1, 1, 16, 8, 1]))}, + (1, 1, 16, 8), + ] + ) -TEST_CASE_4 = [ - {"keys": ["img", "seg"]}, - {"img": np.random.rand(1, 2, 1, 3), "seg": np.random.randint(0, 2, size=[1, 2, 1, 3])}, - (2, 1, 3), -] + TESTS.append( + [ + {"keys": ["img", "seg"]}, + {"img": p(np.random.rand(1, 2, 1, 3)), "seg": p(np.random.randint(0, 2, size=[1, 2, 1, 3]))}, + (2, 1, 3), + ] + ) -TEST_CASE_4_PT = [ - {"keys": ["img", "seg"], "dim": 0}, - {"img": torch.rand(1, 2, 1, 3), "seg": torch.randint(0, 2, size=[1, 2, 1, 3])}, - (2, 1, 3), -] + TESTS.append( + [ + {"keys": ["img", "seg"], "dim": 0}, + {"img": p(np.random.rand(1, 2, 1, 3)), "seg": p(np.random.randint(0, 2, size=[1, 2, 1, 3]))}, + (2, 1, 3), + ] + ) -TEST_CASE_5 = [ - ValueError, - {"keys": ["img", "seg"], "dim": -2}, - {"img": np.random.rand(1, 1, 16, 8, 1), "seg": np.random.randint(0, 2, size=[1, 1, 16, 8, 1])}, -] + TESTS_FAIL.append( + [ + ValueError, + {"keys": ["img", "seg"], "dim": -2}, + {"img": p(np.random.rand(1, 1, 16, 8, 1)), "seg": p(np.random.randint(0, 2, size=[1, 1, 16, 8, 1]))}, + ] + ) -TEST_CASE_6 = [ - TypeError, - {"keys": ["img", "seg"], "dim": 0.5}, - {"img": np.random.rand(1, 1, 16, 8, 1), "seg": np.random.randint(0, 2, size=[1, 1, 16, 8, 1])}, -] + TESTS_FAIL.append( + [ + TypeError, + {"keys": ["img", "seg"], "dim": 0.5}, + {"img": p(np.random.rand(1, 1, 16, 8, 1)), "seg": p(np.random.randint(0, 2, size=[1, 1, 16, 8, 1]))}, + ] + ) class TestSqueezeDim(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_4_PT]) + @parameterized.expand(TESTS) def test_shape(self, input_param, test_data, expected_shape): result = SqueezeDimd(**input_param)(test_data) self.assertTupleEqual(result["img"].shape, expected_shape) self.assertTupleEqual(result["seg"].shape, expected_shape) - @parameterized.expand([TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand(TESTS_FAIL) def test_invalid_inputs(self, exception, input_param, test_data): with self.assertRaises(exception): SqueezeDimd(**input_param)(test_data) diff --git a/tests/utils.py b/tests/utils.py index 22720849f1..1375cd2d72 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -36,6 +36,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.data import create_test_image_2d, create_test_image_3d from monai.utils import ensure_tuple, optional_import, set_determinism +from monai.utils.misc import is_module_ver_at_least from monai.utils.module import version_leq nib, _ = optional_import("nibabel") @@ -142,8 +143,7 @@ class SkipIfBeforePyTorchVersion: def __init__(self, pytorch_version_tuple): self.min_version = pytorch_version_tuple - test_ver = ".".join(map(str, self.min_version)) - self.version_too_old = torch.__version__ != test_ver and version_leq(torch.__version__, test_ver) + self.version_too_old = not is_module_ver_at_least(torch, pytorch_version_tuple) def __call__(self, obj): return unittest.skipIf( From 2cd4b99586175dfd2390b7dbb6ea5b055b2de780 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 27 Aug 2021 14:55:51 +0000 Subject: [PATCH 14/17] cumpy->cupy Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_to_cupy.py | 2 +- tests/test_to_cupyd.py | 2 +- tests/test_to_numpy.py | 2 +- tests/test_to_numpyd.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_to_cupy.py b/tests/test_to_cupy.py index a9460bc825..8b00e12539 100644 --- a/tests/test_to_cupy.py +++ b/tests/test_to_cupy.py @@ -24,7 +24,7 @@ class TestToCupy(unittest.TestCase): @skipUnless(has_cp, "CuPy is required.") - def test_cumpy_input(self): + def test_cupy_input(self): test_data = cp.array([[1, 2], [3, 4]]) test_data = cp.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) diff --git a/tests/test_to_cupyd.py b/tests/test_to_cupyd.py index 2f3c42dd1f..6f40bafe1c 100644 --- a/tests/test_to_cupyd.py +++ b/tests/test_to_cupyd.py @@ -24,7 +24,7 @@ class TestToCupyd(unittest.TestCase): @skipUnless(has_cp, "CuPy is required.") - def test_cumpy_input(self): + def test_cupy_input(self): test_data = cp.array([[1, 2], [3, 4]]) test_data = cp.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) diff --git a/tests/test_to_numpy.py b/tests/test_to_numpy.py index fd49a3d473..b48727c01d 100644 --- a/tests/test_to_numpy.py +++ b/tests/test_to_numpy.py @@ -24,7 +24,7 @@ class TestToNumpy(unittest.TestCase): @skipUnless(has_cp, "CuPy is required.") - def test_cumpy_input(self): + def test_cupy_input(self): test_data = cp.array([[1, 2], [3, 4]]) test_data = cp.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) diff --git a/tests/test_to_numpyd.py b/tests/test_to_numpyd.py index adfab65904..5acaef39c7 100644 --- a/tests/test_to_numpyd.py +++ b/tests/test_to_numpyd.py @@ -24,7 +24,7 @@ class TestToNumpyd(unittest.TestCase): @skipUnless(has_cp, "CuPy is required.") - def test_cumpy_input(self): + def test_cupy_input(self): test_data = cp.array([[1, 2], [3, 4]]) test_data = cp.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) From 85c2be5209c14877edc871615e411dfe42efc121 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 27 Aug 2021 17:05:57 +0100 Subject: [PATCH 15/17] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utility/array.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index f1fa72f597..f38a94302e 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -32,15 +32,7 @@ map_classes_to_indices, ) from monai.transforms.utils_pytorch_numpy_unification import moveaxis -from monai.utils import ( - convert_to_numpy, - convert_to_tensor, - ensure_tuple, - issequenceiterable, - look_up_option, - min_version, - optional_import, -) +from monai.utils import convert_to_numpy, convert_to_tensor, ensure_tuple, look_up_option, min_version, optional_import from monai.utils.enums import TransformBackends from monai.utils.type_conversion import convert_data_type From c548b1d2e4a4d3db25ce3c14f0ea214293922027 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 30 Aug 2021 16:43:22 +0100 Subject: [PATCH 16/17] fixes unit tests Signed-off-by: Wenqi Li --- monai/transforms/utility/array.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index a09d7ff641..68794282b1 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -731,7 +731,9 @@ def __call__( data = where(in1d(img, select_labels), True, False).reshape(img.shape) # pre pytorch 1.8.0, need to use 1/0 instead of True/False else: - data = where(in1d(img, select_labels), 1, 0).reshape(img.shape) + data = where( + in1d(img, select_labels), torch.tensor(1, device=img.device), torch.tensor(0, device=img.device) + ).reshape(img.shape) if merge_channels or self.merge_channels: if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): From c6e7235d4351f80ab2ba7135d28f50e1df56faaf Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 30 Aug 2021 18:17:17 +0100 Subject: [PATCH 17/17] style fixes Signed-off-by: Wenqi Li --- monai/transforms/utility/array.py | 3 +- .../utils_pytorch_numpy_unification.py | 33 ++++++++++--------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 68794282b1..918763405f 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -739,8 +739,7 @@ def __call__( if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): return data.any(0)[None] # pre pytorch 1.8.0 compatibility - else: - return data.to(torch.uint8).any(0)[None].to(bool) # type: ignore + return data.to(torch.uint8).any(0)[None].to(bool) # type: ignore return data diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index b2179b584e..2eebe3eda3 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -25,27 +25,28 @@ def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor: if isinstance(x, torch.Tensor): if hasattr(torch, "moveaxis"): return torch.moveaxis(x, src, dst) - # moveaxis only available in pytorch as of 1.8.0 - else: - # get original indices - indices = list(range(x.ndim)) - # make src and dst positive - if src < 0: - src = len(indices) + src - if dst < 0: - dst = len(indices) + dst - # remove desired index and insert it in new position - indices.pop(src) - indices.insert(dst, src) - return x.permute(indices) - elif isinstance(x, np.ndarray): + return _moveaxis_with_permute(x, src, dst) # type: ignore + if isinstance(x, np.ndarray): return np.moveaxis(x, src, dst) raise RuntimeError() +def _moveaxis_with_permute(x, src, dst): + # get original indices + indices = list(range(x.ndim)) + # make src and dst positive + if src < 0: + src = len(indices) + src + if dst < 0: + dst = len(indices) + dst + # remove desired index and insert it in new position + indices.pop(src) + indices.insert(dst, src) + return x.permute(indices) + + def in1d(x, y): """`np.in1d` with equivalent implementation for torch.""" if isinstance(x, np.ndarray): return np.in1d(x, y) - else: - return (x[..., None] == torch.tensor(y, device=x.device)).any(-1).view(-1) + return (x[..., None] == torch.tensor(y, device=x.device)).any(-1).view(-1)