diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 1c3ee288a1..6a25c62c49 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -662,6 +662,13 @@ Utility :members: :special-members: __call__ +`IntensityStats` +"""""""""""""""" + .. autoclass:: IntensityStats + :members: + :special-members: __call__ + + Dictionary Transforms --------------------- @@ -911,6 +918,7 @@ Intensity (Dict) :members: :special-members: __call__ + IO (Dict) ^^^^^^^^^ @@ -1265,6 +1273,13 @@ Utility (Dict) :members: :special-members: __call__ +`IntensityStatsd` +""""""""""""""""" +.. autoclass:: IntensityStatsd + :members: + :special-members: __call__ + + Transform Adaptors ------------------ .. automodule:: monai.transforms.adaptors diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 20e29d5aa9..cf9198dbf5 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -328,6 +328,7 @@ EnsureType, FgBgToIndices, Identity, + IntensityStats, LabelToMask, Lambda, MapLabelValue, @@ -390,6 +391,9 @@ Identityd, IdentityD, IdentityDict, + IntensityStatsd, + IntensityStatsD, + IntensityStatsDict, LabelToMaskd, LabelToMaskD, LabelToMaskDict, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 4533f333ce..14b3e54459 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -187,11 +187,13 @@ class ShiftIntensity(Transform): def __init__(self, offset: float) -> None: self.offset = offset - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: np.ndarray, offset: Optional[float] = None) -> np.ndarray: """ Apply the transform to `img`. """ - return np.asarray((img + self.offset), dtype=img.dtype) + + offset = self.offset if offset is None else offset + return np.asarray((img + offset), dtype=img.dtype) class RandShiftIntensity(RandomizableTransform): @@ -214,20 +216,26 @@ def __init__(self, offsets: Union[Tuple[float, float], float], prob: float = 0.1 raise AssertionError("offsets should be a number or pair of numbers.") self.offsets = (min(offsets), max(offsets)) self._offset = self.offsets[0] + self._shfiter = ShiftIntensity(self._offset) def randomize(self, data: Optional[Any] = None) -> None: self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) super().randomize(None) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: np.ndarray, factor: Optional[float] = None) -> np.ndarray: """ Apply the transform to `img`. + + Args: + img: input image to shift intensity. + factor: a factor to multiply the random offset, then shift. + can be some image specific value at runtime, like: max(img), etc. + """ self.randomize() if not self._do_transform: return img - shifter = ShiftIntensity(self._offset) - return shifter(img) + return self._shfiter(img, self._offset if factor is None else self._offset * factor) class StdShiftIntensity(Transform): @@ -1457,7 +1465,7 @@ def __init__( self.intensity_range = intensity_range self.channel_wise = channel_wise self.as_tensor_output = as_tensor_output - self.sampled_k_intensity: List[float] = [] + self.sampled_k_intensity: List = [] self.sampled_locs: List[Tuple] = [] if intensity_range is not None: @@ -1523,7 +1531,7 @@ def _randomize(self, img: torch.Tensor, intensity_range: Sequence[Sequence[float if isinstance(intensity_range[0], Sequence): self.sampled_k_intensity = [self.R.uniform(p[0], p[1]) for p in intensity_range] else: - self.sampled_k_intensity = [self.R.uniform(intensity_range[0], intensity_range[1])] * len(img) # type: ignore + self.sampled_k_intensity = [self.R.uniform(intensity_range[0], intensity_range[1])] * len(img) def _make_sequence(self, x: torch.Tensor) -> Sequence[Sequence[float]]: """ diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index c24f7b67ca..e43aa1e2b3 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -42,7 +42,7 @@ ThresholdIntensity, ) from monai.transforms.transform import MapTransform, RandomizableTransform -from monai.utils import dtype_torch_to_numpy, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple +from monai.utils import dtype_torch_to_numpy, ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple __all__ = [ "RandGaussianNoised", @@ -232,21 +232,53 @@ class ShiftIntensityd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ShiftIntensity`. """ - def __init__(self, keys: KeysCollection, offset: float, allow_missing_keys: bool = False) -> None: + def __init__( + self, + keys: KeysCollection, + offset: float, + factor_key: Optional[str] = None, + meta_keys: Optional[KeysCollection] = None, + meta_key_postfix: str = "meta_dict", + allow_missing_keys: bool = False, + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` offset: offset value to shift the intensity of image. + factor_key: if not None, use it as the key to extract a value from the corresponding + meta data dictionary of `key` at runtime, and multiply the `offset` to shift intensity. + Usually, `IntensityStatsd` transform can pre-compute statistics of intensity values + and store in the meta data. + it also can be a sequence of strings, map to `keys`. + meta_keys: explicitly indicate the key of the corresponding meta data dictionary. + used to extract the factor value is `factor_key` is not None. + for example, for data with key `image`, the metadata by default is in `image_meta_dict`. + the meta data is a dictionary object which contains: filename, original_shape, etc. + it can be a sequence of string, map to the `keys`. + if None, will try to construct meta_keys by `key_{meta_key_postfix}`. + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + to the key data, default is `meta_dict`, the meta data is a dictionary object. + used to extract the factor value is `factor_key` is not None. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) + self.factor_key = ensure_tuple_rep(factor_key, len(self.keys)) + self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) + if len(self.keys) != len(self.meta_keys): + raise ValueError("meta_keys should have the same length as keys.") + self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) self.shifter = ShiftIntensity(offset) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key in self.key_iterator(d): - d[key] = self.shifter(d[key]) + for key, factor_key, meta_key, meta_key_postfix in self.key_iterator( + d, self.factor_key, self.meta_keys, self.meta_key_postfix + ): + meta_key = meta_key or f"{key}_{meta_key_postfix}" + factor: Optional[float] = d[meta_key].get(factor_key) if meta_key in d else None + offset = None if factor is None else self.shifter.offset * factor + d[key] = self.shifter(d[key], offset=offset) return d @@ -259,6 +291,9 @@ def __init__( self, keys: KeysCollection, offsets: Union[Tuple[float, float], float], + factor_key: Optional[str] = None, + meta_keys: Optional[KeysCollection] = None, + meta_key_postfix: str = "meta_dict", prob: float = 0.1, allow_missing_keys: bool = False, ) -> None: @@ -268,6 +303,20 @@ def __init__( See also: :py:class:`monai.transforms.compose.MapTransform` offsets: offset range to randomly shift. if single number, offset value is picked from (-offsets, offsets). + factor_key: if not None, use it as the key to extract a value from the corresponding + meta data dictionary of `key` at runtime, and multiply the random `offset` to shift intensity. + Usually, `IntensityStatsd` transform can pre-compute statistics of intensity values + and store in the meta data. + it also can be a sequence of strings, map to `keys`. + meta_keys: explicitly indicate the key of the corresponding meta data dictionary. + used to extract the factor value is `factor_key` is not None. + for example, for data with key `image`, the metadata by default is in `image_meta_dict`. + the meta data is a dictionary object which contains: filename, original_shape, etc. + it can be a sequence of string, map to the `keys`. + if None, will try to construct meta_keys by `key_{meta_key_postfix}`. + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + to the key data, default is `meta_dict`, the meta data is a dictionary object. + used to extract the factor value is `factor_key` is not None. prob: probability of rotating. (Default 0.1, with 10% probability it returns a rotated array.) allow_missing_keys: don't raise exception if key is missing. @@ -282,19 +331,29 @@ def __init__( raise AssertionError("offsets should be a number or pair of numbers.") self.offsets = (min(offsets), max(offsets)) self._offset = self.offsets[0] + self.factor_key = ensure_tuple_rep(factor_key, len(self.keys)) + self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) + if len(self.keys) != len(self.meta_keys): + raise ValueError("meta_keys should have the same length as keys.") + self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) + self.shifter = ShiftIntensity(self._offset) def randomize(self, data: Optional[Any] = None) -> None: self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) super().randomize(None) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data) -> Dict[Hashable, np.ndarray]: d = dict(data) self.randomize() if not self._do_transform: return d - shifter = ShiftIntensity(self._offset) - for key in self.key_iterator(d): - d[key] = shifter(d[key]) + for key, factor_key, meta_key, meta_key_postfix in self.key_iterator( + d, self.factor_key, self.meta_keys, self.meta_key_postfix + ): + meta_key = meta_key or f"{key}_{meta_key_postfix}" + factor: Optional[float] = d[meta_key].get(factor_key) if meta_key in d else None + offset = self._offset if factor is None else self._offset * factor + d[key] = self.shifter(d[key], offset=offset) return d diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 4e0141652f..3de2408abd 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -17,7 +17,7 @@ import sys import time import warnings -from typing import Callable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -32,7 +32,7 @@ map_binary_to_indices, map_classes_to_indices, ) -from monai.utils import ensure_tuple, issequenceiterable, min_version, optional_import +from monai.utils import ensure_tuple, issequenceiterable, look_up_option, min_version, optional_import PILImageImage, has_pil = optional_import("PIL.Image", name="Image") pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray") @@ -66,6 +66,7 @@ "AddExtremePointsChannel", "TorchVision", "MapLabelValue", + "IntensityStats", ] @@ -938,3 +939,80 @@ def __call__(self, img: np.ndarray): np.place(out_flat, img_flat == o, t) return out_flat.reshape(img.shape) + + +class IntensityStats(Transform): + """ + Compute statistics for the intensity values of input image and store into the meta data dictionary. + For example: if `ops=[lambda x: np.mean(x), "max"]` and `key_prefix="orig"`, may generate below stats: + `{"orig_custom_0": 1.5, "orig_max": 3.0}`. + + Args: + ops: expected operations to compute statistics for the intensity. + if a string, will map to the predefined operations, supported: ["mean", "median", "max", "min", "std"] + mapping to `np.nanmean`, `np.nanmedian`, `np.nanmax`, `np.nanmin`, `np.nanstd`. + if a callable function, will execute the function on input image. + key_prefix: the prefix to combine with `ops` name to generate the key to store the results in the + meta data dictionary. if some `ops` are callable functions, will use "{key_prefix}_custom_{index}" + as the key, where index counts from 0. + channel_wise: whether to compute statistics for every channel of input image separately. + if True, return a list of values for every operation, default to False. + + """ + + def __init__(self, ops: Sequence[Union[str, Callable]], key_prefix: str, channel_wise: bool = False) -> None: + self.ops = ensure_tuple(ops) + self.key_prefix = key_prefix + self.channel_wise = channel_wise + + def __call__( + self, + img: np.ndarray, + meta_data: Optional[Dict] = None, + mask: Optional[np.ndarray] = None, + ) -> Tuple[np.ndarray, Dict]: + """ + Compute statistics for the intensity of input image. + + Args: + img: input image to compute intensity stats. + meta_data: meta data dictionary to store the statistics data, if None, will create an empty dictionary. + mask: if not None, mask the image to extract only the interested area to compute statistics. + mask must have the same shape as input `img`. + + """ + if meta_data is None: + meta_data = {} + + img_: np.ndarray = img + if mask is not None: + if mask.shape != img.shape or mask.dtype != bool: + raise TypeError("mask must be bool array with the same shape as input `img`.") + img_ = img[mask] + + supported_ops = { + "mean": lambda x: np.nanmean(x), + "median": lambda x: np.nanmedian(x), + "max": lambda x: np.nanmax(x), + "min": lambda x: np.nanmin(x), + "std": lambda x: np.nanstd(x), + } + + def _compute(op: Callable, data: np.ndarray): + if self.channel_wise: + return [op(c) for c in data] + else: + return op(data) + + custom_index = 0 + for o in self.ops: + if isinstance(o, str): + o = look_up_option(o, supported_ops.keys()) + meta_data[self.key_prefix + "_" + o] = _compute(supported_ops[o], img_) + elif callable(o): + meta_data[self.key_prefix + "_custom_" + str(custom_index)] = _compute(o, img_) + custom_index += 1 + else: + raise ValueError("ops must be key string for predefined operations or callable function.") + + return img, meta_data diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 75be9685c4..fb9963601d 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -39,6 +39,7 @@ EnsureType, FgBgToIndices, Identity, + IntensityStats, LabelToMask, Lambda, MapLabelValue, @@ -101,6 +102,9 @@ "IdentityD", "IdentityDict", "Identityd", + "IntensityStatsd", + "IntensityStatsD", + "IntensityStatsDict", "LabelToMaskD", "LabelToMaskDict", "LabelToMaskd", @@ -1282,6 +1286,74 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d +class IntensityStatsd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.IntensityStats`. + Compute statistics for the intensity values of input image and store into the meta data dictionary. + For example: if `ops=[lambda x: np.mean(x), "max"]` and `key_prefix="orig"`, may generate below stats: + `{"orig_custom_0": 1.5, "orig_max": 3.0}`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + ops: expected operations to compute statistics for the intensity. + if a string, will map to the predefined operations, supported: ["mean", "median", "max", "min", "std"] + mapping to `np.nanmean`, `np.nanmedian`, `np.nanmax`, `np.nanmin`, `np.nanstd`. + if a callable function, will execute the function on input image. + key_prefix: the prefix to combine with `ops` name to generate the key to store the results in the + meta data dictionary. if some `ops` are callable functions, will use "{key_prefix}_custom_{index}" + as the key, where index counts from 0. + mask_keys: if not None, specify the mask array for the image to extract only the interested area to compute + statistics, mask must have the same shape as the image. + it should be a sequence of strings or None, map to the `keys`. + channel_wise: whether to compute statistics for every channel of input image separately. + if True, return a list of values for every operation, default to False. + meta_keys: explicitly indicate the key of the corresponding meta data dictionary. + used to store the computed statistics to the meta dict. + for example, for data with key `image`, the metadata by default is in `image_meta_dict`. + the meta data is a dictionary object which contains: filename, original_shape, etc. + it can be a sequence of string, map to the `keys`. + if None, will try to construct meta_keys by `key_{meta_key_postfix}`. + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + to the key data, default is `meta_dict`, the meta data is a dictionary object. + used to store the computed statistics to the meta dict. + allow_missing_keys: don't raise exception if key is missing. + + """ + + def __init__( + self, + keys: KeysCollection, + ops: Sequence[Union[str, Callable]], + key_prefix: str, + mask_keys: Optional[KeysCollection] = None, + channel_wise: bool = False, + meta_keys: Optional[KeysCollection] = None, + meta_key_postfix: str = "meta_dict", + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.stats = IntensityStats(ops=ops, key_prefix=key_prefix, channel_wise=channel_wise) + self.mask_keys = ensure_tuple_rep(None, len(self.keys)) if mask_keys is None else ensure_tuple(mask_keys) + self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) + if len(self.keys) != len(self.meta_keys): + raise ValueError("meta_keys should have the same length as keys.") + self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) + + def __call__(self, data) -> Dict[Hashable, np.ndarray]: + d = dict(data) + for key, mask_key, meta_key, meta_key_postfix in self.key_iterator( + d, self.mask_keys, self.meta_keys, self.meta_key_postfix + ): + meta_key = meta_key or f"{key}_{meta_key_postfix}" + d[key], d[meta_key] = self.stats( + img=d[key], + meta_data=d.get(meta_key), + mask=d.get(mask_key) if mask_key is not None else None, + ) + return d + + IdentityD = IdentityDict = Identityd AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd @@ -1316,3 +1388,4 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda RandTorchVisionD = RandTorchVisionDict = RandTorchVisiond RandLambdaD = RandLambdaDict = RandLambdad MapLabelValueD = MapLabelValueDict = MapLabelValued +IntensityStatsD = IntensityStatsDict = IntensityStatsd diff --git a/tests/test_intensity_stats.py b/tests/test_intensity_stats.py new file mode 100644 index 0000000000..059271e442 --- /dev/null +++ b/tests/test_intensity_stats.py @@ -0,0 +1,72 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import IntensityStats + +TEST_CASE_1 = [ + {"ops": ["max", "mean"], "key_prefix": "orig"}, + np.array([[[0.0, 1.0], [2.0, 3.0]]]), + {"affine": None}, + {"orig_max": 3.0, "orig_mean": 1.5}, +] + +TEST_CASE_2 = [ + {"ops": "std", "key_prefix": "orig"}, + np.array([[[0.0, 1.0], [2.0, 3.0]]]), + None, + {"orig_std": 1.118034}, +] + +TEST_CASE_3 = [ + {"ops": [lambda x: np.mean(x), "max", lambda x: np.min(x)], "key_prefix": "orig"}, + np.array([[[0.0, 1.0], [2.0, 3.0]]]), + None, + {"orig_custom_0": 1.5, "orig_max": 3.0, "orig_custom_1": 0.0}, +] + +TEST_CASE_4 = [ + {"ops": ["max", "mean"], "key_prefix": "orig", "channel_wise": True}, + np.array([[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]]]), + {"affine": None}, + {"orig_max": [3.0, 7.0], "orig_mean": [1.5, 5.5]}, +] + +TEST_CASE_5 = [ + {"ops": ["max", "mean"], "key_prefix": "orig"}, + np.array([[[0.0, 1.0], [2.0, 3.0]]]), + {"affine": None}, + {"orig_max": 3.0, "orig_mean": 1.5}, +] + + +class TestIntensityStats(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + def test_value(self, input_param, img, meta_dict, expected): + _, meta_dict = IntensityStats(**input_param)(img, meta_dict) + for k, v in expected.items(): + self.assertTrue(k in meta_dict) + np.testing.assert_allclose(v, meta_dict[k], atol=1e-3) + + def test_mask(self): + img = np.array([[[0.0, 1.0], [2.0, 3.0]]]) + mask = np.array([[[1, 0], [1, 0]]], dtype=bool) + img, meta_dict = IntensityStats(ops=["max", "mean"], key_prefix="orig")(img, mask=mask) + np.testing.assert_allclose(meta_dict["orig_max"], 2.0, atol=1e-3) + np.testing.assert_allclose(meta_dict["orig_mean"], 1.0, atol=1e-3) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_intensity_statsd.py b/tests/test_intensity_statsd.py new file mode 100644 index 0000000000..8c8bc8795a --- /dev/null +++ b/tests/test_intensity_statsd.py @@ -0,0 +1,86 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import numpy as np +import torch.multiprocessing as mp +from parameterized import parameterized + +from monai.data import DataLoader, Dataset +from monai.transforms import IntensityStatsd + +TEST_CASE_1 = [ + {"keys": "img", "ops": ["max", "mean"], "key_prefix": "orig", "meta_keys": "test_meta"}, + {"img": np.array([[[0.0, 1.0], [2.0, 3.0]]]), "test_meta": {"affine": None}}, + "test_meta", + {"orig_max": 3.0, "orig_mean": 1.5}, +] + +TEST_CASE_2 = [ + {"keys": "img", "ops": "std", "key_prefix": "orig"}, + {"img": np.array([[[0.0, 1.0], [2.0, 3.0]]])}, + "img_meta_dict", + {"orig_std": 1.118034}, +] + +TEST_CASE_3 = [ + {"keys": "img", "ops": [lambda x: np.mean(x), "max", lambda x: np.min(x)], "key_prefix": "orig"}, + {"img": np.array([[[0.0, 1.0], [2.0, 3.0]]])}, + "img_meta_dict", + {"orig_custom_0": 1.5, "orig_max": 3.0, "orig_custom_1": 0.0}, +] + +TEST_CASE_4 = [ + {"keys": "img", "ops": ["max", "mean"], "key_prefix": "orig", "channel_wise": True, "meta_key_postfix": "meta"}, + {"img": np.array([[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]]]), "img_meta": {"affine": None}}, + "img_meta", + {"orig_max": [3.0, 7.0], "orig_mean": [1.5, 5.5]}, +] + + +class TestIntensityStatsd(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + def test_value(self, input_param, data, meta_key, expected): + meta = IntensityStatsd(**input_param)(data)[meta_key] + for k, v in expected.items(): + self.assertTrue(k in meta) + np.testing.assert_allclose(v, meta[k], atol=1e-3) + + def test_dataloader(self): + dataset = Dataset( + data=[{"img": np.array([[[0.0, 1.0], [2.0, 3.0]]])}, {"img": np.array([[[0.0, 1.0], [2.0, 3.0]]])}], + transform=IntensityStatsd(keys="img", ops=["max", "mean"], key_prefix="orig"), + ) + # set num workers = 0 for mac / win + num_workers = 2 if sys.platform == "linux" else 0 + dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=2) + orig_method = mp.get_start_method() + mp.set_start_method("spawn", force=True) + + for d in dataloader: + meta = d["img_meta_dict"] + np.testing.assert_allclose(meta["orig_max"], [3.0, 3.0], atol=1e-3) + np.testing.assert_allclose(meta["orig_mean"], [1.5, 1.5], atol=1e-3) + # restore the mp method + mp.set_start_method(orig_method, force=True) + + def test_mask(self): + data = {"img": np.array([[[0.0, 1.0], [2.0, 3.0]]]), "img_mask": np.array([[[1, 0], [1, 0]]], dtype=bool)} + stats = IntensityStatsd(keys="img", ops=["max", "mean"], mask_keys="img_mask", key_prefix="orig") + meta = stats(data)["img_meta_dict"] + np.testing.assert_allclose(meta["orig_max"], 2.0, atol=1e-3) + np.testing.assert_allclose(meta["orig_mean"], 1.0, atol=1e-3) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_shift_intensity.py b/tests/test_rand_shift_intensity.py index ba54510bc3..4c4dd87dfe 100644 --- a/tests/test_rand_shift_intensity.py +++ b/tests/test_rand_shift_intensity.py @@ -21,7 +21,7 @@ class TestRandShiftIntensity(NumpyImageTestCase2D): def test_value(self): shifter = RandShiftIntensity(offsets=1.0, prob=1.0) shifter.set_random_state(seed=0) - result = shifter(self.imt) + result = shifter(self.imt, factor=1.0) np.random.seed(0) expected = self.imt + np.random.uniform(low=-1.0, high=1.0) np.testing.assert_allclose(result, expected) diff --git a/tests/test_rand_shift_intensityd.py b/tests/test_rand_shift_intensityd.py index 0c6f25e7b5..71cfd8fc50 100644 --- a/tests/test_rand_shift_intensityd.py +++ b/tests/test_rand_shift_intensityd.py @@ -13,7 +13,7 @@ import numpy as np -from monai.transforms import RandShiftIntensityd +from monai.transforms import IntensityStatsd, RandShiftIntensityd from tests.utils import NumpyImageTestCase2D @@ -27,6 +27,17 @@ def test_value(self): expected = self.imt + np.random.uniform(low=-1.0, high=1.0) np.testing.assert_allclose(result[key], expected) + def test_factor(self): + key = "img" + stats = IntensityStatsd(keys=key, ops="max", key_prefix="orig") + shifter = RandShiftIntensityd(keys=[key], offsets=1.0, factor_key=["orig_max"], prob=1.0) + data = {key: self.imt, key + "_meta_dict": {"affine": None}} + shifter.set_random_state(seed=0) + result = shifter(stats(data)) + np.random.seed(0) + expected = self.imt + np.random.uniform(low=-1.0, high=1.0) * np.nanmax(self.imt) + np.testing.assert_allclose(result[key], expected) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_shift_intensityd.py b/tests/test_shift_intensityd.py index 752cf4b8d2..71cfffc9c5 100644 --- a/tests/test_shift_intensityd.py +++ b/tests/test_shift_intensityd.py @@ -13,7 +13,7 @@ import numpy as np -from monai.transforms import ShiftIntensityd +from monai.transforms import IntensityStatsd, ShiftIntensityd from tests.utils import NumpyImageTestCase2D @@ -25,6 +25,16 @@ def test_value(self): expected = self.imt + 1.0 np.testing.assert_allclose(result[key], expected) + def test_factor(self): + key = "img" + stats = IntensityStatsd(keys=key, ops="max", key_prefix="orig") + shifter = ShiftIntensityd(keys=[key], offset=1.0, factor_key=["orig_max"]) + data = {key: self.imt, key + "_meta_dict": {"affine": None}} + + result = shifter(stats(data)) + expected = self.imt + 1.0 * np.nanmax(self.imt) + np.testing.assert_allclose(result[key], expected) + if __name__ == "__main__": unittest.main()