From 733bf13297bc0c10ac8c5e6a74f0348f2f2436d7 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 24 Aug 2021 15:21:23 +0100 Subject: [PATCH 1/3] all close Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_fill_holes.py | 3 +-- tests/test_flip.py | 7 ++----- tests/test_flipd.py | 7 ++----- tests/test_keep_largest_connected_component.py | 4 ++-- tests/test_label_filter.py | 3 +-- tests/test_rand_axis_flip.py | 8 ++------ tests/test_rand_axis_flipd.py | 7 ++----- tests/test_rand_flip.py | 7 ++----- tests/test_rand_flipd.py | 7 ++----- tests/utils.py | 17 +++++++---------- 10 files changed, 23 insertions(+), 47 deletions(-) diff --git a/tests/test_fill_holes.py b/tests/test_fill_holes.py index 294bbd8c87..dfc6fb5154 100644 --- a/tests/test_fill_holes.py +++ b/tests/test_fill_holes.py @@ -278,10 +278,9 @@ def test_correct_results(self, _, args, input_image, expected): converter = FillHoles(**args) if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): result = converter(clone(input_image).cuda()) - assert allclose(result, expected.cuda()) else: result = converter(clone(input_image)) - assert allclose(result, expected) + self.assertTrue(allclose(result, expected)) @parameterized.expand(INVALID_CASES) def test_raise_exception(self, _, args, input_image, expected_error): diff --git a/tests/test_flip.py b/tests/test_flip.py index bd0162fb8b..f9dbb5368e 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -12,11 +12,10 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import Flip -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, allclose INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)] @@ -40,9 +39,7 @@ def test_correct_results(self, _, spatial_axis): expected.append(np.flip(channel, spatial_axis)) expected = np.stack(expected) result = flip(im) - if isinstance(result, torch.Tensor): - result = result.cpu() - self.assertTrue(np.allclose(expected, result)) + self.assertTrue(allclose(expected, result)) if __name__ == "__main__": diff --git a/tests/test_flipd.py b/tests/test_flipd.py index cec4a99cbf..5266e36955 100644 --- a/tests/test_flipd.py +++ b/tests/test_flipd.py @@ -12,11 +12,10 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import Flipd -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, allclose INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)] @@ -39,9 +38,7 @@ def test_correct_results(self, _, spatial_axis): expected.append(np.flip(channel, spatial_axis)) expected = np.stack(expected) result = flip({"img": p(self.imt[0])})["img"] - if isinstance(result, torch.Tensor): - result = result.cpu() - assert np.allclose(expected, result) + self.assertTrue(allclose(expected, result)) if __name__ == "__main__": diff --git a/tests/test_keep_largest_connected_component.py b/tests/test_keep_largest_connected_component.py index 670dd2d2ee..91a2d79718 100644 --- a/tests/test_keep_largest_connected_component.py +++ b/tests/test_keep_largest_connected_component.py @@ -327,10 +327,10 @@ def test_correct_results(self, _, args, input_image, expected): converter = KeepLargestConnectedComponent(**args) if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): result = converter(clone(input_image).cuda()) - assert allclose(result, expected.cuda()) + else: result = converter(clone(input_image)) - assert allclose(result, expected) + self.assertTrue(allclose(result, expected.cuda())) @parameterized.expand(INVALID_CASES) def test_raise_exception(self, _, args, input_image, expected_error): diff --git a/tests/test_label_filter.py b/tests/test_label_filter.py index 9165fddc40..87da61acc9 100644 --- a/tests/test_label_filter.py +++ b/tests/test_label_filter.py @@ -108,10 +108,9 @@ def test_correct_results(self, _, args, input_image, expected): converter = LabelFilter(**args) if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): result = converter(clone(input_image).cuda()) - assert allclose(result, expected.cuda()) else: result = converter(clone(input_image)) - assert allclose(result, expected) + self.assertTrue(allclose(result, expected)) @parameterized.expand(INVALID_CASES) def test_raise_exception(self, _, args, input_image, expected_error): diff --git a/tests/test_rand_axis_flip.py b/tests/test_rand_axis_flip.py index bd53fa1fb0..9acd7be74c 100644 --- a/tests/test_rand_axis_flip.py +++ b/tests/test_rand_axis_flip.py @@ -12,10 +12,9 @@ import unittest import numpy as np -import torch from monai.transforms import RandAxisFlip -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, allclose class TestRandAxisFlip(NumpyImageTestCase2D): @@ -23,13 +22,10 @@ def test_correct_results(self): for p in TEST_NDARRAYS: flip = RandAxisFlip(prob=1.0) result = flip(p(self.imt[0])) - if isinstance(result, torch.Tensor): - result = result.cpu() - expected = [] for channel in self.imt[0]: expected.append(np.flip(channel, flip._axis)) - self.assertTrue(np.allclose(np.stack(expected), result)) + self.assertTrue(allclose(np.stack(expected), result)) if __name__ == "__main__": diff --git a/tests/test_rand_axis_flipd.py b/tests/test_rand_axis_flipd.py index 518d78dd29..bba7eb632c 100644 --- a/tests/test_rand_axis_flipd.py +++ b/tests/test_rand_axis_flipd.py @@ -12,10 +12,9 @@ import unittest import numpy as np -import torch from monai.transforms import RandAxisFlipd -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase3D, allclose class TestRandAxisFlip(NumpyImageTestCase3D): @@ -23,13 +22,11 @@ def test_correct_results(self): for p in TEST_NDARRAYS: flip = RandAxisFlipd(keys="img", prob=1.0) result = flip({"img": p(self.imt[0])})["img"] - if isinstance(result, torch.Tensor): - result = result.cpu() expected = [] for channel in self.imt[0]: expected.append(np.flip(channel, flip._axis)) - self.assertTrue(np.allclose(np.stack(expected), result)) + self.assertTrue(allclose(np.stack(expected), result)) if __name__ == "__main__": diff --git a/tests/test_rand_flip.py b/tests/test_rand_flip.py index c20c13fec5..1d5cc34316 100644 --- a/tests/test_rand_flip.py +++ b/tests/test_rand_flip.py @@ -12,11 +12,10 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import RandFlip -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, allclose INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)] @@ -40,9 +39,7 @@ def test_correct_results(self, _, spatial_axis): expected.append(np.flip(channel, spatial_axis)) expected = np.stack(expected) result = flip(im) - if isinstance(result, torch.Tensor): - result = result.cpu() - self.assertTrue(np.allclose(expected, result)) + self.assertTrue(allclose(expected, result)) if __name__ == "__main__": diff --git a/tests/test_rand_flipd.py b/tests/test_rand_flipd.py index 42c7dfe4b5..743c5a49e8 100644 --- a/tests/test_rand_flipd.py +++ b/tests/test_rand_flipd.py @@ -12,11 +12,10 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import RandFlipd -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, allclose VALID_CASES = [("no_axis", None), ("one_axis", 1), ("many_axis", [0, 1])] @@ -27,13 +26,11 @@ def test_correct_results(self, _, spatial_axis): for p in TEST_NDARRAYS: flip = RandFlipd(keys="img", prob=1.0, spatial_axis=spatial_axis) result = flip({"img": p(self.imt[0])})["img"] - if isinstance(result, torch.Tensor): - result = result.cpu() expected = [] for channel in self.imt[0]: expected.append(np.flip(channel, spatial_axis)) expected = np.stack(expected) - self.assertTrue(np.allclose(expected, result)) + self.assertTrue(allclose(expected, result)) if __name__ == "__main__": diff --git a/tests/utils.py b/tests/utils.py index 1148af7551..1b2b7dea9b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -33,6 +33,7 @@ from monai.config import NdarrayTensor from monai.config.deviceconfig import USE_COMPILED +from monai.config.type_definitions import NdarrayOrTensor from monai.data import create_test_image_2d, create_test_image_3d from monai.utils import ensure_tuple, optional_import, set_determinism from monai.utils.module import version_leq @@ -55,7 +56,7 @@ def clone(data: NdarrayTensor) -> NdarrayTensor: return copy.deepcopy(data) -def allclose(a: NdarrayTensor, b: NdarrayTensor) -> bool: +def allclose(a: NdarrayOrTensor, b: NdarrayOrTensor, rtol=1.0e-5, atol=1.0e-8, equal_nan=False) -> bool: """ Check if all values of two data objects are close. @@ -63,19 +64,15 @@ def allclose(a: NdarrayTensor, b: NdarrayTensor) -> bool: This method also checks that both data objects are either Pytorch Tensors or numpy arrays. Args: - a (NdarrayTensor): Pytorch Tensor or numpy array for comparison - b (NdarrayTensor): Pytorch Tensor or numpy array to compare against + a (NdarrayOrTensor): Pytorch Tensor or numpy array for comparison + b (NdarrayOrTensor): Pytorch Tensor or numpy array to compare against Returns: bool: If both data objects are close. """ - if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): - return np.allclose(a, b) - - if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): - return torch.allclose(a, b) - - return False + a = a.cpu() if isinstance(a, torch.Tensor) else a + b = b.cpu() if isinstance(b, torch.Tensor) else b + return np.allclose(a, b, rtol, atol, equal_nan) def test_pretrained_networks(network, input_param, device): From 00e9f22da43cc7e667d929b3fa87968b774887e0 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 24 Aug 2021 15:48:55 +0100 Subject: [PATCH 2/3] assert_allclose Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_fill_holes.py | 4 ++-- tests/test_flip.py | 4 ++-- tests/test_flipd.py | 4 ++-- tests/test_keep_largest_connected_component.py | 4 ++-- tests/test_label_filter.py | 4 ++-- tests/test_rand_axis_flip.py | 4 ++-- tests/test_rand_axis_flipd.py | 4 ++-- tests/test_rand_flip.py | 4 ++-- tests/test_rand_flipd.py | 4 ++-- tests/utils.py | 12 +++--------- 10 files changed, 21 insertions(+), 27 deletions(-) diff --git a/tests/test_fill_holes.py b/tests/test_fill_holes.py index dfc6fb5154..6ea83c239b 100644 --- a/tests/test_fill_holes.py +++ b/tests/test_fill_holes.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.transforms import FillHoles -from tests.utils import allclose, clone +from tests.utils import assert_allclose, clone grid_1_raw = [ [1, 1, 1], @@ -280,7 +280,7 @@ def test_correct_results(self, _, args, input_image, expected): result = converter(clone(input_image).cuda()) else: result = converter(clone(input_image)) - self.assertTrue(allclose(result, expected)) + assert_allclose(result, expected) @parameterized.expand(INVALID_CASES) def test_raise_exception(self, _, args, input_image, expected_error): diff --git a/tests/test_flip.py b/tests/test_flip.py index f9dbb5368e..404a3def7d 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import Flip -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, allclose +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)] @@ -39,7 +39,7 @@ def test_correct_results(self, _, spatial_axis): expected.append(np.flip(channel, spatial_axis)) expected = np.stack(expected) result = flip(im) - self.assertTrue(allclose(expected, result)) + assert_allclose(expected, result) if __name__ == "__main__": diff --git a/tests/test_flipd.py b/tests/test_flipd.py index 5266e36955..1676723800 100644 --- a/tests/test_flipd.py +++ b/tests/test_flipd.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import Flipd -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, allclose +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)] @@ -38,7 +38,7 @@ def test_correct_results(self, _, spatial_axis): expected.append(np.flip(channel, spatial_axis)) expected = np.stack(expected) result = flip({"img": p(self.imt[0])})["img"] - self.assertTrue(allclose(expected, result)) + assert_allclose(expected, result) if __name__ == "__main__": diff --git a/tests/test_keep_largest_connected_component.py b/tests/test_keep_largest_connected_component.py index 91a2d79718..527d986614 100644 --- a/tests/test_keep_largest_connected_component.py +++ b/tests/test_keep_largest_connected_component.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import KeepLargestConnectedComponent -from tests.utils import allclose, clone +from tests.utils import assert_allclose, clone grid_1 = torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]) grid_2 = torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]) @@ -330,7 +330,7 @@ def test_correct_results(self, _, args, input_image, expected): else: result = converter(clone(input_image)) - self.assertTrue(allclose(result, expected.cuda())) + assert_allclose(result, expected) @parameterized.expand(INVALID_CASES) def test_raise_exception(self, _, args, input_image, expected_error): diff --git a/tests/test_label_filter.py b/tests/test_label_filter.py index 87da61acc9..c699fb31fd 100644 --- a/tests/test_label_filter.py +++ b/tests/test_label_filter.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.transforms import LabelFilter -from tests.utils import allclose, clone +from tests.utils import assert_allclose, clone grid_1 = torch.tensor( [ @@ -110,7 +110,7 @@ def test_correct_results(self, _, args, input_image, expected): result = converter(clone(input_image).cuda()) else: result = converter(clone(input_image)) - self.assertTrue(allclose(result, expected)) + assert_allclose(result, expected) @parameterized.expand(INVALID_CASES) def test_raise_exception(self, _, args, input_image, expected_error): diff --git a/tests/test_rand_axis_flip.py b/tests/test_rand_axis_flip.py index 9acd7be74c..c05c3a1e0d 100644 --- a/tests/test_rand_axis_flip.py +++ b/tests/test_rand_axis_flip.py @@ -14,7 +14,7 @@ import numpy as np from monai.transforms import RandAxisFlip -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, allclose +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestRandAxisFlip(NumpyImageTestCase2D): @@ -25,7 +25,7 @@ def test_correct_results(self): expected = [] for channel in self.imt[0]: expected.append(np.flip(channel, flip._axis)) - self.assertTrue(allclose(np.stack(expected), result)) + assert_allclose(np.stack(expected), result) if __name__ == "__main__": diff --git a/tests/test_rand_axis_flipd.py b/tests/test_rand_axis_flipd.py index bba7eb632c..7bef0baa63 100644 --- a/tests/test_rand_axis_flipd.py +++ b/tests/test_rand_axis_flipd.py @@ -14,7 +14,7 @@ import numpy as np from monai.transforms import RandAxisFlipd -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase3D, allclose +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase3D, assert_allclose class TestRandAxisFlip(NumpyImageTestCase3D): @@ -26,7 +26,7 @@ def test_correct_results(self): expected = [] for channel in self.imt[0]: expected.append(np.flip(channel, flip._axis)) - self.assertTrue(allclose(np.stack(expected), result)) + assert_allclose(np.stack(expected), result) if __name__ == "__main__": diff --git a/tests/test_rand_flip.py b/tests/test_rand_flip.py index 1d5cc34316..b3c514cb1f 100644 --- a/tests/test_rand_flip.py +++ b/tests/test_rand_flip.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import RandFlip -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, allclose +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)] @@ -39,7 +39,7 @@ def test_correct_results(self, _, spatial_axis): expected.append(np.flip(channel, spatial_axis)) expected = np.stack(expected) result = flip(im) - self.assertTrue(allclose(expected, result)) + assert_allclose(expected, result) if __name__ == "__main__": diff --git a/tests/test_rand_flipd.py b/tests/test_rand_flipd.py index 743c5a49e8..8972024fd8 100644 --- a/tests/test_rand_flipd.py +++ b/tests/test_rand_flipd.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import RandFlipd -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, allclose +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose VALID_CASES = [("no_axis", None), ("one_axis", 1), ("many_axis", [0, 1])] @@ -30,7 +30,7 @@ def test_correct_results(self, _, spatial_axis): for channel in self.imt[0]: expected.append(np.flip(channel, spatial_axis)) expected = np.stack(expected) - self.assertTrue(allclose(expected, result)) + assert_allclose(expected, result) if __name__ == "__main__": diff --git a/tests/utils.py b/tests/utils.py index 1b2b7dea9b..22720849f1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -56,23 +56,17 @@ def clone(data: NdarrayTensor) -> NdarrayTensor: return copy.deepcopy(data) -def allclose(a: NdarrayOrTensor, b: NdarrayOrTensor, rtol=1.0e-5, atol=1.0e-8, equal_nan=False) -> bool: +def assert_allclose(a: NdarrayOrTensor, b: NdarrayOrTensor, *args, **kwargs): """ - Check if all values of two data objects are close. - - Note: - This method also checks that both data objects are either Pytorch Tensors or numpy arrays. + Assert that all values of two data objects are close. Args: a (NdarrayOrTensor): Pytorch Tensor or numpy array for comparison b (NdarrayOrTensor): Pytorch Tensor or numpy array to compare against - - Returns: - bool: If both data objects are close. """ a = a.cpu() if isinstance(a, torch.Tensor) else a b = b.cpu() if isinstance(b, torch.Tensor) else b - return np.allclose(a, b, rtol, atol, equal_nan) + np.testing.assert_allclose(a, b, *args, **kwargs) def test_pretrained_networks(network, input_param, device): From 02230b571354aa6e409f9465cdc4bea4cf578626 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 24 Aug 2021 16:39:47 +0100 Subject: [PATCH 3/3] NormalizeIntensity Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 52 +++++++-- monai/transforms/intensity/dictionary.py | 8 +- tests/test_normalize_intensity.py | 139 +++++++++++++++-------- tests/test_normalize_intensityd.py | 85 +++++++++----- 4 files changed, 189 insertions(+), 95 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 4b7c8d6997..8b2bf32145 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -539,10 +539,12 @@ class NormalizeIntensity(Transform): dtype: output data type, defaults to float32. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, - subtrahend: Union[Sequence, np.ndarray, None] = None, - divisor: Union[Sequence, np.ndarray, None] = None, + subtrahend: Union[Sequence, NdarrayOrTensor, None] = None, + divisor: Union[Sequence, NdarrayOrTensor, None] = None, nonzero: bool = False, channel_wise: bool = False, dtype: DtypeLike = np.float32, @@ -553,26 +555,51 @@ def __init__( self.channel_wise = channel_wise self.dtype = dtype - def _normalize(self, img: np.ndarray, sub=None, div=None) -> np.ndarray: - slices = (img != 0) if self.nonzero else np.ones(img.shape, dtype=bool) - if not np.any(slices): + @staticmethod + def _mean(x): + if isinstance(x, np.ndarray): + return np.mean(x) + x = torch.mean(x.float()) + return x.item() if x.numel() == 1 else x + + @staticmethod + def _std(x): + if isinstance(x, np.ndarray): + return np.std(x) + x = torch.std(x.float(), unbiased=False) + return x.item() if x.numel() == 1 else x + + def _normalize(self, img: NdarrayOrTensor, sub=None, div=None) -> NdarrayOrTensor: + img, *_ = convert_data_type(img, dtype=torch.float32) + + if self.nonzero: + slices = img != 0 + else: + if isinstance(img, np.ndarray): + slices = np.ones_like(img, dtype=bool) + else: + slices = torch.ones_like(img, dtype=torch.bool) + if not slices.any(): return img - _sub = sub if sub is not None else np.mean(img[slices]) - if isinstance(_sub, np.ndarray): + _sub = sub if sub is not None else self._mean(img[slices]) + if isinstance(_sub, (torch.Tensor, np.ndarray)): + _sub, *_ = convert_to_dst_type(_sub, img) _sub = _sub[slices] - _div = div if div is not None else np.std(img[slices]) + _div = div if div is not None else self._std(img[slices]) if np.isscalar(_div): if _div == 0.0: _div = 1.0 - elif isinstance(_div, np.ndarray): + elif isinstance(_div, (torch.Tensor, np.ndarray)): + _div, *_ = convert_to_dst_type(_div, img) _div = _div[slices] _div[_div == 0.0] = 1.0 + img[slices] = (img[slices] - _sub) / _div return img - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is a channel-first array if `self.channel_wise` is True, """ @@ -583,7 +610,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: raise ValueError(f"img has {len(img)} channels, but divisor has {len(self.divisor)} components.") for i, d in enumerate(img): - img[i] = self._normalize( + img[i] = self._normalize( # type: ignore d, sub=self.subtrahend[i] if self.subtrahend is not None else None, div=self.divisor[i] if self.divisor is not None else None, @@ -591,7 +618,8 @@ def __call__(self, img: np.ndarray) -> np.ndarray: else: img = self._normalize(img, self.subtrahend, self.divisor) - return img.astype(self.dtype) + out, *_ = convert_data_type(img, dtype=self.dtype) + return out class ThresholdIntensity(Transform): diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 522007df29..bce45b57d3 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -612,11 +612,13 @@ class NormalizeIntensityd(MapTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = NormalizeIntensity.backend + def __init__( self, keys: KeysCollection, - subtrahend: Optional[np.ndarray] = None, - divisor: Optional[np.ndarray] = None, + subtrahend: Optional[NdarrayOrTensor] = None, + divisor: Optional[NdarrayOrTensor] = None, nonzero: bool = False, channel_wise: bool = False, dtype: DtypeLike = np.float32, @@ -625,7 +627,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.normalizer = NormalizeIntensity(subtrahend, divisor, nonzero, channel_wise, dtype) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.normalizer(d[key]) diff --git a/tests/test_normalize_intensity.py b/tests/test_normalize_intensity.py index 9d474faea7..2755eb4c25 100644 --- a/tests/test_normalize_intensity.py +++ b/tests/test_normalize_intensity.py @@ -12,70 +12,111 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import NormalizeIntensity -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose -TEST_CASES = [ - [{"nonzero": True}, np.array([0.0, 3.0, 0.0, 4.0]), np.array([0.0, -1.0, 0.0, 1.0])], - [ - {"subtrahend": np.array([3.5, 3.5, 3.5, 3.5]), "divisor": np.array([0.5, 0.5, 0.5, 0.5]), "nonzero": True}, - np.array([0.0, 3.0, 0.0, 4.0]), - np.array([0.0, -1.0, 0.0, 1.0]), - ], - [{"nonzero": True}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])], - [{"nonzero": False}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])], - [{"nonzero": False}, np.array([1, 1, 1, 1]), np.array([0.0, 0.0, 0.0, 0.0])], - [ - {"nonzero": False, "channel_wise": True, "subtrahend": [1, 2, 3]}, - np.ones((3, 2, 2)), - np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-2.0, -2.0], [-2.0, -2.0]]]), - ], - [ - {"nonzero": True, "channel_wise": True, "subtrahend": [1, 2, 3], "divisor": [0, 0, 2]}, - np.ones((3, 2, 2)), - np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-1.0, -1.0], [-1.0, -1.0]]]), - ], - [ - {"nonzero": True, "channel_wise": False, "subtrahend": 2, "divisor": 0}, - np.ones((3, 2, 2)), - np.ones((3, 2, 2)) * -1.0, - ], - [ - {"nonzero": True, "channel_wise": False, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": 0}, - np.ones((3, 2, 2)), - np.ones((3, 2, 2)) * 0.5, - ], - [ - {"nonzero": True, "channel_wise": True, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": [0, 1, 0]}, - np.ones((3, 2, 2)), - np.ones((3, 2, 2)) * 0.5, - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"nonzero": True}, np.array([0.0, 3.0, 0.0, 4.0]), np.array([0.0, -1.0, 0.0, 1.0])]) + for q in TEST_NDARRAYS: + for u in TEST_NDARRAYS: + TESTS.append( + [ + p, + { + "subtrahend": q(np.array([3.5, 3.5, 3.5, 3.5])), + "divisor": u(np.array([0.5, 0.5, 0.5, 0.5])), + "nonzero": True, + }, + np.array([0.0, 3.0, 0.0, 4.0]), + np.array([0.0, -1.0, 0.0, 1.0]), + ] + ) + TESTS.append([p, {"nonzero": True}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])]) + TESTS.append([p, {"nonzero": False}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])]) + TESTS.append([p, {"nonzero": False}, np.array([1, 1, 1, 1]), np.array([0.0, 0.0, 0.0, 0.0])]) + TESTS.append( + [ + p, + {"nonzero": False, "channel_wise": True, "subtrahend": [1, 2, 3]}, + np.ones((3, 2, 2)), + np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-2.0, -2.0], [-2.0, -2.0]]]), + ] + ) + TESTS.append( + [ + p, + {"nonzero": True, "channel_wise": True, "subtrahend": [1, 2, 3], "divisor": [0, 0, 2]}, + np.ones((3, 2, 2)), + np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-1.0, -1.0], [-1.0, -1.0]]]), + ] + ) + TESTS.append( + [ + p, + {"nonzero": True, "channel_wise": False, "subtrahend": 2, "divisor": 0}, + np.ones((3, 2, 2)), + np.ones((3, 2, 2)) * -1.0, + ] + ) + TESTS.append( + [ + p, + {"nonzero": True, "channel_wise": False, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": 0}, + np.ones((3, 2, 2)), + np.ones((3, 2, 2)) * 0.5, + ] + ) + TESTS.append( + [ + p, + {"nonzero": True, "channel_wise": True, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": [0, 1, 0]}, + np.ones((3, 2, 2)), + np.ones((3, 2, 2)) * 0.5, + ] + ) class TestNormalizeIntensity(NumpyImageTestCase2D): - def test_default(self): + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_default(self, im_type): + im = im_type(self.imt.copy()) normalizer = NormalizeIntensity() - normalized = normalizer(self.imt.copy()) - self.assertTrue(normalized.dtype == np.float32) + normalized = normalizer(im) + self.assertEqual(type(im), type(normalized)) + if isinstance(normalized, torch.Tensor): + self.assertEqual(im.device, normalized.device) + self.assertTrue(normalized.dtype in (np.float32, torch.float32)) expected = (self.imt - np.mean(self.imt)) / np.std(self.imt) - np.testing.assert_allclose(normalized, expected, rtol=1e-3) + assert_allclose(expected, normalized, rtol=1e-3) - @parameterized.expand(TEST_CASES) - def test_nonzero(self, input_param, input_data, expected_data): + @parameterized.expand(TESTS) + def test_nonzero(self, in_type, input_param, input_data, expected_data): normalizer = NormalizeIntensity(**input_param) - np.testing.assert_allclose(expected_data, normalizer(input_data)) + im = in_type(input_data) + normalized = normalizer(im) + self.assertEqual(type(im), type(normalized)) + if isinstance(normalized, torch.Tensor): + self.assertEqual(im.device, normalized.device) + assert_allclose(expected_data, normalized) - def test_channel_wise(self): + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_channel_wise(self, im_type): normalizer = NormalizeIntensity(nonzero=True, channel_wise=True) - input_data = np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]) + input_data = im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]])) expected = np.array([[0.0, -1.0, 0.0, 1.0], [0.0, -1.0, 0.0, 1.0]]) - np.testing.assert_allclose(expected, normalizer(input_data)) + normalized = normalizer(input_data) + self.assertEqual(type(input_data), type(normalized)) + if isinstance(normalized, torch.Tensor): + self.assertEqual(input_data.device, normalized.device) + assert_allclose(expected, normalized) - def test_value_errors(self): - input_data = np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]) + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_value_errors(self, im_type): + input_data = im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]])) normalizer = NormalizeIntensity(nonzero=True, channel_wise=True, subtrahend=[1]) with self.assertRaises(ValueError): normalizer(input_data) diff --git a/tests/test_normalize_intensityd.py b/tests/test_normalize_intensityd.py index 482d1a3f5b..e2cec5407a 100644 --- a/tests/test_normalize_intensityd.py +++ b/tests/test_normalize_intensityd.py @@ -12,54 +12,77 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import NormalizeIntensityd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose -TEST_CASE_1 = [ - {"keys": ["img"], "nonzero": True}, - {"img": np.array([0.0, 3.0, 0.0, 4.0])}, - np.array([0.0, -1.0, 0.0, 1.0]), -] - -TEST_CASE_2 = [ - { - "keys": ["img"], - "subtrahend": np.array([3.5, 3.5, 3.5, 3.5]), - "divisor": np.array([0.5, 0.5, 0.5, 0.5]), - "nonzero": True, - }, - {"img": np.array([0.0, 3.0, 0.0, 4.0])}, - np.array([0.0, -1.0, 0.0, 1.0]), -] - -TEST_CASE_3 = [ - {"keys": ["img"], "nonzero": True}, - {"img": np.array([0.0, 0.0, 0.0, 0.0])}, - np.array([0.0, 0.0, 0.0, 0.0]), -] +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["img"], "nonzero": True}, + {"img": p(np.array([0.0, 3.0, 0.0, 4.0]))}, + np.array([0.0, -1.0, 0.0, 1.0]), + ] + ) + TESTS.append( + [ + { + "keys": ["img"], + "subtrahend": q(np.array([3.5, 3.5, 3.5, 3.5])), + "divisor": q(np.array([0.5, 0.5, 0.5, 0.5])), + "nonzero": True, + }, + {"img": p(np.array([0.0, 3.0, 0.0, 4.0]))}, + np.array([0.0, -1.0, 0.0, 1.0]), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "nonzero": True}, + {"img": p(np.array([0.0, 0.0, 0.0, 0.0]))}, + np.array([0.0, 0.0, 0.0, 0.0]), + ] + ) class TestNormalizeIntensityd(NumpyImageTestCase2D): - def test_image_normalize_intensityd(self): + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_image_normalize_intensityd(self, im_type): key = "img" + im = im_type(self.imt) normalizer = NormalizeIntensityd(keys=[key]) - normalized = normalizer({key: self.imt}) + normalized = normalizer({key: im})[key] expected = (self.imt - np.mean(self.imt)) / np.std(self.imt) - np.testing.assert_allclose(normalized[key], expected, rtol=1e-3) + self.assertEqual(type(im), type(normalized)) + if isinstance(normalized, torch.Tensor): + self.assertEqual(im.device, normalized.device) + assert_allclose(normalized, expected, rtol=1e-3) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_nonzero(self, input_param, input_data, expected_data): + key = "img" normalizer = NormalizeIntensityd(**input_param) - np.testing.assert_allclose(expected_data, normalizer(input_data)["img"]) + normalized = normalizer(input_data)[key] + self.assertEqual(type(input_data[key]), type(normalized)) + if isinstance(normalized, torch.Tensor): + self.assertEqual(input_data[key].device, normalized.device) + assert_allclose(normalized, expected_data) - def test_channel_wise(self): + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_channel_wise(self, im_type): key = "img" normalizer = NormalizeIntensityd(keys=key, nonzero=True, channel_wise=True) - input_data = {key: np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]])} + input_data = {key: im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]))} + normalized = normalizer(input_data)[key] + self.assertEqual(type(input_data[key]), type(normalized)) + if isinstance(normalized, torch.Tensor): + self.assertEqual(input_data[key].device, normalized.device) expected = np.array([[0.0, -1.0, 0.0, 1.0], [0.0, -1.0, 0.0, 1.0]]) - np.testing.assert_allclose(expected, normalizer(input_data)[key]) + assert_allclose(normalized, expected) if __name__ == "__main__":