diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index c7558eddc3..7b3e7b4fd2 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -25,7 +25,7 @@ from monai.networks.layers import GaussianFilter from monai.transforms.transform import Transform from monai.transforms.utils import fill_holes, get_largest_connected_component_mask -from monai.utils import ensure_tuple +from monai.utils import ensure_tuple, look_up_option __all__ = [ "Activations", @@ -112,7 +112,8 @@ class AsDiscrete(Transform): - execute `argmax` for input logits values. - threshold input value to 0.0 or 1.0. - - convert input value to One-Hot format + - convert input value to One-Hot format. + - round the value to the closest integer. Args: argmax: whether to execute argmax function on input data before transform. @@ -125,6 +126,8 @@ class AsDiscrete(Transform): Defaults to ``False``. logit_thresh: the threshold value for thresholding operation.. Defaults to ``0.5``. + rounding: if not None, round the data according to the specified option, + available options: ["torchrounding"]. """ @@ -135,12 +138,14 @@ def __init__( n_classes: Optional[int] = None, threshold_values: bool = False, logit_thresh: float = 0.5, + rounding: Optional[str] = None, ) -> None: self.argmax = argmax self.to_onehot = to_onehot self.n_classes = n_classes self.threshold_values = threshold_values self.logit_thresh = logit_thresh + self.rounding = rounding def __call__( self, @@ -150,6 +155,7 @@ def __call__( n_classes: Optional[int] = None, threshold_values: Optional[bool] = None, logit_thresh: Optional[float] = None, + rounding: Optional[str] = None, ) -> torch.Tensor: """ Args: @@ -165,6 +171,8 @@ def __call__( Defaults to ``self.threshold_values``. logit_thresh: the threshold value for thresholding operation.. Defaults to ``self.logit_thresh``. + rounding: if not None, round the data according to the specified option, + available options: ["torchrounding"]. """ if argmax or self.argmax: @@ -179,6 +187,11 @@ def __call__( if threshold_values or self.threshold_values: img = img >= (self.logit_thresh if logit_thresh is None else logit_thresh) + rounding = self.rounding if rounding is None else rounding + if rounding is not None: + rounding = look_up_option(rounding, ["torchrounding"]) + img = torch.round(img) + return img.float() diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 0d9be131fc..d4e039339b 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -134,6 +134,7 @@ def __init__( n_classes: Optional[Union[Sequence[int], int]] = None, threshold_values: Union[Sequence[bool], bool] = False, logit_thresh: Union[Sequence[float], float] = 0.5, + rounding: Union[Sequence[Optional[str]], Optional[str]] = None, allow_missing_keys: bool = False, ) -> None: """ @@ -150,6 +151,9 @@ def __init__( it also can be a sequence of bool, each element corresponds to a key in ``keys``. logit_thresh: the threshold value for thresholding operation, default is 0.5. it also can be a sequence of float, each element corresponds to a key in ``keys``. + rounding: if not None, round the data according to the specified option, + available options: ["torchrounding"]. it also can be a sequence of str or None, + each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. """ @@ -159,12 +163,13 @@ def __init__( self.n_classes = ensure_tuple_rep(n_classes, len(self.keys)) self.threshold_values = ensure_tuple_rep(threshold_values, len(self.keys)) self.logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys)) + self.rounding = ensure_tuple_rep(rounding, len(self.keys)) self.converter = AsDiscrete() def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) - for key, argmax, to_onehot, n_classes, threshold_values, logit_thresh in self.key_iterator( - d, self.argmax, self.to_onehot, self.n_classes, self.threshold_values, self.logit_thresh + for key, argmax, to_onehot, n_classes, threshold_values, logit_thresh, rounding in self.key_iterator( + d, self.argmax, self.to_onehot, self.n_classes, self.threshold_values, self.logit_thresh, self.rounding ): d[key] = self.converter( d[key], @@ -173,6 +178,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc n_classes, threshold_values, logit_thresh, + rounding, ) return d diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index ea806be139..b87fafd8f3 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -44,9 +44,16 @@ (3,), ] +TEST_CASE_5 = [ + {"rounding": "torchrounding"}, + torch.tensor([[[0.123, 1.345], [2.567, 3.789]]]), + torch.tensor([[[0.0, 1.0], [3.0, 4.0]]]), + (1, 2, 2), +] + class TestAsDiscrete(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_value_shape(self, input_param, img, out, expected_shape): result = AsDiscrete(**input_param)(img) torch.testing.assert_allclose(result, out) diff --git a/tests/test_as_discreted.py b/tests/test_as_discreted.py index d6a6f3c2a4..ac594f0daa 100644 --- a/tests/test_as_discreted.py +++ b/tests/test_as_discreted.py @@ -58,9 +58,16 @@ (2, 1, 2), ] +TEST_CASE_4 = [ + {"keys": "pred", "rounding": "torchrounding"}, + {"pred": torch.tensor([[[0.123, 1.345], [2.567, 3.789]]])}, + {"pred": torch.tensor([[[0.0, 1.0], [3.0, 4.0]]])}, + (1, 2, 2), +] + class TestAsDiscreted(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_value_shape(self, input_param, test_input, output, expected_shape): result = AsDiscreted(**input_param)(test_input) torch.testing.assert_allclose(result["pred"], output["pred"])