From 53376e0714140d50695511b12397565cf947dd34 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 10 Aug 2021 15:25:38 +0800 Subject: [PATCH 1/3] [DLMED] add Histogram normalize Signed-off-by: Nic Ma --- docs/source/transforms.rst | 12 ++++++ monai/transforms/__init__.py | 4 ++ monai/transforms/intensity/array.py | 35 ++++++++++++++++++ monai/transforms/intensity/dictionary.py | 38 +++++++++++++++++++ tests/test_histogram_normalize.py | 47 ++++++++++++++++++++++++ tests/test_histogram_normalized.py | 47 ++++++++++++++++++++++++ 6 files changed, 183 insertions(+) create mode 100644 tests/test_histogram_normalize.py create mode 100644 tests/test_histogram_normalized.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 8a880ff151..f97be395d1 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -319,6 +319,12 @@ Intensity :members: :special-members: __call__ +`HistogramNormalize` +"""""""""""""""""""" + .. autoclass:: HistogramNormalize + :members: + :special-members: __call__ + IO ^^ @@ -930,6 +936,12 @@ Intensity (Dict) :members: :special-members: __call__ +`HistogramNormalized` +""""""""""""""""""""" + .. autoclass:: HistogramNormalized + :members: + :special-members: __call__ + IO (Dict) ^^^^^^^^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 7f2873cc85..bbd5fccff2 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -83,6 +83,7 @@ GaussianSharpen, GaussianSmooth, GibbsNoise, + HistogramNormalize, KSpaceSpikeNoise, MaskIntensity, NormalizeIntensity, @@ -120,6 +121,9 @@ GibbsNoised, GibbsNoiseD, GibbsNoiseDict, + HistogramNormalized, + HistogramNormalizeD, + HistogramNormalizeDict, KSpaceSpikeNoised, KSpaceSpikeNoiseD, KSpaceSpikeNoiseDict, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index c14f2b242f..2b937f1179 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -64,6 +64,7 @@ "KSpaceSpikeNoise", "RandKSpaceSpikeNoise", "RandCoarseDropout", + "HistogramNormalize", ] @@ -1626,3 +1627,37 @@ def __call__(self, img: np.ndarray): img[h] = self.fill_value return img + + +class HistogramNormalize(Transform): + """ + Apply the histogram normalization to input image. + Refer to: https://github.com/facebookresearch/CovidPrognosis/blob/master/covidprognosis/data/transforms.py#L83. + + Args: + bins: number of the bins to use in histogram, default to `256`. for more details: + https://numpy.org/doc/stable/reference/generated/numpy.histogram.html. + max: the max value to normalize input image, default to `255`. + dtype: data type of the output, default to `float32`. + + """ + + def __init__(self, bins: int = 256, max: int = 255, dtype: DtypeLike = np.float32) -> None: + self.bins = bins + self.max = max + self.dtype = dtype + + def __call__(self, img: np.ndarray) -> np.ndarray: + """ + Apply the transform to `img`, assuming `img` is a channel-first array if `self.channel_wise` is True, + """ + orig_shape = img.shape + hist, bins = np.histogram(img.flatten(), self.bins, density=True) + cum = hist.cumsum() + # normalize the cumulative result + cum = self.max * cum / cum[-1] + + # apply linear interpolation + img = np.interp(img.flatten(), bins[:-1], cum) + + return img.reshape(orig_shape).astype(self.dtype) diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 19323e2020..859a4cd418 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -28,6 +28,7 @@ GaussianSharpen, GaussianSmooth, GibbsNoise, + HistogramNormalize, KSpaceSpikeNoise, MaskIntensity, NormalizeIntensity, @@ -72,6 +73,7 @@ "RandKSpaceSpikeNoised", "RandHistogramShiftd", "RandCoarseDropoutd", + "HistogramNormalized", "RandGaussianNoiseD", "RandGaussianNoiseDict", "ShiftIntensityD", @@ -122,6 +124,8 @@ "RandRicianNoiseDict", "RandCoarseDropoutD", "RandCoarseDropoutDict", + "HistogramNormalizeD", + "HistogramNormalizeDict", ] @@ -1469,6 +1473,39 @@ def __call__(self, data): return d +class HistogramNormalized(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.HistogramNormalize`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + bins: number of the bins to use in histogram, default to `256`. for more details: + https://numpy.org/doc/stable/reference/generated/numpy.histogram.html. + max: the max value to normalize input image, default to `255`. + dtype: data type of the output, default to `float32`. + allow_missing_keys: do not raise exception if key is missing. + + """ + + def __init__( + self, + keys: KeysCollection, + bins: int = 256, + max: int = 255, + dtype: DtypeLike = np.float32, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.transform = HistogramNormalize(bins=bins, max=max, dtype=dtype) + + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.transform(d[key]) + return d + + RandGaussianNoiseD = RandGaussianNoiseDict = RandGaussianNoised RandRicianNoiseD = RandRicianNoiseDict = RandRicianNoised ShiftIntensityD = ShiftIntensityDict = ShiftIntensityd @@ -1495,3 +1532,4 @@ def __call__(self, data): KSpaceSpikeNoiseD = KSpaceSpikeNoiseDict = KSpaceSpikeNoised RandKSpaceSpikeNoiseD = RandKSpaceSpikeNoiseDict = RandKSpaceSpikeNoised RandCoarseDropoutD = RandCoarseDropoutDict = RandCoarseDropoutd +HistogramNormalizeD = HistogramNormalizeDict = HistogramNormalized diff --git a/tests/test_histogram_normalize.py b/tests/test_histogram_normalize.py new file mode 100644 index 0000000000..ce55d57f5a --- /dev/null +++ b/tests/test_histogram_normalize.py @@ -0,0 +1,47 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import HistogramNormalize + +TEST_CASE_1 = [ + {"bins": 4, "max": 4}, + np.array([0, 1, 2, 3, 4]), + np.array([0.8, 1.6, 2.4, 4.0, 4.0]), +] + +TEST_CASE_2 = [ + {"bins": 4, "max": 4, "dtype": np.uint8}, + np.array([0, 1, 2, 3, 4]), + np.array([0, 1, 2, 4, 4]), +] + +TEST_CASE_3 = [ + {"bins": 256, "max": 255, "dtype": np.uint8}, + np.array([[[100, 200], [150, 250]]]), + np.array([[[63, 191], [127, 255]]]), +] + + +class TestHistogramNormalize(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_value(self, argments, image, expected_data): + result = HistogramNormalize(**argments)(image) + np.testing.assert_allclose(result, expected_data) + self.assertEqual(result.dtype, argments.get("dtype", np.float32)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_histogram_normalized.py b/tests/test_histogram_normalized.py new file mode 100644 index 0000000000..be1851c783 --- /dev/null +++ b/tests/test_histogram_normalized.py @@ -0,0 +1,47 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import HistogramNormalized + +TEST_CASE_1 = [ + {"keys": "img", "bins": 4, "max": 4}, + {"img": np.array([0, 1, 2, 3, 4])}, + np.array([0.8, 1.6, 2.4, 4.0, 4.0]), +] + +TEST_CASE_2 = [ + {"keys": "img", "bins": 4, "max": 4, "dtype": np.uint8}, + {"img": np.array([0, 1, 2, 3, 4])}, + np.array([0, 1, 2, 4, 4]), +] + +TEST_CASE_3 = [ + {"keys": "img", "bins": 256, "max": 255, "dtype": np.uint8}, + {"img": np.array([[[100, 200], [150, 250]]])}, + np.array([[[63, 191], [127, 255]]]), +] + + +class TestHistogramNormalized(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_value(self, argments, image, expected_data): + result = HistogramNormalized(**argments)(image)["img"] + np.testing.assert_allclose(result, expected_data) + self.assertEqual(result.dtype, argments.get("dtype", np.float32)) + + +if __name__ == "__main__": + unittest.main() From 9670463bdca245c4431b272f4563c8cb96e9c8fb Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 11 Aug 2021 12:22:59 +0800 Subject: [PATCH 2/3] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/transforms/__init__.py | 1 + monai/transforms/intensity/array.py | 21 ++++++---------- monai/transforms/intensity/dictionary.py | 8 +++--- monai/transforms/utils.py | 32 ++++++++++++++++++++++++ tests/test_histogram_normalize.py | 10 ++++---- tests/test_histogram_normalized.py | 10 ++++---- 6 files changed, 55 insertions(+), 27 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index bbd5fccff2..390b85a1b8 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -471,6 +471,7 @@ create_scale, create_shear, create_translate, + equalize_hist, extreme_points_to_image, generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 2b937f1179..cbae934660 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -24,7 +24,7 @@ from monai.data.utils import get_random_patch, get_valid_patch_size from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter from monai.transforms.transform import Fourier, RandomizableTransform, Transform -from monai.transforms.utils import is_positive, rescale_array +from monai.transforms.utils import equalize_hist, is_positive, rescale_array from monai.utils import ( PT_BEFORE_1_7, InvalidPyTorchVersionError, @@ -1635,15 +1635,17 @@ class HistogramNormalize(Transform): Refer to: https://github.com/facebookresearch/CovidPrognosis/blob/master/covidprognosis/data/transforms.py#L83. Args: - bins: number of the bins to use in histogram, default to `256`. for more details: + num_bins: number of the bins to use in histogram, default to `256`. for more details: https://numpy.org/doc/stable/reference/generated/numpy.histogram.html. + min: the min value to normalize input image, default to `0`. max: the max value to normalize input image, default to `255`. dtype: data type of the output, default to `float32`. """ - def __init__(self, bins: int = 256, max: int = 255, dtype: DtypeLike = np.float32) -> None: - self.bins = bins + def __init__(self, num_bins: int = 256, min: int = 0, max: int = 255, dtype: DtypeLike = np.float32) -> None: + self.num_bins = num_bins + self.min = min self.max = max self.dtype = dtype @@ -1651,13 +1653,4 @@ def __call__(self, img: np.ndarray) -> np.ndarray: """ Apply the transform to `img`, assuming `img` is a channel-first array if `self.channel_wise` is True, """ - orig_shape = img.shape - hist, bins = np.histogram(img.flatten(), self.bins, density=True) - cum = hist.cumsum() - # normalize the cumulative result - cum = self.max * cum / cum[-1] - - # apply linear interpolation - img = np.interp(img.flatten(), bins[:-1], cum) - - return img.reshape(orig_shape).astype(self.dtype) + return equalize_hist(img=img, num_bins=self.num_bins, min=self.min, max=self.max, dtype=self.dtype) diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 859a4cd418..50977d209e 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -1480,8 +1480,9 @@ class HistogramNormalized(MapTransform): Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` - bins: number of the bins to use in histogram, default to `256`. for more details: + num_bins: number of the bins to use in histogram, default to `256`. for more details: https://numpy.org/doc/stable/reference/generated/numpy.histogram.html. + min: the min value to normalize input image, default to `255`. max: the max value to normalize input image, default to `255`. dtype: data type of the output, default to `float32`. allow_missing_keys: do not raise exception if key is missing. @@ -1491,13 +1492,14 @@ class HistogramNormalized(MapTransform): def __init__( self, keys: KeysCollection, - bins: int = 256, + num_bins: int = 256, + min: int = 0, max: int = 255, dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.transform = HistogramNormalize(bins=bins, max=max, dtype=dtype) + self.transform = HistogramNormalize(num_bins=num_bins, min=min, max=max, dtype=dtype) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 800a779651..9a59a72b21 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -76,6 +76,7 @@ "tensor_to_numpy", "weighted_patch_samples", "zero_margins", + "equalize_hist", ] @@ -1115,3 +1116,34 @@ def tensor_to_numpy(data): return tuple(tensor_to_numpy(i) for i in data) return data + + +def equalize_hist( + img: np.ndarray, + num_bins: int = 256, + min: int = 0, + max: int = 255, + dtype: DtypeLike = np.float32, +) -> np.ndarray: + """ + Utility to equalize input image based on the histogram. + + Args: + img: input image to equalize. + num_bins: number of the bins to use in histogram, default to `256`. for more details: + https://numpy.org/doc/stable/reference/generated/numpy.histogram.html. + min: the min value to normalize input image, default to `0`. + max: the max value to normalize input image, default to `255`. + dtype: data type of the output, default to `float32`. + + """ + orig_shape = img.shape + hist, bins = np.histogram(img.flatten(), num_bins, density=True) + cum = hist.cumsum() + # normalize the cumulative result + cum = rescale_array(arr=cum, minv=min, maxv=max) + + # apply linear interpolation + img = np.interp(img.flatten(), bins[:-1], cum) + + return img.reshape(orig_shape).astype(dtype) diff --git a/tests/test_histogram_normalize.py b/tests/test_histogram_normalize.py index ce55d57f5a..c6f43df7d0 100644 --- a/tests/test_histogram_normalize.py +++ b/tests/test_histogram_normalize.py @@ -17,21 +17,21 @@ from monai.transforms import HistogramNormalize TEST_CASE_1 = [ - {"bins": 4, "max": 4}, + {"num_bins": 4, "min": 1, "max": 5}, np.array([0, 1, 2, 3, 4]), - np.array([0.8, 1.6, 2.4, 4.0, 4.0]), + np.array([1.0, 2.0, 3.0, 5.0, 5.0]), ] TEST_CASE_2 = [ - {"bins": 4, "max": 4, "dtype": np.uint8}, + {"num_bins": 4, "max": 4, "dtype": np.uint8}, np.array([0, 1, 2, 3, 4]), np.array([0, 1, 2, 4, 4]), ] TEST_CASE_3 = [ - {"bins": 256, "max": 255, "dtype": np.uint8}, + {"num_bins": 256, "max": 255, "dtype": np.uint8}, np.array([[[100, 200], [150, 250]]]), - np.array([[[63, 191], [127, 255]]]), + np.array([[[0, 169], [85, 255]]]), ] diff --git a/tests/test_histogram_normalized.py b/tests/test_histogram_normalized.py index be1851c783..a8c920801b 100644 --- a/tests/test_histogram_normalized.py +++ b/tests/test_histogram_normalized.py @@ -17,21 +17,21 @@ from monai.transforms import HistogramNormalized TEST_CASE_1 = [ - {"keys": "img", "bins": 4, "max": 4}, + {"keys": "img", "num_bins": 4, "min": 1, "max": 5}, {"img": np.array([0, 1, 2, 3, 4])}, - np.array([0.8, 1.6, 2.4, 4.0, 4.0]), + np.array([1.0, 2.0, 3.0, 5.0, 5.0]), ] TEST_CASE_2 = [ - {"keys": "img", "bins": 4, "max": 4, "dtype": np.uint8}, + {"keys": "img", "num_bins": 4, "max": 4, "dtype": np.uint8}, {"img": np.array([0, 1, 2, 3, 4])}, np.array([0, 1, 2, 4, 4]), ] TEST_CASE_3 = [ - {"keys": "img", "bins": 256, "max": 255, "dtype": np.uint8}, + {"keys": "img", "num_bins": 256, "max": 255, "dtype": np.uint8}, {"img": np.array([[[100, 200], [150, 250]]])}, - np.array([[[63, 191], [127, 255]]]), + np.array([[[0, 169], [85, 255]]]), ] From 640c17f88ea78caa13792b47900c42523fb94697 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 12 Aug 2021 01:41:13 +0800 Subject: [PATCH 3/3] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/transforms/intensity/array.py | 27 ++++++++++++++++++------ monai/transforms/intensity/dictionary.py | 12 +++++++++-- monai/transforms/utils.py | 16 ++++++++++++-- tests/test_histogram_normalize.py | 14 ++++++------ tests/test_histogram_normalized.py | 14 ++++++------ 5 files changed, 59 insertions(+), 24 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index cbae934660..258d896eb6 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1639,18 +1639,33 @@ class HistogramNormalize(Transform): https://numpy.org/doc/stable/reference/generated/numpy.histogram.html. min: the min value to normalize input image, default to `0`. max: the max value to normalize input image, default to `255`. + mask: if provided, must be ndarray of bools or 0s and 1s, and same shape as `image`. + only points at which `mask==True` are used for the equalization. + can also provide the mask along with img at runtime. dtype: data type of the output, default to `float32`. """ - def __init__(self, num_bins: int = 256, min: int = 0, max: int = 255, dtype: DtypeLike = np.float32) -> None: + def __init__( + self, + num_bins: int = 256, + min: int = 0, + max: int = 255, + mask: Optional[np.ndarray] = None, + dtype: DtypeLike = np.float32, + ) -> None: self.num_bins = num_bins self.min = min self.max = max + self.mask = mask self.dtype = dtype - def __call__(self, img: np.ndarray) -> np.ndarray: - """ - Apply the transform to `img`, assuming `img` is a channel-first array if `self.channel_wise` is True, - """ - return equalize_hist(img=img, num_bins=self.num_bins, min=self.min, max=self.max, dtype=self.dtype) + def __call__(self, img: np.ndarray, mask: Optional[np.ndarray] = None) -> np.ndarray: + return equalize_hist( + img=img, + mask=mask if mask is not None else self.mask, + num_bins=self.num_bins, + min=self.min, + max=self.max, + dtype=self.dtype, + ) diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 50977d209e..bc5534b402 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -1484,6 +1484,10 @@ class HistogramNormalized(MapTransform): https://numpy.org/doc/stable/reference/generated/numpy.histogram.html. min: the min value to normalize input image, default to `255`. max: the max value to normalize input image, default to `255`. + mask: if provided, must be ndarray of bools or 0s and 1s, and same shape as `image`. + only points at which `mask==True` are used for the equalization. + can also provide the mask by `mask_key` at runtime. + mask_key: if mask is None, will try to get the mask with `mask_key`. dtype: data type of the output, default to `float32`. allow_missing_keys: do not raise exception if key is missing. @@ -1495,16 +1499,20 @@ def __init__( num_bins: int = 256, min: int = 0, max: int = 255, + mask: Optional[np.ndarray] = None, + mask_key: Optional[str] = None, dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.transform = HistogramNormalize(num_bins=num_bins, min=min, max=max, dtype=dtype) + self.transform = HistogramNormalize(num_bins=num_bins, min=min, max=max, mask=mask, dtype=dtype) + self.mask_key = mask_key if mask is None else None def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.key_iterator(d): - d[key] = self.transform(d[key]) + d[key] = self.transform(d[key], d[self.mask_key]) if self.mask_key is not None else self.transform(d[key]) + return d diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 9a59a72b21..e996d7c9ea 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -40,6 +40,7 @@ ndimage, _ = optional_import("scipy.ndimage") cp, has_cp = optional_import("cupy") cp_ndarray, _ = optional_import("cupy", name="ndarray") +exposure, has_skimage = optional_import("skimage.exposure") __all__ = [ "allow_missing_keys_mode", @@ -1120,6 +1121,7 @@ def tensor_to_numpy(data): def equalize_hist( img: np.ndarray, + mask: Optional[np.ndarray] = None, num_bins: int = 256, min: int = 0, max: int = 255, @@ -1127,9 +1129,13 @@ def equalize_hist( ) -> np.ndarray: """ Utility to equalize input image based on the histogram. + If `skimage` installed, will leverage `skimage.exposure.histogram`, otherwise, use + `np.histogram` instead. Args: img: input image to equalize. + mask: if provided, must be ndarray of bools or 0s and 1s, and same shape as `image`. + only points at which `mask==True` are used for the equalization. num_bins: number of the bins to use in histogram, default to `256`. for more details: https://numpy.org/doc/stable/reference/generated/numpy.histogram.html. min: the min value to normalize input image, default to `0`. @@ -1138,12 +1144,18 @@ def equalize_hist( """ orig_shape = img.shape - hist, bins = np.histogram(img.flatten(), num_bins, density=True) + hist_img = img[np.array(mask, dtype=bool)] if mask is not None else img + if has_skimage: + hist, bins = exposure.histogram(hist_img.flatten(), num_bins) + else: + hist, bins = np.histogram(hist_img.flatten(), num_bins) + bins = (bins[:-1] + bins[1:]) / 2 + cum = hist.cumsum() # normalize the cumulative result cum = rescale_array(arr=cum, minv=min, maxv=max) # apply linear interpolation - img = np.interp(img.flatten(), bins[:-1], cum) + img = np.interp(img.flatten(), bins, cum) return img.reshape(orig_shape).astype(dtype) diff --git a/tests/test_histogram_normalize.py b/tests/test_histogram_normalize.py index c6f43df7d0..b69fb1d927 100644 --- a/tests/test_histogram_normalize.py +++ b/tests/test_histogram_normalize.py @@ -17,21 +17,21 @@ from monai.transforms import HistogramNormalize TEST_CASE_1 = [ - {"num_bins": 4, "min": 1, "max": 5}, - np.array([0, 1, 2, 3, 4]), - np.array([1.0, 2.0, 3.0, 5.0, 5.0]), + {"num_bins": 4, "min": 1, "max": 5, "mask": np.array([1, 1, 1, 1, 1, 0])}, + np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), + np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0]), ] TEST_CASE_2 = [ {"num_bins": 4, "max": 4, "dtype": np.uint8}, - np.array([0, 1, 2, 3, 4]), - np.array([0, 1, 2, 4, 4]), + np.array([0.0, 1.0, 2.0, 3.0, 4.0]), + np.array([0, 0, 1, 3, 4]), ] TEST_CASE_3 = [ {"num_bins": 256, "max": 255, "dtype": np.uint8}, - np.array([[[100, 200], [150, 250]]]), - np.array([[[0, 169], [85, 255]]]), + np.array([[[100.0, 200.0], [150.0, 250.0]]]), + np.array([[[0, 170], [70, 255]]]), ] diff --git a/tests/test_histogram_normalized.py b/tests/test_histogram_normalized.py index a8c920801b..68647e82fb 100644 --- a/tests/test_histogram_normalized.py +++ b/tests/test_histogram_normalized.py @@ -17,21 +17,21 @@ from monai.transforms import HistogramNormalized TEST_CASE_1 = [ - {"keys": "img", "num_bins": 4, "min": 1, "max": 5}, - {"img": np.array([0, 1, 2, 3, 4])}, - np.array([1.0, 2.0, 3.0, 5.0, 5.0]), + {"keys": "img", "num_bins": 4, "min": 1, "max": 5, "mask_key": "mask"}, + {"img": np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), "mask": np.array([1, 1, 1, 1, 1, 0])}, + np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0]), ] TEST_CASE_2 = [ {"keys": "img", "num_bins": 4, "max": 4, "dtype": np.uint8}, - {"img": np.array([0, 1, 2, 3, 4])}, - np.array([0, 1, 2, 4, 4]), + {"img": np.array([0.0, 1.0, 2.0, 3.0, 4.0])}, + np.array([0, 0, 1, 3, 4]), ] TEST_CASE_3 = [ {"keys": "img", "num_bins": 256, "max": 255, "dtype": np.uint8}, - {"img": np.array([[[100, 200], [150, 250]]])}, - np.array([[[0, 169], [85, 255]]]), + {"img": np.array([[[100.0, 200.0], [150.0, 250.0]]])}, + np.array([[[0, 170], [70, 255]]]), ]