diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 34d75faf63..18e2250084 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -478,17 +478,26 @@ class MaskIntensityd(MapTransform): of input image. if multiple channels, the channel number must match input data. mask_data will be converted to `bool` values by `mask_data > 0` before applying transform to input image. + if None, will extract the mask data from input data based on `mask_key`. + mask_key: the key to extract mask data from input dictionary, only works + when `mask_data` is None. """ - def __init__(self, keys: KeysCollection, mask_data: np.ndarray) -> None: + def __init__( + self, + keys: KeysCollection, + mask_data: Optional[np.ndarray] = None, + mask_key: Optional[str] = None, + ) -> None: super().__init__(keys) self.converter = MaskIntensity(mask_data) + self.mask_key = mask_key if mask_data is None else None def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: - d[key] = self.converter(d[key]) + d[key] = self.converter(d[key], d[self.mask_key]) if self.mask_key is not None else self.converter(d[key]) return d diff --git a/tests/test_mask_intensityd.py b/tests/test_mask_intensityd.py index 47f4c0b8a1..0d08952db2 100644 --- a/tests/test_mask_intensityd.py +++ b/tests/test_mask_intensityd.py @@ -34,9 +34,18 @@ np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]), ] +TEST_CASE_4 = [ + {"keys": "img", "mask_key": "mask"}, + { + "img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + "mask": np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [0, 1, 0], [0, 1, 0]]]), + }, + np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]), +] + class TestMaskIntensityd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_value(self, argments, image, expected_data): result = MaskIntensityd(**argments)(image) np.testing.assert_allclose(result["img"], expected_data)