Skip to content
Merged
25 changes: 19 additions & 6 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,8 +608,8 @@ class NormalizeIntensity(Transform):
subtrahend: the amount to subtract by (usually the mean).
divisor: the amount to divide by (usually the standard deviation).
nonzero: whether only normalize non-zero values.
channel_wise: if using calculated mean and std, calculate on each channel separately
or calculate on the entire image directly.
channel_wise: if True, calculate on each channel separately, otherwise, calculate on
the entire image directly. default to False.
dtype: output data type, if None, same as input image. defaults to float32.
"""

Expand Down Expand Up @@ -919,6 +919,8 @@ class ScaleIntensityRangePercentiles(Transform):
b_max: intensity target range max.
clip: whether to perform clip after scaling.
relative: whether to scale to the corresponding percentiles of [b_min, b_max].
channel_wise: if True, compute intensity percentile and normalize every channel separately.
default to False.
dtype: output data type, if None, same as input image. defaults to float32.
"""

Expand All @@ -932,6 +934,7 @@ def __init__(
b_max: Optional[float],
clip: bool = False,
relative: bool = False,
channel_wise: bool = False,
dtype: DtypeLike = np.float32,
) -> None:
if lower < 0.0 or lower > 100.0:
Expand All @@ -944,12 +947,10 @@ def __init__(
self.b_max = b_max
self.clip = clip
self.relative = relative
self.channel_wise = channel_wise
self.dtype = dtype

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
def _normalize(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
a_min: float = percentile(img, self.lower) # type: ignore
a_max: float = percentile(img, self.upper) # type: ignore
b_min = self.b_min
Expand All @@ -967,6 +968,18 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
img = scalar(img)
return img

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
if self.channel_wise:
for i, d in enumerate(img):
img[i] = self._normalize(img=d) # type: ignore
else:
img = self._normalize(img=img)

return img


class MaskIntensity(Transform):
"""
Expand Down
9 changes: 6 additions & 3 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,8 +655,8 @@ class NormalizeIntensityd(MapTransform):
subtrahend: the amount to subtract by (usually the mean)
divisor: the amount to divide by (usually the standard deviation)
nonzero: whether only normalize non-zero values.
channel_wise: if using calculated mean and std, calculate on each channel separately
or calculate on the entire image directly.
channel_wise: if True, calculate on each channel separately, otherwise, calculate on
the entire image directly. default to False.
dtype: output data type, if None, same as input image. defaults to float32.
allow_missing_keys: don't raise exception if key is missing.
"""
Expand Down Expand Up @@ -844,6 +844,8 @@ class ScaleIntensityRangePercentilesd(MapTransform):
b_max: intensity target range max.
clip: whether to perform clip after scaling.
relative: whether to scale to the corresponding percentiles of [b_min, b_max]
channel_wise: if True, compute intensity percentile and normalize every channel separately.
default to False.
dtype: output data type, if None, same as input image. defaults to float32.
allow_missing_keys: don't raise exception if key is missing.
"""
Expand All @@ -859,11 +861,12 @@ def __init__(
b_max: Optional[float],
clip: bool = False,
relative: bool = False,
channel_wise: bool = False,
dtype: DtypeLike = np.float32,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative, dtype)
self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative, channel_wise, dtype)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
Expand Down
11 changes: 8 additions & 3 deletions monai/transforms/utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,21 @@ def clip(a: NdarrayOrTensor, a_min, a_max) -> NdarrayOrTensor:
return result


def percentile(x: NdarrayOrTensor, q) -> Union[NdarrayOrTensor, float, int]:
def percentile(x: NdarrayOrTensor, q, dim: Optional[int] = None) -> Union[NdarrayOrTensor, float, int]:
"""`np.percentile` with equivalent implementation for torch.

Pytorch uses `quantile`, but this functionality is only available from v1.7.
For earlier methods, we calculate it ourselves. This doesn't do interpolation,
so is the equivalent of ``numpy.percentile(..., interpolation="nearest")``.
For more details, please refer to:
https://pytorch.org/docs/stable/generated/torch.quantile.html.
https://numpy.org/doc/stable/reference/generated/numpy.percentile.html.

Args:
x: input data
q: percentile to compute (should in range 0 <= q <= 100)
dim: the dim along which the percentiles are computed. default is to compute the percentile
along a flattened version of the array. only work for numpy array or Tensor with PyTorch >= 1.7.0.

Returns:
Resulting value (scalar)
Expand All @@ -102,11 +107,11 @@ def percentile(x: NdarrayOrTensor, q) -> Union[NdarrayOrTensor, float, int]:
raise ValueError
result: Union[NdarrayOrTensor, float, int]
if isinstance(x, np.ndarray):
result = np.percentile(x, q)
result = np.percentile(x, q, axis=dim)
else:
q = torch.tensor(q, device=x.device)
if hasattr(torch, "quantile"):
result = torch.quantile(x, q / 100.0)
result = torch.quantile(x, q / 100.0, dim=dim)
else:
# Note that ``kthvalue()`` works one-based, i.e., the first sorted value
# corresponds to k=1, not k=0. Thus, we need the `1 +`.
Expand Down
24 changes: 22 additions & 2 deletions tests/test_scale_intensity_range_percentiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class TestScaleIntensityRangePercentiles(NumpyImageTestCase2D):
def test_scaling(self):
img = self.imt
img = self.imt[0]
lower = 10
upper = 99
b_min = 0
Expand All @@ -34,7 +34,7 @@ def test_scaling(self):
assert_allclose(result, p(expected), rtol=1e-4)

def test_relative_scaling(self):
img = self.imt
img = self.imt[0]
lower = 10
upper = 99
b_min = 100
Expand Down Expand Up @@ -65,6 +65,26 @@ def test_invalid_instantiation(self):
self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=30, upper=-20, b_min=0, b_max=255)
self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=30, upper=900, b_min=0, b_max=255)

