From 0369a8321146a94dd5b673c4495296d7a5efd92f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 23 Nov 2021 23:16:40 +0800 Subject: [PATCH 1/2] [DLMED] fix threshold issue Signed-off-by: Nic Ma --- monai/transforms/post/dictionary.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 2f639e0a95..89385dcbc7 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -175,13 +175,13 @@ def __init__( self.argmax = ensure_tuple_rep(argmax, len(self.keys)) self.to_onehot = ensure_tuple_rep(to_onehot, len(self.keys)) - if True in self.to_onehot or False in self.to_onehot: # backward compatibility + if any([isinstance(i, bool) for i in self.to_onehot]): # backward compatibility warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") num_classes = ensure_tuple_rep(num_classes, len(self.keys)) self.to_onehot = tuple(val if flag else None for flag, val in zip(self.to_onehot, num_classes)) self.threshold = ensure_tuple_rep(threshold, len(self.keys)) - if True in self.threshold or False in self.threshold: # backward compatibility + if any([isinstance(i, bool) for i in self.threshold]): # backward compatibility warnings.warn("`threshold_values=True/False` is deprecated, please use `threshold=value` instead.") logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys)) self.threshold = tuple(val if flag else None for flag, val in zip(self.threshold, logit_thresh)) From 07cc7490c5c6ea5c64eba7d554eb453b2207c913 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 23 Nov 2021 23:48:42 +0800 Subject: [PATCH 2/2] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/transforms/post/dictionary.py | 32 ++++++++++++++++++----------- tests/test_as_discreted.py | 16 +++++++++++++++ 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 89385dcbc7..9d6173ac80 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -173,18 +173,26 @@ def __init__( """ super().__init__(keys, allow_missing_keys) self.argmax = ensure_tuple_rep(argmax, len(self.keys)) - self.to_onehot = ensure_tuple_rep(to_onehot, len(self.keys)) - - if any([isinstance(i, bool) for i in self.to_onehot]): # backward compatibility - warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") - num_classes = ensure_tuple_rep(num_classes, len(self.keys)) - self.to_onehot = tuple(val if flag else None for flag, val in zip(self.to_onehot, num_classes)) - - self.threshold = ensure_tuple_rep(threshold, len(self.keys)) - if any([isinstance(i, bool) for i in self.threshold]): # backward compatibility - warnings.warn("`threshold_values=True/False` is deprecated, please use `threshold=value` instead.") - logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys)) - self.threshold = tuple(val if flag else None for flag, val in zip(self.threshold, logit_thresh)) + to_onehot_ = ensure_tuple_rep(to_onehot, len(self.keys)) + num_classes = ensure_tuple_rep(num_classes, len(self.keys)) + self.to_onehot = [] + for flag, val in zip(to_onehot_, num_classes): + if isinstance(flag, bool): + warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") + self.to_onehot.append(val if flag else None) + else: + self.to_onehot.append(flag) + + threshold_ = ensure_tuple_rep(threshold, len(self.keys)) + logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys)) + self.threshold = [] + for flag, val in zip(threshold_, logit_thresh): + if isinstance(flag, bool): + warnings.warn("`threshold_values=True/False` is deprecated, please use `threshold=value` instead.") + self.threshold.append(val if flag else None) + else: + self.threshold.append(flag) + self.rounding = ensure_tuple_rep(rounding, len(self.keys)) self.converter = AsDiscrete() diff --git a/tests/test_as_discreted.py b/tests/test_as_discreted.py index 8532f84d33..ae9d578f78 100644 --- a/tests/test_as_discreted.py +++ b/tests/test_as_discreted.py @@ -54,6 +54,22 @@ ] ) + # test compatible with previous versions + TEST_CASES.append( + [ + { + "keys": ["pred", "label"], + "argmax": False, + "to_onehot": None, + "threshold": [True, None], + "logit_thresh": 0.6, + }, + {"pred": p([[[0.0, 1.0], [2.0, 3.0]]]), "label": p([[[0, 1], [1, 1]]])}, + {"pred": p([[[0.0, 1.0], [1.0, 1.0]]]), "label": p([[[0.0, 1.0], [1.0, 1.0]]])}, + (1, 2, 2), + ] + ) + class TestAsDiscreted(unittest.TestCase): @parameterized.expand(TEST_CASES)