Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""

from collections.abc import Iterable
from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand All @@ -35,7 +35,7 @@
ShiftIntensity,
ThresholdIntensity,
)
from monai.utils import dtype_torch_to_numpy, ensure_tuple_size
from monai.utils import dtype_torch_to_numpy, ensure_tuple_rep, ensure_tuple_size

__all__ = [
"RandGaussianNoised",
Expand Down Expand Up @@ -110,27 +110,29 @@ def __init__(
) -> None:
super().__init__(keys)
self.prob = prob
self.mean = ensure_tuple_size(mean, len(self.keys))
self.mean = ensure_tuple_rep(mean, len(self.keys))
self.std = std
self._do_transform = False
self._noise: Optional[np.ndarray] = None
self._noise: List[np.ndarray] = []

def randomize(self, im_shape: Sequence[int]) -> None:
self._do_transform = self.R.random() < self.prob
self._noise = self.R.normal(self.mean, self.R.uniform(0, self.std), size=im_shape)
self._noise.clear()
for m in self.mean:
self._noise.append(self.R.normal(m, self.R.uniform(0, self.std), size=im_shape))

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = dict(data)

image_shape = d[self.keys[0]].shape # image shape from the first data key
self.randomize(image_shape)
if self._noise is None:
if len(self._noise) != len(self.keys):
raise AssertionError
if not self._do_transform:
return d
for key in self.keys:
for noise, key in zip(self._noise, self.keys):
dtype = dtype_torch_to_numpy(d[key].dtype) if isinstance(d[key], torch.Tensor) else d[key].dtype
d[key] = d[key] + self._noise.astype(dtype)
d[key] = d[key] + noise.astype(dtype)
return d


Expand Down
32 changes: 15 additions & 17 deletions tests/test_rand_gaussian_noised.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,36 @@
from monai.transforms import RandGaussianNoised
from tests.utils import NumpyImageTestCase2D, TorchImageTestCase2D

TEST_CASE_0 = ["test_zero_mean", ["img"], 0, 0.1]
TEST_CASE_1 = ["test_non_zero_mean", ["img"], 1, 0.5]
TEST_CASE_0 = ["test_zero_mean", ["img1", "img2"], 0, 0.1]
TEST_CASE_1 = ["test_non_zero_mean", ["img1", "img2"], 1, 0.5]
TEST_CASES = [TEST_CASE_0, TEST_CASE_1]

seed = 0

# Test with numpy

def test_numpy_or_torch(keys, mean, std, imt):
gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std)
gaussian_fn.set_random_state(seed)
noised = gaussian_fn({k: imt for k in keys})
np.random.seed(seed)
np.random.random()
for k in keys:
expected = imt + np.random.normal(mean, np.random.uniform(0, std), size=imt.shape)
np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5)


# Test with numpy
class TestRandGaussianNoisedNumpy(NumpyImageTestCase2D):
@parameterized.expand(TEST_CASES)
def test_correct_results(self, _, keys, mean, std):
gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std)
gaussian_fn.set_random_state(seed)
noised = gaussian_fn({"img": self.imt})
np.random.seed(seed)
np.random.random()
expected = self.imt + np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape)
np.testing.assert_allclose(expected, noised["img"], atol=1e-5, rtol=1e-5)
test_numpy_or_torch(keys, mean, std, self.imt)


# Test with torch
class TestRandGaussianNoisedTorch(TorchImageTestCase2D):
@parameterized.expand(TEST_CASES)
def test_correct_results(self, _, keys, mean, std):
gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std)
gaussian_fn.set_random_state(seed)
noised = gaussian_fn({"img": self.imt})
np.random.seed(seed)
np.random.random()
expected = self.imt + np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape)
np.testing.assert_allclose(expected, noised["img"], atol=1e-5, rtol=1e-5)
test_numpy_or_torch(keys, mean, std, self.imt)


if __name__ == "__main__":
Expand Down