def test_channel_wise(self):
img = self.imt[0]
lower = 10
upper = 99
b_min = 0
b_max = 255
scaler = ScaleIntensityRangePercentiles(
lower=lower, upper=upper, b_min=b_min, b_max=b_max, channel_wise=True, dtype=np.uint8
)
expected = []
for c in img:
a_min = np.percentile(c, lower)
a_max = np.percentile(c, upper)
expected.append(((c - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min)
expected = np.stack(expected).astype(np.uint8)

for p in TEST_NDARRAYS:
result = scaler(p(img))
assert_allclose(result, p(expected), rtol=1e-4)


if __name__ == "__main__":
unittest.main()
25 changes: 22 additions & 3 deletions tests/test_scale_intensity_range_percentilesd.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@ def test_scaling(self):

a_min = np.percentile(img, lower)
a_max = np.percentile(img, upper)
expected = (img - a_min) / (a_max - a_min)
expected = (expected * (b_max - b_min)) + b_min
expected = (((img - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min).astype(np.uint8)

for p in TEST_NDARRAYS:
data = {"img": p(img)}
scaler = ScaleIntensityRangePercentilesd(
keys=data.keys(), lower=lower, upper=upper, b_min=b_min, b_max=b_max
keys=data.keys(), lower=lower, upper=upper, b_min=b_min, b_max=b_max, dtype=np.uint8
)
assert_allclose(p(expected), scaler(data)["img"], rtol=1e-4)

Expand Down Expand Up @@ -75,6 +74,26 @@ def test_invalid_instantiation(self):
s = ScaleIntensityRangePercentilesd(keys=["img"], lower=30, upper=90, b_min=None, b_max=20, relative=True)
s(self.imt)

def test_channel_wise(self):
img = self.imt
lower = 10
upper = 99
b_min = 0
b_max = 255
scaler = ScaleIntensityRangePercentilesd(
keys="img", lower=lower, upper=upper, b_min=b_min, b_max=b_max, channel_wise=True, dtype=np.uint8
)
expected = []
for c in img:
a_min = np.percentile(c, lower)
a_max = np.percentile(c, upper)
expected.append((((c - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min).astype(np.uint8))
expected = np.stack(expected)

for p in TEST_NDARRAYS:
data = {"img": p(img)}
assert_allclose(scaler(data)["img"], p(expected), rtol=1e-4)


if __name__ == "__main__":
unittest.main()
14 changes: 13 additions & 1 deletion tests/test_utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from monai.transforms.utils_pytorch_numpy_unification import percentile
from monai.utils import set_determinism
from tests.utils import TEST_NDARRAYS, assert_allclose
from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose


class TestPytorchNumpyUnification(unittest.TestCase):
Expand All @@ -42,6 +42,18 @@ def test_fails(self):
with self.assertRaises(ValueError):
percentile(arr, q)

@SkipIfBeforePyTorchVersion((1, 7))
def test_dim(self):
q = np.random.randint(0, 100, size=50)
results = []
for p in TEST_NDARRAYS:
arr = p(np.arange(6).reshape(1, 2, 3).astype(np.float32))
results.append(percentile(arr, q, dim=1))
# pre torch 1.7, no `quantile`. Our own method doesn't interpolate,
# so we can only be accurate to 0.5
atol = 0.5 if not hasattr(torch, "quantile") else 1e-4
assert_allclose(results[0], results[-1], type_test=False, atol=atol)


if __name__ == "__main__":
unittest.main()