Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions monai/data/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def create_test_image_2d(
labels = np.ceil(image).astype(np.int32)

norm = rs.uniform(0, num_seg_classes * noise_max, size=image.shape)
noisyimage = rescale_array(np.maximum(image, norm))
noisyimage: np.ndarray = rescale_array(np.maximum(image, norm)) # type: ignore

if channel_dim is not None:
if not (isinstance(channel_dim, int) and channel_dim in (-1, 0, 2)):
Expand Down Expand Up @@ -151,7 +151,7 @@ def create_test_image_3d(
labels = np.ceil(image).astype(np.int32)

norm = rs.uniform(0, num_seg_classes * noise_max, size=image.shape)
noisyimage = rescale_array(np.maximum(image, norm))
noisyimage: np.ndarray = rescale_array(np.maximum(image, norm)) # type: ignore

if channel_dim is not None:
if not (isinstance(channel_dim, int) and channel_dim in (-1, 0, 3)):
Expand Down
14 changes: 10 additions & 4 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@ class ScaleIntensity(Transform):
If `minv` and `maxv` not provided, use `factor` to scale image by ``v = v * (1 + factor)``.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self, minv: Optional[float] = 0.0, maxv: Optional[float] = 1.0, factor: Optional[float] = None
) -> None:
Expand All @@ -387,7 +389,7 @@ def __init__(
self.maxv = maxv
self.factor = factor

def __call__(self, img: np.ndarray) -> np.ndarray:
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.

Expand All @@ -396,9 +398,11 @@ def __call__(self, img: np.ndarray) -> np.ndarray:

"""
if self.minv is not None and self.maxv is not None:
return np.asarray(rescale_array(img, self.minv, self.maxv, img.dtype))
return rescale_array(img, self.minv, self.maxv, img.dtype)
if self.factor is not None:
return np.asarray(img * (1 + self.factor), dtype=img.dtype)
out = img * (1 + self.factor)
out, *_ = convert_data_type(out, dtype=img.dtype)
return out
raise ValueError("Incompatible values: minv=None or maxv=None and factor=None.")


Expand All @@ -408,6 +412,8 @@ class RandScaleIntensity(RandomizableTransform):
is randomly picked.
"""

backend = ScaleIntensity.backend

def __init__(self, factors: Union[Tuple[float, float], float], prob: float = 0.1) -> None:
"""
Args:
Expand All @@ -429,7 +435,7 @@ def randomize(self, data: Optional[Any] = None) -> None:
self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])
super().randomize(None)

def __call__(self, img: np.ndarray) -> np.ndarray:
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
Expand Down
8 changes: 6 additions & 2 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,8 @@ class ScaleIntensityd(MapTransform):
If `minv` and `maxv` not provided, use `factor` to scale image by ``v = v * (1 + factor)``.
"""

backend = ScaleIntensity.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -494,7 +496,7 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.scaler = ScaleIntensity(minv, maxv, factor)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.scaler(d[key])
Expand All @@ -506,6 +508,8 @@ class RandScaleIntensityd(RandomizableTransform, MapTransform):
Dictionary-based version :py:class:`monai.transforms.RandScaleIntensity`.
"""

backend = ScaleIntensity.backend

def __init__(
self,
keys: KeysCollection,
Expand Down Expand Up @@ -539,7 +543,7 @@ def randomize(self, data: Optional[Any] = None) -> None:
self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])
super().randomize(None)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
self.randomize()
if not self._do_transform:
Expand Down
12 changes: 8 additions & 4 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import monai
import monai.transforms.transform
from monai.config import DtypeLike, IndexSelection
from monai.config.type_definitions import NdarrayOrTensor
from monai.networks.layers import GaussianFilter
from monai.transforms.compose import Compose, OneOf
from monai.transforms.transform import MapTransform, Transform
Expand All @@ -37,6 +38,7 @@
min_version,
optional_import,
)
from monai.utils.type_conversion import convert_data_type

measure, _ = optional_import("skimage.measure", "0.14.2", min_version)
ndimage, _ = optional_import("scipy.ndimage")
Expand Down Expand Up @@ -130,15 +132,17 @@ def zero_margins(img: np.ndarray, margin: int) -> bool:
return not np.any(img[:, :margin, :]) and not np.any(img[:, -margin:, :])


def rescale_array(arr: np.ndarray, minv: float = 0.0, maxv: float = 1.0, dtype: DtypeLike = np.float32):
def rescale_array(
arr: NdarrayOrTensor, minv: float = 0.0, maxv: float = 1.0, dtype: Union[DtypeLike, torch.dtype] = np.float32
) -> NdarrayOrTensor:
"""
Rescale the values of numpy array `arr` to be from `minv` to `maxv`.
"""
if dtype is not None:
arr = arr.astype(dtype)
arr, *_ = convert_data_type(arr, dtype=dtype)

mina = np.min(arr)
maxa = np.max(arr)
mina = arr.min()
maxa = arr.max()

if mina == maxa:
return arr * minv
Expand Down
2 changes: 1 addition & 1 deletion monai/visualize/img2tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def plot_2d_or_3d_image(
d: np.ndarray = data_index.detach().cpu().numpy() if isinstance(data_index, torch.Tensor) else data_index

if d.ndim == 2:
d = rescale_array(d, 0, 1)
d = rescale_array(d, 0, 1) # type: ignore
dataformats = "HW"
writer.add_image(f"{tag}_{dataformats}", d, step, dataformats=dataformats)
return
Expand Down
15 changes: 8 additions & 7 deletions tests/test_rand_scale_intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@
import numpy as np

from monai.transforms import RandScaleIntensity
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose


class TestRandScaleIntensity(NumpyImageTestCase2D):
def test_value(self):
scaler = RandScaleIntensity(factors=0.5, prob=1.0)
scaler.set_random_state(seed=0)
result = scaler(self.imt)
np.random.seed(0)
expected = (self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32)
np.testing.assert_allclose(result, expected)
for p in TEST_NDARRAYS:
scaler = RandScaleIntensity(factors=0.5, prob=1.0)
scaler.set_random_state(seed=0)
result = scaler(p(self.imt))
np.random.seed(0)
expected = p((self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32))
assert_allclose(result, expected, rtol=1e-7, atol=0)


if __name__ == "__main__":
Expand Down
17 changes: 9 additions & 8 deletions tests/test_rand_scale_intensityd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@
import numpy as np

from monai.transforms import RandScaleIntensityd
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose


class TestRandScaleIntensityd(NumpyImageTestCase2D):
def test_value(self):
key = "img"
scaler = RandScaleIntensityd(keys=[key], factors=0.5, prob=1.0)
scaler.set_random_state(seed=0)
result = scaler({key: self.imt})
np.random.seed(0)
expected = (self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32)
np.testing.assert_allclose(result[key], expected)
for p in TEST_NDARRAYS:
key = "img"
scaler = RandScaleIntensityd(keys=[key], factors=0.5, prob=1.0)
scaler.set_random_state(seed=0)
result = scaler({key: p(self.imt)})
np.random.seed(0)
expected = (self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32)
assert_allclose(result[key], expected)


if __name__ == "__main__":
Expand Down
26 changes: 14 additions & 12 deletions tests/test_scale_intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,26 @@
import numpy as np

from monai.transforms import ScaleIntensity
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose


class TestScaleIntensity(NumpyImageTestCase2D):
def test_range_scale(self):
scaler = ScaleIntensity(minv=1.0, maxv=2.0)
result = scaler(self.imt)
mina = np.min(self.imt)
maxa = np.max(self.imt)
norm = (self.imt - mina) / (maxa - mina)
expected = (norm * (2.0 - 1.0)) + 1.0
np.testing.assert_allclose(result, expected)
for p in TEST_NDARRAYS:
scaler = ScaleIntensity(minv=1.0, maxv=2.0)
result = scaler(p(self.imt))
mina = self.imt.min()
maxa = self.imt.max()
norm = (self.imt - mina) / (maxa - mina)
expected = p((norm * (2.0 - 1.0)) + 1.0)
assert_allclose(result, expected, rtol=1e-7, atol=0)

def test_factor_scale(self):
scaler = ScaleIntensity(minv=None, maxv=None, factor=0.1)
result = scaler(self.imt)
expected = (self.imt * (1 + 0.1)).astype(np.float32)
np.testing.assert_allclose(result, expected)
for p in TEST_NDARRAYS:
scaler = ScaleIntensity(minv=None, maxv=None, factor=0.1)
result = scaler(p(self.imt))
expected = p((self.imt * (1 + 0.1)).astype(np.float32))
assert_allclose(result, expected, rtol=1e-7, atol=0)


if __name__ == "__main__":
Expand Down
30 changes: 16 additions & 14 deletions tests/test_scale_intensityd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,28 @@
import numpy as np

from monai.transforms import ScaleIntensityd
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose


class TestScaleIntensityd(NumpyImageTestCase2D):
def test_range_scale(self):
key = "img"
scaler = ScaleIntensityd(keys=[key], minv=1.0, maxv=2.0)
result = scaler({key: self.imt})
mina = np.min(self.imt)
maxa = np.max(self.imt)
norm = (self.imt - mina) / (maxa - mina)
expected = (norm * (2.0 - 1.0)) + 1.0
np.testing.assert_allclose(result[key], expected)
for p in TEST_NDARRAYS:
key = "img"
scaler = ScaleIntensityd(keys=[key], minv=1.0, maxv=2.0)
result = scaler({key: p(self.imt)})
mina = np.min(self.imt)
maxa = np.max(self.imt)
norm = (self.imt - mina) / (maxa - mina)
expected = (norm * (2.0 - 1.0)) + 1.0
assert_allclose(result[key], expected)

def test_factor_scale(self):
key = "img"
scaler = ScaleIntensityd(keys=[key], minv=None, maxv=None, factor=0.1)
result = scaler({key: self.imt})
expected = (self.imt * (1 + 0.1)).astype(np.float32)
np.testing.assert_allclose(result[key], expected)
for p in TEST_NDARRAYS:
key = "img"
scaler = ScaleIntensityd(keys=[key], minv=None, maxv=None, factor=0.1)
result = scaler({key: p(self.imt)})
expected = (self.imt * (1 + 0.1)).astype(np.float32)
assert_allclose(result[key], expected)


if __name__ == "__main__":
Expand Down