diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 9575a412b4..8a32b9e0b8 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -524,4 +524,4 @@ weighted_patch_samples, zero_margins, ) -from .utils_pytorch_numpy_unification import in1d, moveaxis, where +from .utils_pytorch_numpy_unification import clip, in1d, moveaxis, percentile, where diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index b6fa1f72b7..6c45f0d52b 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -28,7 +28,7 @@ from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter from monai.transforms.transform import RandomizableTransform, Transform from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array -from monai.transforms.utils_pytorch_numpy_unification import where +from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where from monai.utils import ( PT_BEFORE_1_7, InvalidPyTorchVersionError, @@ -688,6 +688,8 @@ class ScaleIntensityRange(Transform): clip: whether to perform clip after scaling. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, a_min: float, a_max: float, b_min: float, b_max: float, clip: bool = False) -> None: self.a_min = a_min self.a_max = a_max @@ -695,7 +697,7 @@ def __init__(self, a_min: float, a_max: float, b_min: float, b_max: float, clip: self.b_max = b_max self.clip = clip - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ @@ -706,7 +708,7 @@ def __call__(self, img: np.ndarray): img = (img - self.a_min) / (self.a_max - self.a_min) img = img * (self.b_max - self.b_min) + self.b_min if self.clip: - img = np.asarray(np.clip(img, self.b_min, self.b_max)) + img = clip(img, self.b_min, self.b_max) return img @@ -835,6 +837,8 @@ class ScaleIntensityRangePercentiles(Transform): relative: whether to scale to the corresponding percentiles of [b_min, b_max]. """ + backend = ScaleIntensityRange.backend + def __init__( self, lower: float, upper: float, b_min: float, b_max: float, clip: bool = False, relative: bool = False ) -> None: @@ -849,12 +853,12 @@ def __init__( self.clip = clip self.relative = relative - def __call__(self, img: np.ndarray): + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - a_min = np.percentile(img, self.lower) - a_max = np.percentile(img, self.upper) + a_min: float = percentile(img, self.lower) # type: ignore + a_max: float = percentile(img, self.upper) # type: ignore b_min = self.b_min b_max = self.b_max @@ -866,7 +870,7 @@ def __call__(self, img: np.ndarray): img = scalar(img) if self.clip: - img = np.asarray(np.clip(img, self.b_min, self.b_max)) + img = clip(img, self.b_min, self.b_max) return img diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index cccf3e2a90..07a6045870 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -699,6 +699,8 @@ class ScaleIntensityRanged(MapTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = ScaleIntensityRange.backend + def __init__( self, keys: KeysCollection, @@ -712,7 +714,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.scaler = ScaleIntensityRange(a_min, a_max, b_min, b_max, clip) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.scaler(d[key]) @@ -816,6 +818,8 @@ class ScaleIntensityRangePercentilesd(MapTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = ScaleIntensityRangePercentiles.backend + def __init__( self, keys: KeysCollection, @@ -830,7 +834,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.scaler(d[key]) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 70ecb2848d..0fb8e34ef0 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union + import numpy as np import torch @@ -17,6 +19,8 @@ __all__ = [ "moveaxis", "in1d", + "clip", + "percentile", "where", ] @@ -53,6 +57,56 @@ def in1d(x, y): return (x[..., None] == torch.tensor(y, device=x.device)).any(-1).view(-1) +def clip(a: NdarrayOrTensor, a_min, a_max) -> NdarrayOrTensor: + """`np.clip` with equivalent implementation for torch.""" + result: NdarrayOrTensor + if isinstance(a, np.ndarray): + result = np.clip(a, a_min, a_max) + else: + result = torch.clip(a, a_min, a_max) + return result + + +def percentile(x: NdarrayOrTensor, q) -> Union[NdarrayOrTensor, float, int]: + """`np.percentile` with equivalent implementation for torch. + + Pytorch uses `quantile`, but this functionality is only available from v1.7. + For earlier methods, we calculate it ourselves. This doesn't do interpolation, + so is the equivalent of ``numpy.percentile(..., interpolation="nearest")``. + + Args: + x: input data + q: percentile to compute (should in range 0 <= q <= 100) + + Returns: + Resulting value (scalar) + """ + if np.isscalar(q): + if not 0 <= q <= 100: + raise ValueError + else: + if any(q < 0) or any(q > 100): + raise ValueError + result: Union[NdarrayOrTensor, float, int] + if isinstance(x, np.ndarray): + result = np.percentile(x, q) + else: + q = torch.tensor(q, device=x.device) + if hasattr(torch, "quantile"): + result = torch.quantile(x, q / 100.0) + else: + # Note that ``kthvalue()`` works one-based, i.e., the first sorted value + # corresponds to k=1, not k=0. Thus, we need the `1 +`. + k = 1 + (0.01 * q * (x.numel() - 1)).round().int() + if k.numel() > 1: + r = [x.view(-1).kthvalue(int(_k)).values.item() for _k in k] + result = torch.tensor(r, device=x.device) + else: + result = x.view(-1).kthvalue(int(k)).values.item() + + return result + + def where(condition: NdarrayOrTensor, x, y) -> NdarrayOrTensor: """ Note that `torch.where` may convert y.dtype to x.dtype. diff --git a/tests/test_scale_intensity_range.py b/tests/test_scale_intensity_range.py index cba07d9157..d64f09ae82 100644 --- a/tests/test_scale_intensity_range.py +++ b/tests/test_scale_intensity_range.py @@ -11,19 +11,18 @@ import unittest -import numpy as np - from monai.transforms import ScaleIntensityRange -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class IntensityScaleIntensityRange(NumpyImageTestCase2D): def test_image_scale_intensity_range(self): scaler = ScaleIntensityRange(a_min=20, a_max=108, b_min=50, b_max=80) - scaled = scaler(self.imt) - expected = (self.imt - 20) / 88 - expected = expected * 30 + 50 - self.assertTrue(np.allclose(scaled, expected)) + for p in TEST_NDARRAYS: + scaled = scaler(p(self.imt)) + expected = (self.imt - 20) / 88 + expected = expected * 30 + 50 + assert_allclose(scaled, expected) if __name__ == "__main__": diff --git a/tests/test_scale_intensity_range_percentiles.py b/tests/test_scale_intensity_range_percentiles.py index 015162c8de..5cd19581b3 100644 --- a/tests/test_scale_intensity_range_percentiles.py +++ b/tests/test_scale_intensity_range_percentiles.py @@ -14,7 +14,7 @@ import numpy as np from monai.transforms.intensity.array import ScaleIntensityRangePercentiles -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestScaleIntensityRangePercentiles(NumpyImageTestCase2D): @@ -30,7 +30,9 @@ def test_scaling(self): expected = (img - a_min) / (a_max - a_min) expected = (expected * (b_max - b_min)) + b_min scaler = ScaleIntensityRangePercentiles(lower=lower, upper=upper, b_min=b_min, b_max=b_max) - self.assertTrue(np.allclose(expected, scaler(img))) + for p in TEST_NDARRAYS: + result = scaler(p(img)) + assert_allclose(expected, result) def test_relative_scaling(self): img = self.imt @@ -47,7 +49,9 @@ def test_relative_scaling(self): expected_img = (img - expected_a_min) / (expected_a_max - expected_a_min) expected_img = (expected_img * (expected_b_max - expected_b_min)) + expected_b_min - self.assertTrue(np.allclose(expected_img, scaler(img))) + for p in TEST_NDARRAYS: + result = scaler(p(img)) + assert_allclose(expected_img, result) def test_invalid_instantiation(self): self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=-10, upper=99, b_min=0, b_max=255) diff --git a/tests/test_scale_intensity_ranged.py b/tests/test_scale_intensity_ranged.py index a8cac414e8..b4d8cbf65a 100644 --- a/tests/test_scale_intensity_ranged.py +++ b/tests/test_scale_intensity_ranged.py @@ -11,20 +11,19 @@ import unittest -import numpy as np - from monai.transforms import ScaleIntensityRanged -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class IntensityScaleIntensityRanged(NumpyImageTestCase2D): def test_image_scale_intensity_ranged(self): key = "img" scaler = ScaleIntensityRanged(keys=key, a_min=20, a_max=108, b_min=50, b_max=80) - scaled = scaler({key: self.imt}) - expected = (self.imt - 20) / 88 - expected = expected * 30 + 50 - self.assertTrue(np.allclose(scaled[key], expected)) + for p in TEST_NDARRAYS: + scaled = scaler({key: p(self.imt)}) + expected = (self.imt - 20) / 88 + expected = expected * 30 + 50 + assert_allclose(scaled[key], expected) if __name__ == "__main__": diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py new file mode 100644 index 0000000000..f05235187c --- /dev/null +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -0,0 +1,46 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from monai.transforms.utils_pytorch_numpy_unification import percentile +from tests.utils import TEST_NDARRAYS, assert_allclose, set_determinism + + +class TestPytorchNumpyUnification(unittest.TestCase): + def setUp(self) -> None: + set_determinism(0) + + def test_percentile(self): + for size in (1, 100): + q = np.random.randint(0, 100, size=size) + results = [] + for p in TEST_NDARRAYS: + arr = p(np.arange(100 * 101).reshape(1, 100, 101).astype(np.float32)) + results.append(percentile(arr, q)) + # pre torch 1.7, no `quantile`. Our own method doesn't interpolate, + # so we can only be accurate to 0.5 + atol = 0.5 if not hasattr(torch, "quantile") else 1e-4 + assert_allclose(results[0], results[-1], atol=atol) + + def test_fails(self): + for p in TEST_NDARRAYS: + for q in (-1, 101): + arr = p(np.arange(100 * 101).reshape(1, 100, 101).astype(np.float32)) + with self.assertRaises(ValueError): + percentile(arr, q) + + +if __name__ == "__main__": + unittest.main()