From 8cab57d76e58ef43f9b99cb6eb972e789424bad8 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 10 Dec 2021 20:18:53 +0800 Subject: [PATCH 1/7] [DLMED] add channel-wise Signed-off-by: Nic Ma --- monai/transforms/intensity/array.py | 25 ++++++++++++++----- monai/transforms/intensity/dictionary.py | 9 ++++--- .../test_scale_intensity_range_percentiles.py | 20 +++++++++++++++ ...test_scale_intensity_range_percentilesd.py | 22 +++++++++++++++- 4 files changed, 66 insertions(+), 10 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 31249d547b..e4cf961746 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -608,8 +608,8 @@ class NormalizeIntensity(Transform): subtrahend: the amount to subtract by (usually the mean). divisor: the amount to divide by (usually the standard deviation). nonzero: whether only normalize non-zero values. - channel_wise: if using calculated mean and std, calculate on each channel separately - or calculate on the entire image directly. + channel_wise: if True, calculate on each channel separately, otherwise, calculate on + the entire image directly. default to False. dtype: output data type, if None, same as input image. defaults to float32. """ @@ -919,6 +919,8 @@ class ScaleIntensityRangePercentiles(Transform): b_max: intensity target range max. clip: whether to perform clip after scaling. relative: whether to scale to the corresponding percentiles of [b_min, b_max]. + channel_wise: if True, compute intensity percentile and normalize every channel separately. + default to False. dtype: output data type, if None, same as input image. defaults to float32. """ @@ -932,6 +934,7 @@ def __init__( b_max: Optional[float], clip: bool = False, relative: bool = False, + channel_wise: bool = False, dtype: DtypeLike = np.float32, ) -> None: if lower < 0.0 or lower > 100.0: @@ -944,12 +947,10 @@ def __init__( self.b_max = b_max self.clip = clip self.relative = relative + self.channel_wise = channel_wise self.dtype = dtype - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: - """ - Apply the transform to `img`. - """ + def _normalize(self, img: NdarrayOrTensor) -> NdarrayOrTensor: a_min: float = percentile(img, self.lower) # type: ignore a_max: float = percentile(img, self.upper) # type: ignore b_min = self.b_min @@ -967,6 +968,18 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: img = scalar(img) return img + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + """ + Apply the transform to `img`. + """ + if self.channel_wise: + for i, d in enumerate(img): + img[i] = self._normalize(img=d) # type: ignore + else: + img = self._normalize(img=img) + + return img + class MaskIntensity(Transform): """ diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index fa2de4c7b8..683d75763f 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -655,8 +655,8 @@ class NormalizeIntensityd(MapTransform): subtrahend: the amount to subtract by (usually the mean) divisor: the amount to divide by (usually the standard deviation) nonzero: whether only normalize non-zero values. - channel_wise: if using calculated mean and std, calculate on each channel separately - or calculate on the entire image directly. + channel_wise: if True, calculate on each channel separately, otherwise, calculate on + the entire image directly. default to False. dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ @@ -844,6 +844,8 @@ class ScaleIntensityRangePercentilesd(MapTransform): b_max: intensity target range max. clip: whether to perform clip after scaling. relative: whether to scale to the corresponding percentiles of [b_min, b_max] + channel_wise: if True, compute intensity percentile and normalize every channel separately. + default to False. dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ @@ -859,11 +861,12 @@ def __init__( b_max: Optional[float], clip: bool = False, relative: bool = False, + channel_wise: bool = False, dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative, dtype) + self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative, channel_wise, dtype) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) diff --git a/tests/test_scale_intensity_range_percentiles.py b/tests/test_scale_intensity_range_percentiles.py index 3556c7a8b4..108a50a3fc 100644 --- a/tests/test_scale_intensity_range_percentiles.py +++ b/tests/test_scale_intensity_range_percentiles.py @@ -65,6 +65,26 @@ def test_invalid_instantiation(self): self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=30, upper=-20, b_min=0, b_max=255) self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=30, upper=900, b_min=0, b_max=255) + def test_channel_wise(self): + img = self.imt + lower = 10 + upper = 99 + b_min = 0 + b_max = 255 + scaler = ScaleIntensityRangePercentiles( + lower=lower, upper=upper, b_min=b_min, b_max=b_max, channel_wise=True, dtype=np.uint8 + ) + expected = [] + for c in img: + a_min = np.percentile(c, lower) + a_max = np.percentile(c, upper) + expected.append(((c - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min) + expected = np.stack(expected).astype(np.uint8) + + for p in TEST_NDARRAYS: + result = scaler(p(img)) + assert_allclose(result, p(expected), rtol=1e-4) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_scale_intensity_range_percentilesd.py b/tests/test_scale_intensity_range_percentilesd.py index 0fcda21feb..dfc50c7694 100644 --- a/tests/test_scale_intensity_range_percentilesd.py +++ b/tests/test_scale_intensity_range_percentilesd.py @@ -33,7 +33,7 @@ def test_scaling(self): for p in TEST_NDARRAYS: data = {"img": p(img)} scaler = ScaleIntensityRangePercentilesd( - keys=data.keys(), lower=lower, upper=upper, b_min=b_min, b_max=b_max + keys=data.keys(), lower=lower, upper=upper, b_min=b_min, b_max=b_max, dtype=np.uint8 ) assert_allclose(p(expected), scaler(data)["img"], rtol=1e-4) @@ -75,6 +75,26 @@ def test_invalid_instantiation(self): s = ScaleIntensityRangePercentilesd(keys=["img"], lower=30, upper=90, b_min=None, b_max=20, relative=True) s(self.imt) + def test_channel_wise(self): + img = self.imt + lower = 10 + upper = 99 + b_min = 0 + b_max = 255 + scaler = ScaleIntensityRangePercentilesd( + keys="img", lower=lower, upper=upper, b_min=b_min, b_max=b_max, channel_wise=True, dtype=np.uint8 + ) + expected = [] + for c in img: + a_min = np.percentile(c, lower) + a_max = np.percentile(c, upper) + expected.append(((c - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min) + expected = np.stack(expected).astype(np.uint8) + + for p in TEST_NDARRAYS: + data = {"img": p(img)} + assert_allclose(scaler(data)["img"], p(expected), rtol=1e-4) + if __name__ == "__main__": unittest.main() From 8982ba5397d20bab6ab79cc4cb4c968531edb24f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 11 Dec 2021 00:18:20 +0800 Subject: [PATCH 2/7] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/transforms/utils_pytorch_numpy_unification.py | 13 +++++++++---- tests/test_scale_intensity_range_percentiles.py | 6 +++--- tests/test_utils_pytorch_numpy_unification.py | 11 +++++++++++ 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 9754981560..dd8212c62e 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -81,16 +81,21 @@ def clip(a: NdarrayOrTensor, a_min, a_max) -> NdarrayOrTensor: return result -def percentile(x: NdarrayOrTensor, q) -> Union[NdarrayOrTensor, float, int]: +def percentile(x: NdarrayOrTensor, q, dim: Optional[int] = None) -> Union[NdarrayOrTensor, float, int]: """`np.percentile` with equivalent implementation for torch. Pytorch uses `quantile`, but this functionality is only available from v1.7. For earlier methods, we calculate it ourselves. This doesn't do interpolation, so is the equivalent of ``numpy.percentile(..., interpolation="nearest")``. + For more details, please refer to: + https://pytorch.org/docs/stable/generated/torch.quantile.html. + https://numpy.org/doc/stable/reference/generated/numpy.percentile.html. Args: x: input data q: percentile to compute (should in range 0 <= q <= 100) + dim: the dim along which the percentiles are computed. default is to compute the percentile + along a flattened version of the array. only work for numpy array or Tensor with PyTorch >= 1.7.0. Returns: Resulting value (scalar) @@ -102,18 +107,18 @@ def percentile(x: NdarrayOrTensor, q) -> Union[NdarrayOrTensor, float, int]: raise ValueError result: Union[NdarrayOrTensor, float, int] if isinstance(x, np.ndarray): - result = np.percentile(x, q) + result = np.percentile(x, q, axis=dim) else: q = torch.tensor(q, device=x.device) if hasattr(torch, "quantile"): - result = torch.quantile(x, q / 100.0) + result = torch.quantile(x, q / 100.0, dim=dim) else: # Note that ``kthvalue()`` works one-based, i.e., the first sorted value # corresponds to k=1, not k=0. Thus, we need the `1 +`. k = 1 + (0.01 * q * (x.numel() - 1)).round().int() if k.numel() > 1: r = [x.view(-1).kthvalue(int(_k)).values.item() for _k in k] - result = torch.tensor(r, device=x.device) + result = torch.tensor(r, device=q.device) else: result = x.view(-1).kthvalue(int(k)).values.item() diff --git a/tests/test_scale_intensity_range_percentiles.py b/tests/test_scale_intensity_range_percentiles.py index 108a50a3fc..3c35ccbb2c 100644 --- a/tests/test_scale_intensity_range_percentiles.py +++ b/tests/test_scale_intensity_range_percentiles.py @@ -19,7 +19,7 @@ class TestScaleIntensityRangePercentiles(NumpyImageTestCase2D): def test_scaling(self): - img = self.imt + img = self.imt[0] lower = 10 upper = 99 b_min = 0 @@ -34,7 +34,7 @@ def test_scaling(self): assert_allclose(result, p(expected), rtol=1e-4) def test_relative_scaling(self): - img = self.imt + img = self.imt[0] lower = 10 upper = 99 b_min = 100 @@ -66,7 +66,7 @@ def test_invalid_instantiation(self): self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=30, upper=900, b_min=0, b_max=255) def test_channel_wise(self): - img = self.imt + img = self.imt[0] lower = 10 upper = 99 b_min = 0 diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py index c3b1bc259b..a0c22e9a6d 100644 --- a/tests/test_utils_pytorch_numpy_unification.py +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -42,6 +42,17 @@ def test_fails(self): with self.assertRaises(ValueError): percentile(arr, q) + def test_dim(self): + q = np.random.randint(0, 100, size=50) + results = [] + for p in TEST_NDARRAYS: + arr = p(np.arange(6).reshape(1, 2, 3).astype(np.float32)) + results.append(percentile(arr, q, dim=1)) + # pre torch 1.7, no `quantile`. Our own method doesn't interpolate, + # so we can only be accurate to 0.5 + atol = 0.5 if not hasattr(torch, "quantile") else 1e-4 + assert_allclose(results[0], results[-1], type_test=False, atol=atol) + if __name__ == "__main__": unittest.main() From ee63cd481592fc966a828c826c5d10cb8a24adc1 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 11 Dec 2021 00:20:38 +0800 Subject: [PATCH 3/7] [DLMED] fix typo Signed-off-by: Nic Ma --- monai/transforms/utils_pytorch_numpy_unification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index dd8212c62e..354074394e 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -118,7 +118,7 @@ def percentile(x: NdarrayOrTensor, q, dim: Optional[int] = None) -> Union[Ndarra k = 1 + (0.01 * q * (x.numel() - 1)).round().int() if k.numel() > 1: r = [x.view(-1).kthvalue(int(_k)).values.item() for _k in k] - result = torch.tensor(r, device=q.device) + result = torch.tensor(r, device=x.device) else: result = x.view(-1).kthvalue(int(k)).values.item() From a665c611f8317d7ad361c2ac673a2b7d7d172827 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sun, 12 Dec 2021 08:31:41 +0800 Subject: [PATCH 4/7] [DLMED] skip test if before 1.7 Signed-off-by: Nic Ma --- tests/test_utils_pytorch_numpy_unification.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py index a0c22e9a6d..016f19485a 100644 --- a/tests/test_utils_pytorch_numpy_unification.py +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -16,7 +16,7 @@ from monai.transforms.utils_pytorch_numpy_unification import percentile from monai.utils import set_determinism -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS, assert_allclose, SkipIfBeforePyTorchVersion class TestPytorchNumpyUnification(unittest.TestCase): @@ -42,6 +42,7 @@ def test_fails(self): with self.assertRaises(ValueError): percentile(arr, q) + @SkipIfBeforePyTorchVersion((1, 7)) def test_dim(self): q = np.random.randint(0, 100, size=50) results = [] From ffdf69ee31c2c5eeebb678b12edb0ff25d1479d3 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Sun, 12 Dec 2021 00:58:27 +0000 Subject: [PATCH 5/7] [MONAI] python code formatting Signed-off-by: monai-bot --- tests/test_utils_pytorch_numpy_unification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py index 016f19485a..b3724130d9 100644 --- a/tests/test_utils_pytorch_numpy_unification.py +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -16,7 +16,7 @@ from monai.transforms.utils_pytorch_numpy_unification import percentile from monai.utils import set_determinism -from tests.utils import TEST_NDARRAYS, assert_allclose, SkipIfBeforePyTorchVersion +from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose class TestPytorchNumpyUnification(unittest.TestCase): From 4c373df0d00415d817984c678ad4df88303fad62 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sun, 12 Dec 2021 10:04:13 +0800 Subject: [PATCH 6/7] [DLMED] fix CI test Signed-off-by: Nic Ma --- tests/test_scale_intensity_range_percentilesd.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_scale_intensity_range_percentilesd.py b/tests/test_scale_intensity_range_percentilesd.py index dfc50c7694..07d2caf66c 100644 --- a/tests/test_scale_intensity_range_percentilesd.py +++ b/tests/test_scale_intensity_range_percentilesd.py @@ -27,8 +27,7 @@ def test_scaling(self): a_min = np.percentile(img, lower) a_max = np.percentile(img, upper) - expected = (img - a_min) / (a_max - a_min) - expected = (expected * (b_max - b_min)) + b_min + expected = (((img - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min).astype(np.uint8) for p in TEST_NDARRAYS: data = {"img": p(img)} From c8db7ec1b768df53b044bc033d7213f8c1145f5d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sun, 12 Dec 2021 15:29:30 +0800 Subject: [PATCH 7/7] [DLMED] fix wrong test Signed-off-by: Nic Ma --- tests/test_scale_intensity_range_percentilesd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_scale_intensity_range_percentilesd.py b/tests/test_scale_intensity_range_percentilesd.py index 07d2caf66c..d1a626bac8 100644 --- a/tests/test_scale_intensity_range_percentilesd.py +++ b/tests/test_scale_intensity_range_percentilesd.py @@ -87,8 +87,8 @@ def test_channel_wise(self): for c in img: a_min = np.percentile(c, lower) a_max = np.percentile(c, upper) - expected.append(((c - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min) - expected = np.stack(expected).astype(np.uint8) + expected.append((((c - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min).astype(np.uint8)) + expected = np.stack(expected) for p in TEST_NDARRAYS: data = {"img": p(img)}