From 6cbda05d3cf039a184b30d184f5f488465d0ea5c Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 11 Dec 2020 18:51:59 +0800 Subject: [PATCH 1/2] [DLMED] add astype in NormalizeIntensity transform Signed-off-by: Nic Ma --- monai/transforms/intensity/array.py | 3 ++- tests/test_normalize_intensity.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 58eb0da2a7..8b33e5748c 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -237,6 +237,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: """ Apply the transform to `img`, assuming `img` is a channel-first array if `self.channel_wise` is True, """ + img_dtype = img.dtype if self.channel_wise: if self.subtrahend is not None and len(self.subtrahend) != len(img): raise ValueError(f"img has {len(img)} channels, but subtrahend has {len(self.subtrahend)} components.") @@ -252,7 +253,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: else: img = self._normalize(img, self.subtrahend, self.divisor) - return img + return img.astype(img_dtype) class ThresholdIntensity(Transform): diff --git a/tests/test_normalize_intensity.py b/tests/test_normalize_intensity.py index a5021c5f26..06768f77b7 100644 --- a/tests/test_normalize_intensity.py +++ b/tests/test_normalize_intensity.py @@ -59,6 +59,7 @@ class TestNormalizeIntensity(NumpyImageTestCase2D): def test_default(self): normalizer = NormalizeIntensity() normalized = normalizer(self.imt) + self.assertTrue(normalized.dtype == np.float32) expected = (self.imt - np.mean(self.imt)) / np.std(self.imt) np.testing.assert_allclose(normalized, expected, rtol=1e-6) From bb0c77a516a726674d0781056a95311c663ab9ca Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 11 Dec 2020 20:46:56 +0800 Subject: [PATCH 2/2] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/transforms/intensity/array.py | 6 ++++-- monai/transforms/intensity/dictionary.py | 4 +++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 8b33e5748c..ac2d1e46fd 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -200,6 +200,7 @@ class NormalizeIntensity(Transform): nonzero: whether only normalize non-zero values. channel_wise: if using calculated mean and std, calculate on each channel separately or calculate on the entire image directly. + dtype: output data type, defaut to float32. """ def __init__( @@ -208,11 +209,13 @@ def __init__( divisor: Optional[Sequence] = None, nonzero: bool = False, channel_wise: bool = False, + dtype: np.dtype = np.float32, ) -> None: self.subtrahend = subtrahend self.divisor = divisor self.nonzero = nonzero self.channel_wise = channel_wise + self.dtype = dtype def _normalize(self, img: np.ndarray, sub=None, div=None) -> np.ndarray: slices = (img != 0) if self.nonzero else np.ones(img.shape, dtype=np.bool_) @@ -237,7 +240,6 @@ def __call__(self, img: np.ndarray) -> np.ndarray: """ Apply the transform to `img`, assuming `img` is a channel-first array if `self.channel_wise` is True, """ - img_dtype = img.dtype if self.channel_wise: if self.subtrahend is not None and len(self.subtrahend) != len(img): raise ValueError(f"img has {len(img)} channels, but subtrahend has {len(self.subtrahend)} components.") @@ -253,7 +255,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: else: img = self._normalize(img, self.subtrahend, self.divisor) - return img.astype(img_dtype) + return img.astype(self.dtype) class ThresholdIntensity(Transform): diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index f0030849d9..64f641ecd1 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -227,6 +227,7 @@ class NormalizeIntensityd(MapTransform): nonzero: whether only normalize non-zero values. channel_wise: if using calculated mean and std, calculate on each channel separately or calculate on the entire image directly. + dtype: output data type, defaut to float32. """ def __init__( @@ -236,9 +237,10 @@ def __init__( divisor: Optional[np.ndarray] = None, nonzero: bool = False, channel_wise: bool = False, + dtype: np.dtype = np.float32, ) -> None: super().__init__(keys) - self.normalizer = NormalizeIntensity(subtrahend, divisor, nonzero, channel_wise) + self.normalizer = NormalizeIntensity(subtrahend, divisor, nonzero, channel_wise, dtype) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data)