From beab17dd107496d002da48afe32623623f1269aa Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 25 Aug 2021 11:04:12 +0100 Subject: [PATCH 1/9] backends -> backend Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 1 + monai/transforms/intensity/array.py | 10 +-- monai/transforms/utility/array.py | 2 +- monai/transforms/utils.py | 95 +++++++++++++++++--------- tests/test_print_transform_backends.py | 8 ++- 5 files changed, 74 insertions(+), 42 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 4203724a7d..5267af4048 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -500,6 +500,7 @@ get_extreme_points, get_largest_connected_component_mask, get_number_image_type_conversions, + get_transform_backends, img_bounds, in_bounds, is_empty, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 46c512c96c..b36c7adf96 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -131,7 +131,7 @@ class RandRicianNoise(RandomizableTransform): uniformly from 0 to std. """ - backends = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( self, @@ -197,7 +197,7 @@ class ShiftIntensity(Transform): offset: offset value to shift the intensity of image. """ - backends = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__(self, offset: float) -> None: self.offset = offset @@ -219,7 +219,7 @@ class RandShiftIntensity(RandomizableTransform): Randomly shift intensity with randomly picked offset. """ - backends = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__(self, offsets: Union[Tuple[float, float], float], prob: float = 0.1) -> None: """ @@ -273,7 +273,7 @@ class StdShiftIntensity(Transform): dtype: output data type, defaults to float32. """ - backends = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( self, factor: float, nonzero: bool = False, channel_wise: bool = False, dtype: DtypeLike = np.float32 @@ -318,7 +318,7 @@ class RandStdShiftIntensity(RandomizableTransform): by: ``v = v + factor * std(v)`` where the `factor` is randomly picked. """ - backends = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( self, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 3ef6413090..d56bca0d8d 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -291,7 +291,7 @@ class CastToType(Transform): specified PyTorch data type. """ - backends = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__(self, dtype=np.float32) -> None: """ diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e3e61b6c97..6281319973 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -10,6 +10,7 @@ # limitations under the License. import itertools +from monai.utils.enums import TransformBackends import random import warnings from contextlib import contextmanager @@ -81,6 +82,7 @@ "zero_margins", "equalize_hist", "get_number_image_type_conversions", + "get_transform_backends", "print_transform_backends", ] @@ -1157,23 +1159,17 @@ def _get_data(obj, key): num_conversions += 1 return num_conversions +def get_transform_backends() -> dict[str, List]: + """Get the backends of all MONAI transforms. -def print_transform_backends(): - """Prints a list of backends of all MONAI transforms.""" - - class Colours: - red = "91" - green = "92" - yellow = "93" - - def print_colour(t, colour): - print(f"\033[{colour}m{t}\033[00m") - - tr_total = 0 - tr_t_or_np = 0 - tr_t = 0 - tr_np = 0 - tr_uncategorised = 0 + Returns: + Dictionary, where each key is a transform, and its + corresponding values are a boolean list, stating + whether that transform supports (1) `torch.Tensor`, + and (2) `np.ndarray` as input without needing to + convert. + """ + backends = {} unique_transforms = [] for n, obj in getmembers(monai.transforms): # skip aliases @@ -1194,21 +1190,52 @@ def print_colour(t, colour): "InverteD", ]: continue - tr_total += 1 - if obj.backend == ["torch", "numpy"]: - tr_t_or_np += 1 - print_colour(f"TorchOrNumpy: {n}", Colours.green) - elif obj.backend == ["torch"]: - tr_t += 1 - print_colour(f"Torch: {n}", Colours.green) - elif obj.backend == ["numpy"]: - tr_np += 1 - print_colour(f"Numpy: {n}", Colours.yellow) - else: - tr_uncategorised += 1 - print_colour(f"Uncategorised: {n}", Colours.red) - print("Total number of transforms:", tr_total) - print_colour(f"Number transforms allowing both torch and numpy: {tr_t_or_np}", Colours.green) - print_colour(f"Number of TorchTransform: {tr_t}", Colours.green) - print_colour(f"Number of NumpyTransform: {tr_np}", Colours.yellow) - print_colour(f"Number of uncategorised: {tr_uncategorised}", Colours.red) + + backends[n] = [ + TransformBackends.TORCH in obj.backend, + TransformBackends.NUMPY in obj.backend, + ] + return backends + + +def print_transform_backends(): + """Prints a list of backends of all MONAI transforms.""" + class Colors: + none = "" + red = "91" + green = "92" + yellow = "93" + + def print_color(t, color): + print(f"\033[{color}m{t}\033[00m") + + def print_table_column(name, torch, numpy, color=Colors.none): + print_color("{:<50} {:<8} {:<8}".format(name, torch, numpy), color) + + backends = get_transform_backends() + n_total = len(backends) + n_t_or_np, n_t, n_np, n_uncategorized = 0, 0, 0, 0 + print_table_column("Transform", "Torch?", "Numpy?") + for k, v in backends.items(): + if all(v): + color = Colors.green + n_t_or_np += 1 + elif v[0]: + color = Colors.green + n_t += 1 + elif v[1]: + color = Colors.yellow + n_np += 1 + else: + color = Colors.red + n_uncategorized += 1 + print_table_column(k, *v, color) + + print("Total number of transforms:", n_total) + print_color(f"Number transforms allowing both torch and numpy: {n_t_or_np}", Colors.green) + print_color(f"Number of TorchTransform: {n_t}", Colors.green) + print_color(f"Number of NumpyTransform: {n_np}", Colors.yellow) + print_color(f"Number of uncategorised: {n_uncategorized}", Colors.red) + +if __name__ == "__main__": + print_transform_backends() diff --git a/tests/test_print_transform_backends.py b/tests/test_print_transform_backends.py index 09828f0a27..4164687f01 100644 --- a/tests/test_print_transform_backends.py +++ b/tests/test_print_transform_backends.py @@ -11,13 +11,17 @@ import unittest -from monai.transforms.utils import print_transform_backends +from monai.transforms.utils import get_transform_backends, print_transform_backends class TestPrintTransformBackends(unittest.TestCase): def test_get_number_of_conversions(self): + tr_t_or_np, *_ = get_transform_backends() + self.assertGreater(len(tr_t_or_np), 0) print_transform_backends() if __name__ == "__main__": - unittest.main() + # unittest.main() + a = TestPrintTransformBackends() + a.test_get_number_of_conversions() From c3192bad9b6585676a06966d57930790f89c4988 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 25 Aug 2021 11:10:42 +0100 Subject: [PATCH 2/9] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 6281319973..8bb4a38310 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -10,7 +10,6 @@ # limitations under the License. import itertools -from monai.utils.enums import TransformBackends import random import warnings from contextlib import contextmanager @@ -39,6 +38,7 @@ min_version, optional_import, ) +from monai.utils.enums import TransformBackends from monai.utils.type_conversion import convert_data_type measure, _ = optional_import("skimage.measure", "0.14.2", min_version) @@ -1159,6 +1159,7 @@ def _get_data(obj, key): num_conversions += 1 return num_conversions + def get_transform_backends() -> dict[str, List]: """Get the backends of all MONAI transforms. @@ -1200,6 +1201,7 @@ def get_transform_backends() -> dict[str, List]: def print_transform_backends(): """Prints a list of backends of all MONAI transforms.""" + class Colors: none = "" red = "91" @@ -1237,5 +1239,6 @@ def print_table_column(name, torch, numpy, color=Colors.none): print_color(f"Number of NumpyTransform: {n_np}", Colors.yellow) print_color(f"Number of uncategorised: {n_uncategorized}", Colors.red) + if __name__ == "__main__": print_transform_backends() From f790ad7acc151e33aabe88dd8ba43ab87dd057c0 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 25 Aug 2021 11:30:33 +0100 Subject: [PATCH 3/9] code format2 Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 8bb4a38310..30aa5e7b99 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1160,7 +1160,7 @@ def _get_data(obj, key): return num_conversions -def get_transform_backends() -> dict[str, List]: +def get_transform_backends(): """Get the backends of all MONAI transforms. Returns: From 1d2870481031e0dd7df79caf0daf7c3191f3cc13 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 25 Aug 2021 11:55:22 +0100 Subject: [PATCH 4/9] AddChannel, AsChannelFirst, AsChannelLast, EnsureChannelFirst, Identity, RepeatChannel Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utility/array.py | 46 +++++++++++++++++++------- monai/transforms/utility/dictionary.py | 26 ++++++++++----- tests/test_add_channeld.py | 17 ++++++---- tests/test_as_channel_first.py | 17 +++++----- tests/test_as_channel_firstd.py | 21 ++++++------ tests/test_as_channel_last.py | 17 +++++----- tests/test_as_channel_lastd.py | 21 ++++++------ tests/test_ensure_channel_first.py | 8 +++-- tests/test_ensure_channel_firstd.py | 9 +++-- tests/test_identity.py | 11 +++--- tests/test_identityd.py | 13 ++++---- tests/test_repeat_channel.py | 8 +++-- tests/test_repeat_channeld.py | 17 ++++++---- 13 files changed, 142 insertions(+), 89 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index d56bca0d8d..fc233b50cc 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -82,17 +82,18 @@ class Identity(Transform): """ - Convert the input to an np.ndarray, if input data is np.ndarray or subclasses, return unchanged data. + Do nothing to the data. As the output value is same as input, it can be used as a testing tool to verify the transform chain, Compose or transform adaptor, etc. - """ - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> np.ndarray: + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - return np.asanyarray(img) + return img class AsChannelFirst(Transform): @@ -111,16 +112,23 @@ class AsChannelFirst(Transform): channel_dim: which dimension of input image is the channel, default is the last dimension. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, channel_dim: int = -1) -> None: if not (isinstance(channel_dim, int) and channel_dim >= -1): raise AssertionError("invalid channel dimension.") self.channel_dim = channel_dim - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - return np.moveaxis(img, self.channel_dim, 0) + # old versions of pytorch don't have moveaxis + if isinstance(img, torch.Tensor) and hasattr(torch, "moveaxis"): + return torch.moveaxis(img, self.channel_dim, 0) + + img, *_ = convert_data_type(img, np.ndarray) + return np.moveaxis(img, self.channel_dim, 0) # type: ignore class AsChannelLast(Transform): @@ -138,16 +146,23 @@ class AsChannelLast(Transform): channel_dim: which dimension of input image is the channel, default is the first dimension. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, channel_dim: int = 0) -> None: if not (isinstance(channel_dim, int) and channel_dim >= -1): raise AssertionError("invalid channel dimension.") self.channel_dim = channel_dim - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - return np.moveaxis(img, self.channel_dim, -1) + # old versions of pytorch don't have moveaxis + if isinstance(img, torch.Tensor) and hasattr(torch, "moveaxis"): + return torch.moveaxis(img, self.channel_dim, -1) + + img, *_ = convert_data_type(img, np.ndarray) + return np.moveaxis(img, self.channel_dim, -1) # type: ignore class AddChannel(Transform): @@ -164,7 +179,9 @@ class AddChannel(Transform): transforms. """ - def __call__(self, img: NdarrayTensor): + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ @@ -179,6 +196,8 @@ class EnsureChannelFirst(Transform): Convert the data to `channel_first` based on the `original_channel_dim` information. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, strict_check: bool = True): """ Args: @@ -186,7 +205,7 @@ def __init__(self, strict_check: bool = True): """ self.strict_check = strict_check - def __call__(self, img: np.ndarray, meta_dict: Optional[Mapping] = None): + def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> NdarrayOrTensor: """ Apply the transform to `img`. """ @@ -220,16 +239,19 @@ class RepeatChannel(Transform): repeats: the number of repetitions for each element. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, repeats: int) -> None: if repeats <= 0: raise AssertionError("repeats count must be greater than 0.") self.repeats = repeats - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is a "channel-first" array. """ - return np.repeat(img, self.repeats, 0) + repeeat_fn = torch.repeat_interleave if isinstance(img, torch.Tensor) else np.repeat + return repeeat_fn(img, self.repeats, 0) # type: ignore class RemoveRepeatedChannel(Transform): diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index a53e4f3235..41c2a1b9b9 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -169,6 +169,8 @@ class Identityd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Identity`. """ + backend = Identity.backend + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: @@ -180,9 +182,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.identity = Identity() - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.identity(d[key]) @@ -194,6 +194,8 @@ class AsChannelFirstd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.AsChannelFirst`. """ + backend = AsChannelFirst.backend + def __init__(self, keys: KeysCollection, channel_dim: int = -1, allow_missing_keys: bool = False) -> None: """ Args: @@ -205,7 +207,7 @@ def __init__(self, keys: KeysCollection, channel_dim: int = -1, allow_missing_ke super().__init__(keys, allow_missing_keys) self.converter = AsChannelFirst(channel_dim=channel_dim) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -217,6 +219,8 @@ class AsChannelLastd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.AsChannelLast`. """ + backend = AsChannelLast.backend + def __init__(self, keys: KeysCollection, channel_dim: int = 0, allow_missing_keys: bool = False) -> None: """ Args: @@ -228,7 +232,7 @@ def __init__(self, keys: KeysCollection, channel_dim: int = 0, allow_missing_key super().__init__(keys, allow_missing_keys) self.converter = AsChannelLast(channel_dim=channel_dim) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -240,6 +244,8 @@ class AddChanneld(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.AddChannel`. """ + backend = AddChannel.backend + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Args: @@ -250,7 +256,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.adder = AddChannel() - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.adder(d[key]) @@ -262,6 +268,8 @@ class EnsureChannelFirstd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.EnsureChannelFirst`. """ + backend = EnsureChannelFirst.backend + def __init__( self, keys: KeysCollection, @@ -289,7 +297,7 @@ def __init__( self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - def __call__(self, data) -> Dict[Hashable, np.ndarray]: + def __call__(self, data) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, meta_key, meta_key_postfix in zip(self.keys, self.meta_keys, self.meta_key_postfix): d[key] = self.adjuster(d[key], d[meta_key or f"{key}_{meta_key_postfix}"]) @@ -301,6 +309,8 @@ class RepeatChanneld(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.RepeatChannel`. """ + backend = RepeatChannel.backend + def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool = False) -> None: """ Args: @@ -312,7 +322,7 @@ def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.repeater = RepeatChannel(repeats) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.repeater(d[key]) diff --git a/tests/test_add_channeld.py b/tests/test_add_channeld.py index ca4af37271..8bdd89a4ae 100644 --- a/tests/test_add_channeld.py +++ b/tests/test_add_channeld.py @@ -15,16 +15,21 @@ from parameterized import parameterized from monai.transforms import AddChanneld +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"keys": ["img", "seg"]}, - {"img": np.array([[0, 1], [1, 2]]), "seg": np.array([[0, 1], [1, 2]])}, - (1, 2, 2), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["img", "seg"]}, + {"img": p(np.array([[0, 1], [1, 2]])), "seg": p(np.array([[0, 1], [1, 2]]))}, + (1, 2, 2), + ] + ) class TestAddChanneld(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape): result = AddChanneld(**input_param)(input_data) self.assertEqual(result["img"].shape, expected_shape) diff --git a/tests/test_as_channel_first.py b/tests/test_as_channel_first.py index e7d9866ae1..bc9158f277 100644 --- a/tests/test_as_channel_first.py +++ b/tests/test_as_channel_first.py @@ -15,18 +15,19 @@ from parameterized import parameterized from monai.transforms import AsChannelFirst +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"channel_dim": -1}, (4, 1, 2, 3)] - -TEST_CASE_2 = [{"channel_dim": 3}, (4, 1, 2, 3)] - -TEST_CASE_3 = [{"channel_dim": 2}, (3, 1, 2, 4)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"channel_dim": -1}, (4, 1, 2, 3)]) + TESTS.append([p, {"channel_dim": 3}, (4, 1, 2, 3)]) + TESTS.append([p, {"channel_dim": 2}, (3, 1, 2, 4)]) class TestAsChannelFirst(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_param, expected_shape): - test_data = np.random.randint(0, 2, size=[1, 2, 3, 4]) + @parameterized.expand(TESTS) + def test_shape(self, in_type, input_param, expected_shape): + test_data = in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])) result = AsChannelFirst(**input_param)(test_data) self.assertTupleEqual(result.shape, expected_shape) diff --git a/tests/test_as_channel_firstd.py b/tests/test_as_channel_firstd.py index e70c2e1b47..68d33434c1 100644 --- a/tests/test_as_channel_firstd.py +++ b/tests/test_as_channel_firstd.py @@ -15,21 +15,22 @@ from parameterized import parameterized from monai.transforms import AsChannelFirstd +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"keys": ["image", "label", "extra"], "channel_dim": -1}, (4, 1, 2, 3)] - -TEST_CASE_2 = [{"keys": ["image", "label", "extra"], "channel_dim": 3}, (4, 1, 2, 3)] - -TEST_CASE_3 = [{"keys": ["image", "label", "extra"], "channel_dim": 2}, (3, 1, 2, 4)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": -1}, (4, 1, 2, 3)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 3}, (4, 1, 2, 3)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 2}, (3, 1, 2, 4)]) class TestAsChannelFirstd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_param, expected_shape): + @parameterized.expand(TESTS) + def test_shape(self, in_type, input_param, expected_shape): test_data = { - "image": np.random.randint(0, 2, size=[1, 2, 3, 4]), - "label": np.random.randint(0, 2, size=[1, 2, 3, 4]), - "extra": np.random.randint(0, 2, size=[1, 2, 3, 4]), + "image": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), + "label": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), + "extra": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), } result = AsChannelFirstd(**input_param)(test_data) self.assertTupleEqual(result["image"].shape, expected_shape) diff --git a/tests/test_as_channel_last.py b/tests/test_as_channel_last.py index 6ec6c8d6e6..55a7a08676 100644 --- a/tests/test_as_channel_last.py +++ b/tests/test_as_channel_last.py @@ -15,18 +15,19 @@ from parameterized import parameterized from monai.transforms import AsChannelLast +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"channel_dim": 0}, (2, 3, 4, 1)] - -TEST_CASE_2 = [{"channel_dim": 1}, (1, 3, 4, 2)] - -TEST_CASE_3 = [{"channel_dim": 3}, (1, 2, 3, 4)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"channel_dim": 0}, (2, 3, 4, 1)]) + TESTS.append([p, {"channel_dim": 1}, (1, 3, 4, 2)]) + TESTS.append([p, {"channel_dim": 3}, (1, 2, 3, 4)]) class TestAsChannelLast(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_param, expected_shape): - test_data = np.random.randint(0, 2, size=[1, 2, 3, 4]) + @parameterized.expand(TESTS) + def test_shape(self, in_type, input_param, expected_shape): + test_data = in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])) result = AsChannelLast(**input_param)(test_data) self.assertTupleEqual(result.shape, expected_shape) diff --git a/tests/test_as_channel_lastd.py b/tests/test_as_channel_lastd.py index 2ef4dd4da1..350f639f3f 100644 --- a/tests/test_as_channel_lastd.py +++ b/tests/test_as_channel_lastd.py @@ -15,21 +15,22 @@ from parameterized import parameterized from monai.transforms import AsChannelLastd +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"keys": ["image", "label", "extra"], "channel_dim": 0}, (2, 3, 4, 1)] - -TEST_CASE_2 = [{"keys": ["image", "label", "extra"], "channel_dim": 1}, (1, 3, 4, 2)] - -TEST_CASE_3 = [{"keys": ["image", "label", "extra"], "channel_dim": 3}, (1, 2, 3, 4)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 0}, (2, 3, 4, 1)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 1}, (1, 3, 4, 2)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 3}, (1, 2, 3, 4)]) class TestAsChannelLastd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_param, expected_shape): + @parameterized.expand(TESTS) + def test_shape(self, in_type, input_param, expected_shape): test_data = { - "image": np.random.randint(0, 2, size=[1, 2, 3, 4]), - "label": np.random.randint(0, 2, size=[1, 2, 3, 4]), - "extra": np.random.randint(0, 2, size=[1, 2, 3, 4]), + "image": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), + "label": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), + "extra": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), } result = AsChannelLastd(**input_param)(test_data) self.assertTupleEqual(result["image"].shape, expected_shape) diff --git a/tests/test_ensure_channel_first.py b/tests/test_ensure_channel_first.py index 6b9def1cea..23126d326f 100644 --- a/tests/test_ensure_channel_first.py +++ b/tests/test_ensure_channel_first.py @@ -21,6 +21,7 @@ from monai.data import ITKReader from monai.transforms import EnsureChannelFirst, LoadImage +from tests.utils import TEST_NDARRAYS TEST_CASE_1 = [{"image_only": False}, ["test_image.nii.gz"], None] @@ -61,9 +62,10 @@ def test_load_nifti(self, input_param, filenames, original_channel_dim): for i, name in enumerate(filenames): filenames[i] = os.path.join(tempdir, name) nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) - result, header = LoadImage(**input_param)(filenames) - result = EnsureChannelFirst()(result, header) - self.assertEqual(result.shape[0], len(filenames)) + for p in TEST_NDARRAYS: + result, header = LoadImage(**input_param)(filenames) + result = EnsureChannelFirst()(p(result), header) + self.assertEqual(result.shape[0], len(filenames)) @parameterized.expand([TEST_CASE_7]) def test_itk_dicom_series_reader(self, input_param, filenames, original_channel_dim): diff --git a/tests/test_ensure_channel_firstd.py b/tests/test_ensure_channel_firstd.py index 59eb32c576..b4cde02a8f 100644 --- a/tests/test_ensure_channel_firstd.py +++ b/tests/test_ensure_channel_firstd.py @@ -19,6 +19,7 @@ from PIL import Image from monai.transforms import EnsureChannelFirstd, LoadImaged +from tests.utils import TEST_NDARRAYS TEST_CASE_1 = [{"keys": "img"}, ["test_image.nii.gz"], None] @@ -43,9 +44,11 @@ def test_load_nifti(self, input_param, filenames, original_channel_dim): for i, name in enumerate(filenames): filenames[i] = os.path.join(tempdir, name) nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) - result = LoadImaged(**input_param)({"img": filenames}) - result = EnsureChannelFirstd(**input_param)(result) - self.assertEqual(result["img"].shape[0], len(filenames)) + for p in TEST_NDARRAYS: + result = LoadImaged(**input_param)({"img": filenames}) + result["img"] = p(result["img"]) + result = EnsureChannelFirstd(**input_param)(result) + self.assertEqual(result["img"].shape[0], len(filenames)) def test_load_png(self): spatial_size = (256, 256, 3) diff --git a/tests/test_identity.py b/tests/test_identity.py index 2dff2bb13d..172860668c 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -11,17 +11,16 @@ import unittest -import numpy as np - from monai.transforms.utility.array import Identity -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestIdentity(NumpyImageTestCase2D): def test_identity(self): - img = self.imt - identity = Identity() - self.assertTrue(np.allclose(img, identity(img))) + for p in TEST_NDARRAYS: + img = p(self.imt) + identity = Identity() + assert_allclose(img, identity(img)) if __name__ == "__main__": diff --git a/tests/test_identityd.py b/tests/test_identityd.py index 8796f28da8..665b7d5d1c 100644 --- a/tests/test_identityd.py +++ b/tests/test_identityd.py @@ -12,16 +12,17 @@ import unittest from monai.transforms.utility.dictionary import Identityd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestIdentityd(NumpyImageTestCase2D): def test_identityd(self): - img = self.imt - data = {} - data["img"] = img - identity = Identityd(keys=data.keys()) - self.assertEqual(data, identity(data)) + for p in TEST_NDARRAYS: + img = p(self.imt) + data = {} + data["img"] = img + identity = Identityd(keys=data.keys()) + assert_allclose(img, identity(data)["img"]) if __name__ == "__main__": diff --git a/tests/test_repeat_channel.py b/tests/test_repeat_channel.py index 643ebc64de..e246dd1212 100644 --- a/tests/test_repeat_channel.py +++ b/tests/test_repeat_channel.py @@ -11,16 +11,18 @@ import unittest -import numpy as np from parameterized import parameterized from monai.transforms import RepeatChannel +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"repeats": 3}, np.array([[[0, 1], [1, 2]]]), (3, 2, 2)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([{"repeats": 3}, p([[[0, 1], [1, 2]]]), (3, 2, 2)]) class TestRepeatChannel(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape): result = RepeatChannel(**input_param)(input_data) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_repeat_channeld.py b/tests/test_repeat_channeld.py index 7bd58bd1fe..3b73962bb9 100644 --- a/tests/test_repeat_channeld.py +++ b/tests/test_repeat_channeld.py @@ -15,16 +15,21 @@ from parameterized import parameterized from monai.transforms import RepeatChanneld +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"keys": ["img"], "repeats": 3}, - {"img": np.array([[[0, 1], [1, 2]]]), "seg": np.array([[[0, 1], [1, 2]]])}, - (3, 2, 2), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["img"], "repeats": 3}, + {"img": p(np.array([[[0, 1], [1, 2]]])), "seg": p(np.array([[[0, 1], [1, 2]]]))}, + (3, 2, 2), + ] + ) class TestRepeatChanneld(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape): result = RepeatChanneld(**input_param)(input_data) self.assertEqual(result["img"].shape, expected_shape) From be7eac206da7329ddffcdc585ed8b68499d63858 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 25 Aug 2021 15:59:31 +0100 Subject: [PATCH 5/9] moveaxis backwards compatible Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 4 ++ monai/transforms/utility/array.py | 18 +++------ .../utils_pytorch_numpy_unification.py | 37 +++++++++++++++++++ tests/test_as_channel_first.py | 10 ++++- 4 files changed, 54 insertions(+), 15 deletions(-) create mode 100644 monai/transforms/utils_pytorch_numpy_unification.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 5267af4048..b648b3331b 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -517,3 +517,7 @@ weighted_patch_samples, zero_margins, ) + +from .utils_pytorch_numpy_unification import ( + moveaxis, +) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index fc233b50cc..58d1bec34d 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -31,6 +31,9 @@ map_binary_to_indices, map_classes_to_indices, ) +from monai.transforms.utils_pytorch_numpy_unification import ( + moveaxis, +) from monai.utils import ( convert_to_numpy, convert_to_tensor, @@ -123,12 +126,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - # old versions of pytorch don't have moveaxis - if isinstance(img, torch.Tensor) and hasattr(torch, "moveaxis"): - return torch.moveaxis(img, self.channel_dim, 0) - - img, *_ = convert_data_type(img, np.ndarray) - return np.moveaxis(img, self.channel_dim, 0) # type: ignore + return moveaxis(img, self.channel_dim, 0) class AsChannelLast(Transform): @@ -157,13 +155,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - # old versions of pytorch don't have moveaxis - if isinstance(img, torch.Tensor) and hasattr(torch, "moveaxis"): - return torch.moveaxis(img, self.channel_dim, -1) - - img, *_ = convert_data_type(img, np.ndarray) - return np.moveaxis(img, self.channel_dim, -1) # type: ignore - + return moveaxis(img, self.channel_dim, -1) class AddChannel(Transform): """ diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py new file mode 100644 index 0000000000..6eab76678c --- /dev/null +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -0,0 +1,37 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np + +from monai.config.type_definitions import NdarrayOrTensor + + +__all__ = [ + "moveaxis", +] + + +def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor: + if isinstance(x, torch.Tensor): + if hasattr(torch, "moveaxis"): + return torch.moveaxis(x, src, dst) + # moveaxis only available in pytorch as of 1.8.0 + else: + # get original indices, remove desired index and insert it in new position + indices = list(range(x.ndim)) + indices.pop(src) + indices.insert(dst, src) + return x.permute(indices) + elif isinstance(x, np.ndarray): + return np.moveaxis(x, src, dst) + raise RuntimeError() + diff --git a/tests/test_as_channel_first.py b/tests/test_as_channel_first.py index bc9158f277..ffa66c5ec5 100644 --- a/tests/test_as_channel_first.py +++ b/tests/test_as_channel_first.py @@ -13,9 +13,10 @@ import numpy as np from parameterized import parameterized +import torch from monai.transforms import AsChannelFirst -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS, assert_allclose TESTS = [] for p in TEST_NDARRAYS: @@ -26,10 +27,15 @@ class TestAsChannelFirst(unittest.TestCase): @parameterized.expand(TESTS) - def test_shape(self, in_type, input_param, expected_shape): + def test_value(self, in_type, input_param, expected_shape): test_data = in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])) result = AsChannelFirst(**input_param)(test_data) self.assertTupleEqual(result.shape, expected_shape) + if isinstance(test_data, torch.Tensor): + test_data = test_data.cpu().numpy() + expected = np.moveaxis(test_data, input_param["channel_dim"], 0) + assert_allclose(expected, result) + if __name__ == "__main__": From 0d27527cbae154cb9c37dc9f06889e38c2f75a24 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 25 Aug 2021 16:39:24 +0100 Subject: [PATCH 6/9] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 5 +---- monai/transforms/utility/array.py | 5 ++--- monai/transforms/utils_pytorch_numpy_unification.py | 4 +--- tests/test_as_channel_first.py | 3 +-- 4 files changed, 5 insertions(+), 12 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index b648b3331b..28379aadd8 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -517,7 +517,4 @@ weighted_patch_samples, zero_margins, ) - -from .utils_pytorch_numpy_unification import ( - moveaxis, -) +from .utils_pytorch_numpy_unification import moveaxis diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 58d1bec34d..580c6c8b3c 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -31,9 +31,7 @@ map_binary_to_indices, map_classes_to_indices, ) -from monai.transforms.utils_pytorch_numpy_unification import ( - moveaxis, -) +from monai.transforms.utils_pytorch_numpy_unification import moveaxis from monai.utils import ( convert_to_numpy, convert_to_tensor, @@ -157,6 +155,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ return moveaxis(img, self.channel_dim, -1) + class AddChannel(Transform): """ Adds a 1-length channel dimension to the input image. diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 6eab76678c..dd9e54fa45 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -9,12 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch import numpy as np +import torch from monai.config.type_definitions import NdarrayOrTensor - __all__ = [ "moveaxis", ] @@ -34,4 +33,3 @@ def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor: elif isinstance(x, np.ndarray): return np.moveaxis(x, src, dst) raise RuntimeError() - diff --git a/tests/test_as_channel_first.py b/tests/test_as_channel_first.py index ffa66c5ec5..0d1b1c7d3a 100644 --- a/tests/test_as_channel_first.py +++ b/tests/test_as_channel_first.py @@ -12,8 +12,8 @@ import unittest import numpy as np -from parameterized import parameterized import torch +from parameterized import parameterized from monai.transforms import AsChannelFirst from tests.utils import TEST_NDARRAYS, assert_allclose @@ -37,6 +37,5 @@ def test_value(self, in_type, input_param, expected_shape): assert_allclose(expected, result) - if __name__ == "__main__": unittest.main() From 754a6843ff6a99ffe9e411e68caf285fb5f034eb Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 26 Aug 2021 12:25:46 +0100 Subject: [PATCH 7/9] trigger ci/cd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> From e38e7a3fc048b7c27d7b93c45c3b40b150062e9e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 26 Aug 2021 15:42:55 +0100 Subject: [PATCH 8/9] permute requires positive indices Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utils_pytorch_numpy_unification.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index dd9e54fa45..bb16369cf0 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -25,8 +25,14 @@ def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor: return torch.moveaxis(x, src, dst) # moveaxis only available in pytorch as of 1.8.0 else: - # get original indices, remove desired index and insert it in new position + # get original indices indices = list(range(x.ndim)) + # make src and dst positive + if src < 0: + src = len(indices) - src + if dst < 0: + dst = len(indices) - dst + # remove desired index and insert it in new position indices.pop(src) indices.insert(dst, src) return x.permute(indices) From b1e476d267f4a8e22f7fa4cff0fc2b1fbc93662e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 27 Aug 2021 10:41:09 +0100 Subject: [PATCH 9/9] correct permute Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utils_pytorch_numpy_unification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index bb16369cf0..e6dc151596 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -29,9 +29,9 @@ def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor: indices = list(range(x.ndim)) # make src and dst positive if src < 0: - src = len(indices) - src + src = len(indices) + src if dst < 0: - dst = len(indices) - dst + dst = len(indices) + dst # remove desired index and insert it in new position indices.pop(src) indices.insert(dst, src)