diff --git a/monai/data/synthetic.py b/monai/data/synthetic.py index 20a7829cab..6eec9fd277 100644 --- a/monai/data/synthetic.py +++ b/monai/data/synthetic.py @@ -76,7 +76,7 @@ def create_test_image_2d( labels = np.ceil(image).astype(np.int32) norm = rs.uniform(0, num_seg_classes * noise_max, size=image.shape) - noisyimage = rescale_array(np.maximum(image, norm)) + noisyimage: np.ndarray = rescale_array(np.maximum(image, norm)) # type: ignore if channel_dim is not None: if not (isinstance(channel_dim, int) and channel_dim in (-1, 0, 2)): @@ -151,7 +151,7 @@ def create_test_image_3d( labels = np.ceil(image).astype(np.int32) norm = rs.uniform(0, num_seg_classes * noise_max, size=image.shape) - noisyimage = rescale_array(np.maximum(image, norm)) + noisyimage: np.ndarray = rescale_array(np.maximum(image, norm)) # type: ignore if channel_dim is not None: if not (isinstance(channel_dim, int) and channel_dim in (-1, 0, 3)): diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 8b2bf32145..46c512c96c 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -373,6 +373,8 @@ class ScaleIntensity(Transform): If `minv` and `maxv` not provided, use `factor` to scale image by ``v = v * (1 + factor)``. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, minv: Optional[float] = 0.0, maxv: Optional[float] = 1.0, factor: Optional[float] = None ) -> None: @@ -387,7 +389,7 @@ def __init__( self.maxv = maxv self.factor = factor - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. @@ -396,9 +398,11 @@ def __call__(self, img: np.ndarray) -> np.ndarray: """ if self.minv is not None and self.maxv is not None: - return np.asarray(rescale_array(img, self.minv, self.maxv, img.dtype)) + return rescale_array(img, self.minv, self.maxv, img.dtype) if self.factor is not None: - return np.asarray(img * (1 + self.factor), dtype=img.dtype) + out = img * (1 + self.factor) + out, *_ = convert_data_type(out, dtype=img.dtype) + return out raise ValueError("Incompatible values: minv=None or maxv=None and factor=None.") @@ -408,6 +412,8 @@ class RandScaleIntensity(RandomizableTransform): is randomly picked. """ + backend = ScaleIntensity.backend + def __init__(self, factors: Union[Tuple[float, float], float], prob: float = 0.1) -> None: """ Args: @@ -429,7 +435,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) super().randomize(None) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index bce45b57d3..227b6fb434 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -472,6 +472,8 @@ class ScaleIntensityd(MapTransform): If `minv` and `maxv` not provided, use `factor` to scale image by ``v = v * (1 + factor)``. """ + backend = ScaleIntensity.backend + def __init__( self, keys: KeysCollection, @@ -494,7 +496,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.scaler = ScaleIntensity(minv, maxv, factor) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.scaler(d[key]) @@ -506,6 +508,8 @@ class RandScaleIntensityd(RandomizableTransform, MapTransform): Dictionary-based version :py:class:`monai.transforms.RandScaleIntensity`. """ + backend = ScaleIntensity.backend + def __init__( self, keys: KeysCollection, @@ -539,7 +543,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) super().randomize(None) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) self.randomize() if not self._do_transform: diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e81cb7ca17..e3e61b6c97 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -22,6 +22,7 @@ import monai import monai.transforms.transform from monai.config import DtypeLike, IndexSelection +from monai.config.type_definitions import NdarrayOrTensor from monai.networks.layers import GaussianFilter from monai.transforms.compose import Compose, OneOf from monai.transforms.transform import MapTransform, Transform @@ -37,6 +38,7 @@ min_version, optional_import, ) +from monai.utils.type_conversion import convert_data_type measure, _ = optional_import("skimage.measure", "0.14.2", min_version) ndimage, _ = optional_import("scipy.ndimage") @@ -130,15 +132,17 @@ def zero_margins(img: np.ndarray, margin: int) -> bool: return not np.any(img[:, :margin, :]) and not np.any(img[:, -margin:, :]) -def rescale_array(arr: np.ndarray, minv: float = 0.0, maxv: float = 1.0, dtype: DtypeLike = np.float32): +def rescale_array( + arr: NdarrayOrTensor, minv: float = 0.0, maxv: float = 1.0, dtype: Union[DtypeLike, torch.dtype] = np.float32 +) -> NdarrayOrTensor: """ Rescale the values of numpy array `arr` to be from `minv` to `maxv`. """ if dtype is not None: - arr = arr.astype(dtype) + arr, *_ = convert_data_type(arr, dtype=dtype) - mina = np.min(arr) - maxa = np.max(arr) + mina = arr.min() + maxa = arr.max() if mina == maxa: return arr * minv diff --git a/monai/visualize/img2tensorboard.py b/monai/visualize/img2tensorboard.py index 4a17607320..ccdbdc2396 100644 --- a/monai/visualize/img2tensorboard.py +++ b/monai/visualize/img2tensorboard.py @@ -188,7 +188,7 @@ def plot_2d_or_3d_image( d: np.ndarray = data_index.detach().cpu().numpy() if isinstance(data_index, torch.Tensor) else data_index if d.ndim == 2: - d = rescale_array(d, 0, 1) + d = rescale_array(d, 0, 1) # type: ignore dataformats = "HW" writer.add_image(f"{tag}_{dataformats}", d, step, dataformats=dataformats) return diff --git a/tests/test_rand_scale_intensity.py b/tests/test_rand_scale_intensity.py index 2126301758..750d88bfad 100644 --- a/tests/test_rand_scale_intensity.py +++ b/tests/test_rand_scale_intensity.py @@ -14,17 +14,18 @@ import numpy as np from monai.transforms import RandScaleIntensity -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestRandScaleIntensity(NumpyImageTestCase2D): def test_value(self): - scaler = RandScaleIntensity(factors=0.5, prob=1.0) - scaler.set_random_state(seed=0) - result = scaler(self.imt) - np.random.seed(0) - expected = (self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32) - np.testing.assert_allclose(result, expected) + for p in TEST_NDARRAYS: + scaler = RandScaleIntensity(factors=0.5, prob=1.0) + scaler.set_random_state(seed=0) + result = scaler(p(self.imt)) + np.random.seed(0) + expected = p((self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32)) + assert_allclose(result, expected, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_rand_scale_intensityd.py b/tests/test_rand_scale_intensityd.py index 6e207e3cc2..a8d2e63f65 100644 --- a/tests/test_rand_scale_intensityd.py +++ b/tests/test_rand_scale_intensityd.py @@ -14,18 +14,19 @@ import numpy as np from monai.transforms import RandScaleIntensityd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestRandScaleIntensityd(NumpyImageTestCase2D): def test_value(self): - key = "img" - scaler = RandScaleIntensityd(keys=[key], factors=0.5, prob=1.0) - scaler.set_random_state(seed=0) - result = scaler({key: self.imt}) - np.random.seed(0) - expected = (self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32) - np.testing.assert_allclose(result[key], expected) + for p in TEST_NDARRAYS: + key = "img" + scaler = RandScaleIntensityd(keys=[key], factors=0.5, prob=1.0) + scaler.set_random_state(seed=0) + result = scaler({key: p(self.imt)}) + np.random.seed(0) + expected = (self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32) + assert_allclose(result[key], expected) if __name__ == "__main__": diff --git a/tests/test_scale_intensity.py b/tests/test_scale_intensity.py index 61e89191fd..c2485af616 100644 --- a/tests/test_scale_intensity.py +++ b/tests/test_scale_intensity.py @@ -14,24 +14,26 @@ import numpy as np from monai.transforms import ScaleIntensity -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestScaleIntensity(NumpyImageTestCase2D): def test_range_scale(self): - scaler = ScaleIntensity(minv=1.0, maxv=2.0) - result = scaler(self.imt) - mina = np.min(self.imt) - maxa = np.max(self.imt) - norm = (self.imt - mina) / (maxa - mina) - expected = (norm * (2.0 - 1.0)) + 1.0 - np.testing.assert_allclose(result, expected) + for p in TEST_NDARRAYS: + scaler = ScaleIntensity(minv=1.0, maxv=2.0) + result = scaler(p(self.imt)) + mina = self.imt.min() + maxa = self.imt.max() + norm = (self.imt - mina) / (maxa - mina) + expected = p((norm * (2.0 - 1.0)) + 1.0) + assert_allclose(result, expected, rtol=1e-7, atol=0) def test_factor_scale(self): - scaler = ScaleIntensity(minv=None, maxv=None, factor=0.1) - result = scaler(self.imt) - expected = (self.imt * (1 + 0.1)).astype(np.float32) - np.testing.assert_allclose(result, expected) + for p in TEST_NDARRAYS: + scaler = ScaleIntensity(minv=None, maxv=None, factor=0.1) + result = scaler(p(self.imt)) + expected = p((self.imt * (1 + 0.1)).astype(np.float32)) + assert_allclose(result, expected, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_scale_intensityd.py b/tests/test_scale_intensityd.py index 688c99c6af..6e13dbc272 100644 --- a/tests/test_scale_intensityd.py +++ b/tests/test_scale_intensityd.py @@ -14,26 +14,28 @@ import numpy as np from monai.transforms import ScaleIntensityd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestScaleIntensityd(NumpyImageTestCase2D): def test_range_scale(self): - key = "img" - scaler = ScaleIntensityd(keys=[key], minv=1.0, maxv=2.0) - result = scaler({key: self.imt}) - mina = np.min(self.imt) - maxa = np.max(self.imt) - norm = (self.imt - mina) / (maxa - mina) - expected = (norm * (2.0 - 1.0)) + 1.0 - np.testing.assert_allclose(result[key], expected) + for p in TEST_NDARRAYS: + key = "img" + scaler = ScaleIntensityd(keys=[key], minv=1.0, maxv=2.0) + result = scaler({key: p(self.imt)}) + mina = np.min(self.imt) + maxa = np.max(self.imt) + norm = (self.imt - mina) / (maxa - mina) + expected = (norm * (2.0 - 1.0)) + 1.0 + assert_allclose(result[key], expected) def test_factor_scale(self): - key = "img" - scaler = ScaleIntensityd(keys=[key], minv=None, maxv=None, factor=0.1) - result = scaler({key: self.imt}) - expected = (self.imt * (1 + 0.1)).astype(np.float32) - np.testing.assert_allclose(result[key], expected) + for p in TEST_NDARRAYS: + key = "img" + scaler = ScaleIntensityd(keys=[key], minv=None, maxv=None, factor=0.1) + result = scaler({key: p(self.imt)}) + expected = (self.imt * (1 + 0.1)).astype(np.float32) + assert_allclose(result[key], expected) if __name__ == "__main__":