From 5bd4fff4d4c496d0386401932709dae1ee8896e0 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 12 Aug 2021 16:36:24 +0800 Subject: [PATCH 1/3] [DLMED] add round_values Signed-off-by: Nic Ma --- monai/transforms/post/array.py | 12 +++++++++++- monai/transforms/post/dictionary.py | 9 +++++++-- tests/test_as_discrete.py | 9 ++++++++- tests/test_as_discreted.py | 9 ++++++++- 4 files changed, 34 insertions(+), 5 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index c7558eddc3..cd8d57f816 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -112,7 +112,8 @@ class AsDiscrete(Transform): - execute `argmax` for input logits values. - threshold input value to 0.0 or 1.0. - - convert input value to One-Hot format + - convert input value to One-Hot format. + - round the value to the closest integer. Args: argmax: whether to execute argmax function on input data before transform. @@ -125,6 +126,7 @@ class AsDiscrete(Transform): Defaults to ``False``. logit_thresh: the threshold value for thresholding operation.. Defaults to ``0.5``. + round_values: if true, round the data to the closest integer. """ @@ -135,12 +137,14 @@ def __init__( n_classes: Optional[int] = None, threshold_values: bool = False, logit_thresh: float = 0.5, + round_values: bool = False, ) -> None: self.argmax = argmax self.to_onehot = to_onehot self.n_classes = n_classes self.threshold_values = threshold_values self.logit_thresh = logit_thresh + self.round_values = round_values def __call__( self, @@ -150,6 +154,7 @@ def __call__( n_classes: Optional[int] = None, threshold_values: Optional[bool] = None, logit_thresh: Optional[float] = None, + round_values: Optional[bool] = None, ) -> torch.Tensor: """ Args: @@ -165,6 +170,7 @@ def __call__( Defaults to ``self.threshold_values``. logit_thresh: the threshold value for thresholding operation.. Defaults to ``self.logit_thresh``. + round_values: if true, round the data to the closest integer. """ if argmax or self.argmax: @@ -179,6 +185,10 @@ def __call__( if threshold_values or self.threshold_values: img = img >= (self.logit_thresh if logit_thresh is None else logit_thresh) + round_values = self.round_values if round_values is None else round_values + if round_values: + img = torch.round(img) + return img.float() diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 0d9be131fc..6c5dc9b1af 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -134,6 +134,7 @@ def __init__( n_classes: Optional[Union[Sequence[int], int]] = None, threshold_values: Union[Sequence[bool], bool] = False, logit_thresh: Union[Sequence[float], float] = 0.5, + round_values: Union[Sequence[bool], bool] = False, allow_missing_keys: bool = False, ) -> None: """ @@ -150,6 +151,8 @@ def __init__( it also can be a sequence of bool, each element corresponds to a key in ``keys``. logit_thresh: the threshold value for thresholding operation, default is 0.5. it also can be a sequence of float, each element corresponds to a key in ``keys``. + round_values: if true, round the data to the closest integer. + it also can be a sequence of bool, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. """ @@ -159,12 +162,13 @@ def __init__( self.n_classes = ensure_tuple_rep(n_classes, len(self.keys)) self.threshold_values = ensure_tuple_rep(threshold_values, len(self.keys)) self.logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys)) + self.round_values = ensure_tuple_rep(round_values, len(self.keys)) self.converter = AsDiscrete() def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) - for key, argmax, to_onehot, n_classes, threshold_values, logit_thresh in self.key_iterator( - d, self.argmax, self.to_onehot, self.n_classes, self.threshold_values, self.logit_thresh + for key, argmax, to_onehot, n_classes, threshold_values, logit_thresh, round_values in self.key_iterator( + d, self.argmax, self.to_onehot, self.n_classes, self.threshold_values, self.logit_thresh, self.round_values ): d[key] = self.converter( d[key], @@ -173,6 +177,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc n_classes, threshold_values, logit_thresh, + round_values, ) return d diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index ea806be139..6f1dc04207 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -44,9 +44,16 @@ (3,), ] +TEST_CASE_5 = [ + {"round_values": True}, + torch.tensor([[[0.123, 1.345], [2.567, 3.789]]]), + torch.tensor([[[0.0, 1.0], [3.0, 4.0]]]), + (1, 2, 2), +] + class TestAsDiscrete(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_value_shape(self, input_param, img, out, expected_shape): result = AsDiscrete(**input_param)(img) torch.testing.assert_allclose(result, out) diff --git a/tests/test_as_discreted.py b/tests/test_as_discreted.py index d6a6f3c2a4..4eef3b5567 100644 --- a/tests/test_as_discreted.py +++ b/tests/test_as_discreted.py @@ -58,9 +58,16 @@ (2, 1, 2), ] +TEST_CASE_4 = [ + {"keys": "pred", "round_values": True}, + {"pred": torch.tensor([[[0.123, 1.345], [2.567, 3.789]]])}, + {"pred": torch.tensor([[[0.0, 1.0], [3.0, 4.0]]])}, + (1, 2, 2), +] + class TestAsDiscreted(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_value_shape(self, input_param, test_input, output, expected_shape): result = AsDiscreted(**input_param)(test_input) torch.testing.assert_allclose(result["pred"], output["pred"]) From 364b3adcdccd08e7b41b49552b610bc3bb40cd9c Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 12 Aug 2021 17:54:48 +0800 Subject: [PATCH 2/3] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/transforms/post/array.py | 21 +++++++++++++-------- monai/transforms/post/dictionary.py | 15 ++++++++------- tests/test_as_discrete.py | 2 +- tests/test_as_discreted.py | 2 +- 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index cd8d57f816..78557a45b6 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -126,7 +126,8 @@ class AsDiscrete(Transform): Defaults to ``False``. logit_thresh: the threshold value for thresholding operation.. Defaults to ``0.5``. - round_values: if true, round the data to the closest integer. + rounding: if not None, round the data according to the specified option, + available options: ["torchrounding"]. """ @@ -137,14 +138,14 @@ def __init__( n_classes: Optional[int] = None, threshold_values: bool = False, logit_thresh: float = 0.5, - round_values: bool = False, + rounding: Optional[str] = None, ) -> None: self.argmax = argmax self.to_onehot = to_onehot self.n_classes = n_classes self.threshold_values = threshold_values self.logit_thresh = logit_thresh - self.round_values = round_values + self.rounding = rounding def __call__( self, @@ -154,7 +155,7 @@ def __call__( n_classes: Optional[int] = None, threshold_values: Optional[bool] = None, logit_thresh: Optional[float] = None, - round_values: Optional[bool] = None, + rounding: Optional[str] = None, ) -> torch.Tensor: """ Args: @@ -170,7 +171,8 @@ def __call__( Defaults to ``self.threshold_values``. logit_thresh: the threshold value for thresholding operation.. Defaults to ``self.logit_thresh``. - round_values: if true, round the data to the closest integer. + rounding: if not None, round the data according to the specified option, + available options: ["torchrounding"]. """ if argmax or self.argmax: @@ -185,9 +187,12 @@ def __call__( if threshold_values or self.threshold_values: img = img >= (self.logit_thresh if logit_thresh is None else logit_thresh) - round_values = self.round_values if round_values is None else round_values - if round_values: - img = torch.round(img) + rounding = self.rounding if rounding is None else rounding + if rounding is not None: + if rounding == "torchrounding": + img = torch.round(img) + else: + raise ValueError(f"unsupported rounding option: {rounding}.") return img.float() diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 6c5dc9b1af..d4e039339b 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -134,7 +134,7 @@ def __init__( n_classes: Optional[Union[Sequence[int], int]] = None, threshold_values: Union[Sequence[bool], bool] = False, logit_thresh: Union[Sequence[float], float] = 0.5, - round_values: Union[Sequence[bool], bool] = False, + rounding: Union[Sequence[Optional[str]], Optional[str]] = None, allow_missing_keys: bool = False, ) -> None: """ @@ -151,8 +151,9 @@ def __init__( it also can be a sequence of bool, each element corresponds to a key in ``keys``. logit_thresh: the threshold value for thresholding operation, default is 0.5. it also can be a sequence of float, each element corresponds to a key in ``keys``. - round_values: if true, round the data to the closest integer. - it also can be a sequence of bool, each element corresponds to a key in ``keys``. + rounding: if not None, round the data according to the specified option, + available options: ["torchrounding"]. it also can be a sequence of str or None, + each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. """ @@ -162,13 +163,13 @@ def __init__( self.n_classes = ensure_tuple_rep(n_classes, len(self.keys)) self.threshold_values = ensure_tuple_rep(threshold_values, len(self.keys)) self.logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys)) - self.round_values = ensure_tuple_rep(round_values, len(self.keys)) + self.rounding = ensure_tuple_rep(rounding, len(self.keys)) self.converter = AsDiscrete() def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) - for key, argmax, to_onehot, n_classes, threshold_values, logit_thresh, round_values in self.key_iterator( - d, self.argmax, self.to_onehot, self.n_classes, self.threshold_values, self.logit_thresh, self.round_values + for key, argmax, to_onehot, n_classes, threshold_values, logit_thresh, rounding in self.key_iterator( + d, self.argmax, self.to_onehot, self.n_classes, self.threshold_values, self.logit_thresh, self.rounding ): d[key] = self.converter( d[key], @@ -177,7 +178,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc n_classes, threshold_values, logit_thresh, - round_values, + rounding, ) return d diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index 6f1dc04207..b87fafd8f3 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -45,7 +45,7 @@ ] TEST_CASE_5 = [ - {"round_values": True}, + {"rounding": "torchrounding"}, torch.tensor([[[0.123, 1.345], [2.567, 3.789]]]), torch.tensor([[[0.0, 1.0], [3.0, 4.0]]]), (1, 2, 2), diff --git a/tests/test_as_discreted.py b/tests/test_as_discreted.py index 4eef3b5567..ac594f0daa 100644 --- a/tests/test_as_discreted.py +++ b/tests/test_as_discreted.py @@ -59,7 +59,7 @@ ] TEST_CASE_4 = [ - {"keys": "pred", "round_values": True}, + {"keys": "pred", "rounding": "torchrounding"}, {"pred": torch.tensor([[[0.123, 1.345], [2.567, 3.789]]])}, {"pred": torch.tensor([[[0.0, 1.0], [3.0, 4.0]]])}, (1, 2, 2), From 6f938a8117a319ef286c47eb9b3e60800920bc3e Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 12 Aug 2021 20:25:48 +0800 Subject: [PATCH 3/3] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/transforms/post/array.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 78557a45b6..7b3e7b4fd2 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -25,7 +25,7 @@ from monai.networks.layers import GaussianFilter from monai.transforms.transform import Transform from monai.transforms.utils import fill_holes, get_largest_connected_component_mask -from monai.utils import ensure_tuple +from monai.utils import ensure_tuple, look_up_option __all__ = [ "Activations", @@ -189,10 +189,8 @@ def __call__( rounding = self.rounding if rounding is None else rounding if rounding is not None: - if rounding == "torchrounding": - img = torch.round(img) - else: - raise ValueError(f"unsupported rounding option: {rounding}.") + rounding = look_up_option(rounding, ["torchrounding"]) + img = torch.round(img) return img.float()