From 6f2b6c648bc592f4374111b228aa1fb2a66baea8 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Thu, 27 Jul 2023 20:04:37 +0530 Subject: [PATCH 1/2] feat: add channel_wise to ShiftIntensity Signed-off-by: Saurav Maheshkar --- monai/transforms/intensity/array.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index f8eadcfb1b..f0afb06bc4 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -227,13 +227,15 @@ class ShiftIntensity(Transform): offset: offset value to shift the intensity of image. safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`. E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`. + channel_wise: if `True`, shift the intensity for each channel of image with random offset """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, offset: float, safe: bool = False) -> None: + def __init__(self, offset: float, safe: bool = False, channel_wise: bool = False) -> None: self.offset = offset self.safe = safe + self.channel_wise = channel_wise def __call__(self, img: NdarrayOrTensor, offset: float | None = None) -> NdarrayOrTensor: """ @@ -242,10 +244,16 @@ def __call__(self, img: NdarrayOrTensor, offset: float | None = None) -> Ndarray img = convert_to_tensor(img, track_meta=get_track_meta()) offset = self.offset if offset is None else offset - out = img + offset - out, *_ = convert_data_type(data=out, dtype=img.dtype, safe=self.safe) - - return out + if self.channel_wise: + for i, d in enumerate(img): + out = d + offset + out, *_ = convert_data_type(data=out, dtype=d.dtype, safe=self.safe) + img[i] = out + return img + else: + out = img + offset + out, *_ = convert_data_type(data=out, dtype=img.dtype, safe=self.safe) + return out class RandShiftIntensity(RandomizableTransform): @@ -255,13 +263,16 @@ class RandShiftIntensity(RandomizableTransform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, offsets: tuple[float, float] | float, safe: bool = False, prob: float = 0.1) -> None: + def __init__( + self, offsets: tuple[float, float] | float, safe: bool = False, channel_wise: bool = False, prob: float = 0.1 + ) -> None: """ Args: offsets: offset range to randomly shift. if single number, offset value is picked from (-offsets, offsets). safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`. E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`. + channel_wise: if `True`, shift the intensity for each channel of image with random offset prob: probability of shift. """ RandomizableTransform.__init__(self, prob) @@ -272,7 +283,7 @@ def __init__(self, offsets: tuple[float, float] | float, safe: bool = False, pro else: self.offsets = (min(offsets), max(offsets)) self._offset = self.offsets[0] - self._shifter = ShiftIntensity(self._offset, safe) + self._shifter = ShiftIntensity(self._offset, safe, channel_wise) def randomize(self, data: Any | None = None) -> None: super().randomize(None) From ce343e867934a0e84cab980c66ca913fdb6d4a43 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Thu, 27 Jul 2023 20:06:42 +0530 Subject: [PATCH 2/2] style: add ignore flag Signed-off-by: Saurav Maheshkar --- monai/transforms/intensity/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index f0afb06bc4..880ea562a1 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -253,7 +253,7 @@ def __call__(self, img: NdarrayOrTensor, offset: float | None = None) -> Ndarray else: out = img + offset out, *_ = convert_data_type(data=out, dtype=img.dtype, safe=self.safe) - return out + return out # type: ignore class RandShiftIntensity(RandomizableTransform):