diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 289a1bcbcb..c66482678a 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -608,8 +608,8 @@ class NormalizeIntensity(Transform): subtrahend: the amount to subtract by (usually the mean). divisor: the amount to divide by (usually the standard deviation). nonzero: whether only normalize non-zero values. - channel_wise: if using calculated mean and std, calculate on each channel separately - or calculate on the entire image directly. + channel_wise: if True, calculate on each channel separately, otherwise, calculate on + the entire image directly. default to False. dtype: output data type, if None, same as input image. defaults to float32. """ @@ -919,6 +919,8 @@ class ScaleIntensityRangePercentiles(Transform): b_max: intensity target range max. clip: whether to perform clip after scaling. relative: whether to scale to the corresponding percentiles of [b_min, b_max]. + channel_wise: if True, compute intensity percentile and normalize every channel separately. + default to False. dtype: output data type, if None, same as input image. defaults to float32. """ @@ -932,6 +934,7 @@ def __init__( b_max: Optional[float], clip: bool = False, relative: bool = False, + channel_wise: bool = False, dtype: DtypeLike = np.float32, ) -> None: if lower < 0.0 or lower > 100.0: @@ -944,12 +947,10 @@ def __init__( self.b_max = b_max self.clip = clip self.relative = relative + self.channel_wise = channel_wise self.dtype = dtype - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: - """ - Apply the transform to `img`. - """ + def _normalize(self, img: NdarrayOrTensor) -> NdarrayOrTensor: a_min: float = percentile(img, self.lower) # type: ignore a_max: float = percentile(img, self.upper) # type: ignore b_min = self.b_min @@ -967,6 +968,18 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: img = scalar(img) return img + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + """ + Apply the transform to `img`. + """ + if self.channel_wise: + for i, d in enumerate(img): + img[i] = self._normalize(img=d) # type: ignore + else: + img = self._normalize(img=img) + + return img + class MaskIntensity(Transform): """ diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index fa2de4c7b8..683d75763f 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -655,8 +655,8 @@ class NormalizeIntensityd(MapTransform): subtrahend: the amount to subtract by (usually the mean) divisor: the amount to divide by (usually the standard deviation) nonzero: whether only normalize non-zero values. - channel_wise: if using calculated mean and std, calculate on each channel separately - or calculate on the entire image directly. + channel_wise: if True, calculate on each channel separately, otherwise, calculate on + the entire image directly. default to False. dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ @@ -844,6 +844,8 @@ class ScaleIntensityRangePercentilesd(MapTransform): b_max: intensity target range max. clip: whether to perform clip after scaling. relative: whether to scale to the corresponding percentiles of [b_min, b_max] + channel_wise: if True, compute intensity percentile and normalize every channel separately. + default to False. dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ @@ -859,11 +861,12 @@ def __init__( b_max: Optional[float], clip: bool = False, relative: bool = False, + channel_wise: bool = False, dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative, dtype) + self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative, channel_wise, dtype) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 9754981560..354074394e 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -81,16 +81,21 @@ def clip(a: NdarrayOrTensor, a_min, a_max) -> NdarrayOrTensor: return result -def percentile(x: NdarrayOrTensor, q) -> Union[NdarrayOrTensor, float, int]: +def percentile(x: NdarrayOrTensor, q, dim: Optional[int] = None) -> Union[NdarrayOrTensor, float, int]: """`np.percentile` with equivalent implementation for torch. Pytorch uses `quantile`, but this functionality is only available from v1.7. For earlier methods, we calculate it ourselves. This doesn't do interpolation, so is the equivalent of ``numpy.percentile(..., interpolation="nearest")``. + For more details, please refer to: + https://pytorch.org/docs/stable/generated/torch.quantile.html. + https://numpy.org/doc/stable/reference/generated/numpy.percentile.html. Args: x: input data q: percentile to compute (should in range 0 <= q <= 100) + dim: the dim along which the percentiles are computed. default is to compute the percentile + along a flattened version of the array. only work for numpy array or Tensor with PyTorch >= 1.7.0. Returns: Resulting value (scalar) @@ -102,11 +107,11 @@ def percentile(x: NdarrayOrTensor, q) -> Union[NdarrayOrTensor, float, int]: raise ValueError result: Union[NdarrayOrTensor, float, int] if isinstance(x, np.ndarray): - result = np.percentile(x, q) + result = np.percentile(x, q, axis=dim) else: q = torch.tensor(q, device=x.device) if hasattr(torch, "quantile"): - result = torch.quantile(x, q / 100.0) + result = torch.quantile(x, q / 100.0, dim=dim) else: # Note that ``kthvalue()`` works one-based, i.e., the first sorted value # corresponds to k=1, not k=0. Thus, we need the `1 +`. diff --git a/tests/test_scale_intensity_range_percentiles.py b/tests/test_scale_intensity_range_percentiles.py index 3556c7a8b4..3c35ccbb2c 100644 --- a/tests/test_scale_intensity_range_percentiles.py +++ b/tests/test_scale_intensity_range_percentiles.py @@ -19,7 +19,7 @@ class TestScaleIntensityRangePercentiles(NumpyImageTestCase2D): def test_scaling(self): - img = self.imt + img = self.imt[0] lower = 10 upper = 99 b_min = 0 @@ -34,7 +34,7 @@ def test_scaling(self): assert_allclose(result, p(expected), rtol=1e-4) def test_relative_scaling(self): - img = self.imt + img = self.imt[0] lower = 10 upper = 99 b_min = 100 @@ -65,6 +65,26 @@ def test_invalid_instantiation(self): self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=30, upper=-20, b_min=0, b_max=255) self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=30, upper=900, b_min=0, b_max=255) + def test_channel_wise(self): + img = self.imt[0] + lower = 10 + upper = 99 + b_min = 0 + b_max = 255 + scaler = ScaleIntensityRangePercentiles( + lower=lower, upper=upper, b_min=b_min, b_max=b_max, channel_wise=True, dtype=np.uint8 + ) + expected = [] + for c in img: + a_min = np.percentile(c, lower) + a_max = np.percentile(c, upper) + expected.append(((c - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min) + expected = np.stack(expected).astype(np.uint8) + + for p in TEST_NDARRAYS: + result = scaler(p(img)) + assert_allclose(result, p(expected), rtol=1e-4) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_scale_intensity_range_percentilesd.py b/tests/test_scale_intensity_range_percentilesd.py index 0fcda21feb..d1a626bac8 100644 --- a/tests/test_scale_intensity_range_percentilesd.py +++ b/tests/test_scale_intensity_range_percentilesd.py @@ -27,13 +27,12 @@ def test_scaling(self): a_min = np.percentile(img, lower) a_max = np.percentile(img, upper) - expected = (img - a_min) / (a_max - a_min) - expected = (expected * (b_max - b_min)) + b_min + expected = (((img - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min).astype(np.uint8) for p in TEST_NDARRAYS: data = {"img": p(img)} scaler = ScaleIntensityRangePercentilesd( - keys=data.keys(), lower=lower, upper=upper, b_min=b_min, b_max=b_max + keys=data.keys(), lower=lower, upper=upper, b_min=b_min, b_max=b_max, dtype=np.uint8 ) assert_allclose(p(expected), scaler(data)["img"], rtol=1e-4) @@ -75,6 +74,26 @@ def test_invalid_instantiation(self): s = ScaleIntensityRangePercentilesd(keys=["img"], lower=30, upper=90, b_min=None, b_max=20, relative=True) s(self.imt) + def test_channel_wise(self): + img = self.imt + lower = 10 + upper = 99 + b_min = 0 + b_max = 255 + scaler = ScaleIntensityRangePercentilesd( + keys="img", lower=lower, upper=upper, b_min=b_min, b_max=b_max, channel_wise=True, dtype=np.uint8 + ) + expected = [] + for c in img: + a_min = np.percentile(c, lower) + a_max = np.percentile(c, upper) + expected.append((((c - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min).astype(np.uint8)) + expected = np.stack(expected) + + for p in TEST_NDARRAYS: + data = {"img": p(img)} + assert_allclose(scaler(data)["img"], p(expected), rtol=1e-4) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py index c3b1bc259b..b3724130d9 100644 --- a/tests/test_utils_pytorch_numpy_unification.py +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -16,7 +16,7 @@ from monai.transforms.utils_pytorch_numpy_unification import percentile from monai.utils import set_determinism -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose class TestPytorchNumpyUnification(unittest.TestCase): @@ -42,6 +42,18 @@ def test_fails(self): with self.assertRaises(ValueError): percentile(arr, q) + @SkipIfBeforePyTorchVersion((1, 7)) + def test_dim(self): + q = np.random.randint(0, 100, size=50) + results = [] + for p in TEST_NDARRAYS: + arr = p(np.arange(6).reshape(1, 2, 3).astype(np.float32)) + results.append(percentile(arr, q, dim=1)) + # pre torch 1.7, no `quantile`. Our own method doesn't interpolate, + # so we can only be accurate to 0.5 + atol = 0.5 if not hasattr(torch, "quantile") else 1e-4 + assert_allclose(results[0], results[-1], type_test=False, atol=atol) + if __name__ == "__main__": unittest.main()