diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 58eb0da2a7..ac2d1e46fd 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -200,6 +200,7 @@ class NormalizeIntensity(Transform): nonzero: whether only normalize non-zero values. channel_wise: if using calculated mean and std, calculate on each channel separately or calculate on the entire image directly. + dtype: output data type, defaut to float32. """ def __init__( @@ -208,11 +209,13 @@ def __init__( divisor: Optional[Sequence] = None, nonzero: bool = False, channel_wise: bool = False, + dtype: np.dtype = np.float32, ) -> None: self.subtrahend = subtrahend self.divisor = divisor self.nonzero = nonzero self.channel_wise = channel_wise + self.dtype = dtype def _normalize(self, img: np.ndarray, sub=None, div=None) -> np.ndarray: slices = (img != 0) if self.nonzero else np.ones(img.shape, dtype=np.bool_) @@ -252,7 +255,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: else: img = self._normalize(img, self.subtrahend, self.divisor) - return img + return img.astype(self.dtype) class ThresholdIntensity(Transform): diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index f0030849d9..64f641ecd1 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -227,6 +227,7 @@ class NormalizeIntensityd(MapTransform): nonzero: whether only normalize non-zero values. channel_wise: if using calculated mean and std, calculate on each channel separately or calculate on the entire image directly. + dtype: output data type, defaut to float32. """ def __init__( @@ -236,9 +237,10 @@ def __init__( divisor: Optional[np.ndarray] = None, nonzero: bool = False, channel_wise: bool = False, + dtype: np.dtype = np.float32, ) -> None: super().__init__(keys) - self.normalizer = NormalizeIntensity(subtrahend, divisor, nonzero, channel_wise) + self.normalizer = NormalizeIntensity(subtrahend, divisor, nonzero, channel_wise, dtype) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) diff --git a/tests/test_normalize_intensity.py b/tests/test_normalize_intensity.py index a5021c5f26..06768f77b7 100644 --- a/tests/test_normalize_intensity.py +++ b/tests/test_normalize_intensity.py @@ -59,6 +59,7 @@ class TestNormalizeIntensity(NumpyImageTestCase2D): def test_default(self): normalizer = NormalizeIntensity() normalized = normalizer(self.imt) + self.assertTrue(normalized.dtype == np.float32) expected = (self.imt - np.mean(self.imt)) / np.std(self.imt) np.testing.assert_allclose(normalized, expected, rtol=1e-6)