diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index ca4f1ef388..c14f2b242f 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -14,7 +14,7 @@ """ from collections.abc import Iterable -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union from warnings import warn import numpy as np @@ -24,7 +24,7 @@ from monai.data.utils import get_random_patch, get_valid_patch_size from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter from monai.transforms.transform import Fourier, RandomizableTransform, Transform -from monai.transforms.utils import rescale_array +from monai.transforms.utils import is_positive, rescale_array from monai.utils import ( PT_BEFORE_1_7, InvalidPyTorchVersionError, @@ -789,19 +789,23 @@ class MaskIntensity(Transform): """ Mask the intensity values of input image with the specified mask data. Mask data must have the same spatial size as the input image, and all - the intensity values of input image corresponding to `0` in the mask - data will be set to `0`, others will keep the original value. + the intensity values of input image corresponding to the selected values + in the mask data will keep the original value, others will be set to `0`. Args: mask_data: if `mask_data` is single channel, apply to every channel of input image. if multiple channels, the number of channels must - match the input data. `mask_data` will be converted to `bool` values - by `mask_data > 0` before applying transform to input image. + match the input data. the intensity values of input image corresponding + to the selected values in the mask data will keep the original value, + others will be set to `0`. + select_fn: function to select valid values of the `mask_data`, default is + to select `values > 0`. """ - def __init__(self, mask_data: Optional[np.ndarray]) -> None: + def __init__(self, mask_data: Optional[np.ndarray], select_fn: Callable = is_positive) -> None: self.mask_data = mask_data + self.select_fn = select_fn def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> np.ndarray: """ @@ -816,21 +820,18 @@ def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> n - ValueError: When ``mask_data`` and ``img`` channels differ and ``mask_data`` is not single channel. """ - if self.mask_data is None and mask_data is None: - raise ValueError("Unknown mask_data.") - mask_data_ = np.array([[1]]) - if self.mask_data is not None and mask_data is None: - mask_data_ = self.mask_data > 0 - if mask_data is not None: - mask_data_ = mask_data > 0 - mask_data_ = np.asarray(mask_data_) - if mask_data_.shape[0] != 1 and mask_data_.shape[0] != img.shape[0]: + mask_data = self.mask_data if mask_data is None else mask_data + if mask_data is None: + raise ValueError("must provide the mask_data when initializing the transform or at runtime.") + + mask_data = np.asarray(self.select_fn(mask_data)) + if mask_data.shape[0] != 1 and mask_data.shape[0] != img.shape[0]: raise ValueError( "When mask_data is not single channel, mask_data channels must match img, " - f"got img={img.shape[0]} mask_data={mask_data_.shape[0]}." + f"got img channels={img.shape[0]} mask_data channels={mask_data.shape[0]}." ) - return np.asarray(img * mask_data_) + return np.asarray(img * mask_data) class SavitzkyGolaySmooth(Transform): diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index e43aa1e2b3..19323e2020 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -16,7 +16,7 @@ """ from collections.abc import Iterable -from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -42,6 +42,7 @@ ThresholdIntensity, ) from monai.transforms.transform import MapTransform, RandomizableTransform +from monai.transforms.utils import is_positive from monai.utils import dtype_torch_to_numpy, ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple __all__ = [ @@ -808,11 +809,14 @@ class MaskIntensityd(MapTransform): See also: :py:class:`monai.transforms.compose.MapTransform` mask_data: if mask data is single channel, apply to every channel of input image. if multiple channels, the channel number must - match input data. mask_data will be converted to `bool` values - by `mask_data > 0` before applying transform to input image. - if None, will extract the mask data from input data based on `mask_key`. + match input data. the intensity values of input image corresponding + to the selected values in the mask data will keep the original value, + others will be set to `0`. if None, will extract the mask data from + input data based on `mask_key`. mask_key: the key to extract mask data from input dictionary, only works when `mask_data` is None. + select_fn: function to select valid values of the `mask_data`, default is + to select `values > 0`. allow_missing_keys: don't raise exception if key is missing. """ @@ -822,10 +826,11 @@ def __init__( keys: KeysCollection, mask_data: Optional[np.ndarray] = None, mask_key: Optional[str] = None, + select_fn: Callable = is_positive, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.converter = MaskIntensity(mask_data) + self.converter = MaskIntensity(mask_data=mask_data, select_fn=select_fn) self.mask_key = mask_key if mask_data is None else None def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: diff --git a/tests/test_mask_intensity.py b/tests/test_mask_intensity.py index 3131abe8bf..da9eda6416 100644 --- a/tests/test_mask_intensity.py +++ b/tests/test_mask_intensity.py @@ -34,9 +34,18 @@ np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]), ] +TEST_CASE_4 = [ + { + "mask_data": np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]), + "select_fn": lambda x: np.where((x > 3) & (x < 7), True, False), + }, + np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + np.array([[[0, 0, 0], [2, 2, 2], [0, 0, 0]], [[0, 0, 0], [5, 5, 5], [0, 0, 0]]]), +] + class TestMaskIntensity(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_value(self, argments, image, expected_data): result = MaskIntensity(**argments)(image) np.testing.assert_allclose(result, expected_data) diff --git a/tests/test_mask_intensityd.py b/tests/test_mask_intensityd.py index 0d08952db2..c21e26eba6 100644 --- a/tests/test_mask_intensityd.py +++ b/tests/test_mask_intensityd.py @@ -43,9 +43,19 @@ np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]), ] +TEST_CASE_5 = [ + { + "keys": "img", + "mask_data": np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]), + "select_fn": lambda x: np.where((x > 3) & (x < 7), True, False), + }, + {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, + np.array([[[0, 0, 0], [2, 2, 2], [0, 0, 0]], [[0, 0, 0], [5, 5, 5], [0, 0, 0]]]), +] + class TestMaskIntensityd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_value(self, argments, image, expected_data): result = MaskIntensityd(**argments)(image) np.testing.assert_allclose(result["img"], expected_data)