diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 48f0657ab0..1c9b31c120 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -16,7 +16,7 @@ """ from collections.abc import Iterable -from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -35,7 +35,7 @@ ShiftIntensity, ThresholdIntensity, ) -from monai.utils import dtype_torch_to_numpy, ensure_tuple_size +from monai.utils import dtype_torch_to_numpy, ensure_tuple_rep, ensure_tuple_size __all__ = [ "RandGaussianNoised", @@ -110,27 +110,29 @@ def __init__( ) -> None: super().__init__(keys) self.prob = prob - self.mean = ensure_tuple_size(mean, len(self.keys)) + self.mean = ensure_tuple_rep(mean, len(self.keys)) self.std = std self._do_transform = False - self._noise: Optional[np.ndarray] = None + self._noise: List[np.ndarray] = [] def randomize(self, im_shape: Sequence[int]) -> None: self._do_transform = self.R.random() < self.prob - self._noise = self.R.normal(self.mean, self.R.uniform(0, self.std), size=im_shape) + self._noise.clear() + for m in self.mean: + self._noise.append(self.R.normal(m, self.R.uniform(0, self.std), size=im_shape)) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) image_shape = d[self.keys[0]].shape # image shape from the first data key self.randomize(image_shape) - if self._noise is None: + if len(self._noise) != len(self.keys): raise AssertionError if not self._do_transform: return d - for key in self.keys: + for noise, key in zip(self._noise, self.keys): dtype = dtype_torch_to_numpy(d[key].dtype) if isinstance(d[key], torch.Tensor) else d[key].dtype - d[key] = d[key] + self._noise.astype(dtype) + d[key] = d[key] + noise.astype(dtype) return d diff --git a/tests/test_rand_gaussian_noised.py b/tests/test_rand_gaussian_noised.py index 63da23c8dc..442a85ca77 100644 --- a/tests/test_rand_gaussian_noised.py +++ b/tests/test_rand_gaussian_noised.py @@ -17,38 +17,36 @@ from monai.transforms import RandGaussianNoised from tests.utils import NumpyImageTestCase2D, TorchImageTestCase2D -TEST_CASE_0 = ["test_zero_mean", ["img"], 0, 0.1] -TEST_CASE_1 = ["test_non_zero_mean", ["img"], 1, 0.5] +TEST_CASE_0 = ["test_zero_mean", ["img1", "img2"], 0, 0.1] +TEST_CASE_1 = ["test_non_zero_mean", ["img1", "img2"], 1, 0.5] TEST_CASES = [TEST_CASE_0, TEST_CASE_1] seed = 0 -# Test with numpy + +def test_numpy_or_torch(keys, mean, std, imt): + gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std) + gaussian_fn.set_random_state(seed) + noised = gaussian_fn({k: imt for k in keys}) + np.random.seed(seed) + np.random.random() + for k in keys: + expected = imt + np.random.normal(mean, np.random.uniform(0, std), size=imt.shape) + np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5) +# Test with numpy class TestRandGaussianNoisedNumpy(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES) def test_correct_results(self, _, keys, mean, std): - gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std) - gaussian_fn.set_random_state(seed) - noised = gaussian_fn({"img": self.imt}) - np.random.seed(seed) - np.random.random() - expected = self.imt + np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape) - np.testing.assert_allclose(expected, noised["img"], atol=1e-5, rtol=1e-5) + test_numpy_or_torch(keys, mean, std, self.imt) # Test with torch class TestRandGaussianNoisedTorch(TorchImageTestCase2D): @parameterized.expand(TEST_CASES) def test_correct_results(self, _, keys, mean, std): - gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std) - gaussian_fn.set_random_state(seed) - noised = gaussian_fn({"img": self.imt}) - np.random.seed(seed) - np.random.random() - expected = self.imt + np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape) - np.testing.assert_allclose(expected, noised["img"], atol=1e-5, rtol=1e-5) + test_numpy_or_torch(keys, mean, std, self.imt) if __name__ == "__main__":