From 62d5d79176f4553e22cbcb1fdd881e865217b25e Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 22 Sep 2023 15:06:08 +0800 Subject: [PATCH 1/2] fix #6629 Signed-off-by: KumoLiu --- monai/transforms/intensity/array.py | 25 ++++++++++++++++++++---- monai/transforms/intensity/dictionary.py | 14 +++++++++++-- tests/test_rand_shift_intensity.py | 14 +++++++++++++ tests/test_rand_shift_intensityd.py | 16 +++++++++++++++ 4 files changed, 63 insertions(+), 6 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index ed59bbc8f3..e776c07f77 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -255,7 +255,9 @@ class RandShiftIntensity(RandomizableTransform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, offsets: tuple[float, float] | float, safe: bool = False, prob: float = 0.1) -> None: + def __init__( + self, offsets: tuple[float, float] | float, safe: bool = False, prob: float = 0.1, channel_wise: bool = False + ) -> None: """ Args: offsets: offset range to randomly shift. @@ -263,6 +265,8 @@ def __init__(self, offsets: tuple[float, float] | float, safe: bool = False, pro safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`. E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`. prob: probability of shift. + channel_wise: if True, shift intensity on each channel separately. For each channel, a random offset will be chosen. + Please ensure that the first dimension represents the channel of the image if True. """ RandomizableTransform.__init__(self, prob) if isinstance(offsets, (int, float)): @@ -272,13 +276,17 @@ def __init__(self, offsets: tuple[float, float] | float, safe: bool = False, pro else: self.offsets = (min(offsets), max(offsets)) self._offset = self.offsets[0] + self.channel_wise = channel_wise self._shifter = ShiftIntensity(self._offset, safe) def randomize(self, data: Any | None = None) -> None: super().randomize(None) if not self._do_transform: return None - self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) + if self.channel_wise: + self._offset = [self.R.uniform(low=self.offsets[0], high=self.offsets[1]) for _ in range(data.shape[0])] # type: ignore + else: + self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) def __call__(self, img: NdarrayOrTensor, factor: float | None = None, randomize: bool = True) -> NdarrayOrTensor: """ @@ -292,12 +300,21 @@ def __call__(self, img: NdarrayOrTensor, factor: float | None = None, randomize: """ img = convert_to_tensor(img, track_meta=get_track_meta()) if randomize: - self.randomize() + self.randomize(img) if not self._do_transform: return img - return self._shifter(img, self._offset if factor is None else self._offset * factor) + ret: NdarrayOrTensor + if self.channel_wise: + out = [] + for i, d in enumerate(img): + out_channel = self._shifter(d, self._offset[i] if factor is None else self._offset[i] * factor) + out.append(out_channel) + ret = torch.stack(out) # type: ignore + else: + ret = self._shifter(img, self._offset if factor is None else self._offset * factor) + return ret class StdShiftIntensity(Transform): diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 32052ad406..058ef87b95 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -373,6 +373,7 @@ def __init__( meta_keys: KeysCollection | None = None, meta_key_postfix: str = DEFAULT_POST_FIX, prob: float = 0.1, + channel_wise: bool = False, allow_missing_keys: bool = False, ) -> None: """ @@ -399,6 +400,8 @@ def __init__( used to extract the factor value is `factor_key` is not None. prob: probability of shift. (Default 0.1, with 10% probability it returns an array shifted intensity.) + channel_wise: if True, shift intensity on each channel separately. For each channel, a random offset will be chosen. + Please ensure that the first dimension represents the channel of the image if True. allow_missing_keys: don't raise exception if key is missing. """ MapTransform.__init__(self, keys, allow_missing_keys) @@ -409,7 +412,7 @@ def __init__( if len(self.keys) != len(self.meta_keys): raise ValueError("meta_keys should have the same length as keys.") self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self.shifter = RandShiftIntensity(offsets=offsets, safe=safe, prob=1.0) + self.shifter = RandShiftIntensity(offsets=offsets, safe=safe, prob=1.0, channel_wise=channel_wise) def set_random_state( self, seed: int | None = None, state: np.random.RandomState | None = None @@ -426,8 +429,15 @@ def __call__(self, data) -> dict[Hashable, NdarrayOrTensor]: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d + # expect all the specified keys have same spatial shape and share same random holes + first_key: Hashable = self.first_key(d) + if first_key == (): + for key in self.key_iterator(d): + d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) + return d + # all the keys share the same random shift factor - self.shifter.randomize(None) + self.shifter.randomize(d[first_key]) for key, factor_key, meta_key, meta_key_postfix in self.key_iterator( d, self.factor_key, self.meta_keys, self.meta_key_postfix ): diff --git a/tests/test_rand_shift_intensity.py b/tests/test_rand_shift_intensity.py index 12b7ccf526..01ac55f7b8 100644 --- a/tests/test_rand_shift_intensity.py +++ b/tests/test_rand_shift_intensity.py @@ -33,6 +33,20 @@ def test_value(self, p): expected = self.imt + np.random.uniform(low=-1.0, high=1.0) assert_allclose(result, expected, type_test="tensor") + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_channel_wise(self, p): + scaler = RandShiftIntensity(offsets=3.0, channel_wise=True, prob=1.0) + scaler.set_random_state(seed=0) + im = p(self.imt) + result = scaler(im) + np.random.seed(0) + # simulate the randomize() of transform + np.random.random() + channel_num = self.imt.shape[0] + factor = [np.random.uniform(low=-3.0, high=3.0) for _ in range(channel_num)] + expected = p(np.stack([np.asarray((self.imt[i]) + factor[i]) for i in range(channel_num)]).astype(np.float32)) + assert_allclose(result, expected, atol=0, rtol=1e-5, type_test=False) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_shift_intensityd.py b/tests/test_rand_shift_intensityd.py index 92bc39dd20..7522676eb0 100644 --- a/tests/test_rand_shift_intensityd.py +++ b/tests/test_rand_shift_intensityd.py @@ -46,6 +46,22 @@ def test_factor(self): expected = self.imt + np.random.uniform(low=-1.0, high=1.0) * np.nanmax(self.imt) np.testing.assert_allclose(result[key], expected) + def test_channel_wise(self): + key = "img" + for p in TEST_NDARRAYS: + scaler = RandShiftIntensityd(keys=[key], offsets=3.0, prob=1.0, channel_wise=True) + scaler.set_random_state(seed=0) + result = scaler({key: p(self.imt)}) + np.random.seed(0) + # simulate the randomize function of transform + np.random.random() + channel_num = self.imt.shape[0] + factor = [np.random.uniform(low=-3.0, high=3.0) for _ in range(channel_num)] + expected = p( + np.stack([np.asarray((self.imt[i]) + factor[i]) for i in range(channel_num)]).astype(np.float32) + ) + assert_allclose(result[key], p(expected), type_test="tensor") + if __name__ == "__main__": unittest.main() From 0e82765fe2bb9657505e35c0bb2aba6e33e8ff6f Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 22 Sep 2023 15:09:19 +0800 Subject: [PATCH 2/2] fix mypy Signed-off-by: KumoLiu --- monai/transforms/intensity/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index e776c07f77..f9667402c9 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -309,7 +309,7 @@ def __call__(self, img: NdarrayOrTensor, factor: float | None = None, randomize: if self.channel_wise: out = [] for i, d in enumerate(img): - out_channel = self._shifter(d, self._offset[i] if factor is None else self._offset[i] * factor) + out_channel = self._shifter(d, self._offset[i] if factor is None else self._offset[i] * factor) # type: ignore out.append(out_channel) ret = torch.stack(out) # type: ignore else: