diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 2f639e0a95..9d6173ac80 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -173,18 +173,26 @@ def __init__( """ super().__init__(keys, allow_missing_keys) self.argmax = ensure_tuple_rep(argmax, len(self.keys)) - self.to_onehot = ensure_tuple_rep(to_onehot, len(self.keys)) - - if True in self.to_onehot or False in self.to_onehot: # backward compatibility - warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") - num_classes = ensure_tuple_rep(num_classes, len(self.keys)) - self.to_onehot = tuple(val if flag else None for flag, val in zip(self.to_onehot, num_classes)) - - self.threshold = ensure_tuple_rep(threshold, len(self.keys)) - if True in self.threshold or False in self.threshold: # backward compatibility - warnings.warn("`threshold_values=True/False` is deprecated, please use `threshold=value` instead.") - logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys)) - self.threshold = tuple(val if flag else None for flag, val in zip(self.threshold, logit_thresh)) + to_onehot_ = ensure_tuple_rep(to_onehot, len(self.keys)) + num_classes = ensure_tuple_rep(num_classes, len(self.keys)) + self.to_onehot = [] + for flag, val in zip(to_onehot_, num_classes): + if isinstance(flag, bool): + warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") + self.to_onehot.append(val if flag else None) + else: + self.to_onehot.append(flag) + + threshold_ = ensure_tuple_rep(threshold, len(self.keys)) + logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys)) + self.threshold = [] + for flag, val in zip(threshold_, logit_thresh): + if isinstance(flag, bool): + warnings.warn("`threshold_values=True/False` is deprecated, please use `threshold=value` instead.") + self.threshold.append(val if flag else None) + else: + self.threshold.append(flag) + self.rounding = ensure_tuple_rep(rounding, len(self.keys)) self.converter = AsDiscrete() diff --git a/tests/test_as_discreted.py b/tests/test_as_discreted.py index 8532f84d33..ae9d578f78 100644 --- a/tests/test_as_discreted.py +++ b/tests/test_as_discreted.py @@ -54,6 +54,22 @@ ] ) + # test compatible with previous versions + TEST_CASES.append( + [ + { + "keys": ["pred", "label"], + "argmax": False, + "to_onehot": None, + "threshold": [True, None], + "logit_thresh": 0.6, + }, + {"pred": p([[[0.0, 1.0], [2.0, 3.0]]]), "label": p([[[0, 1], [1, 1]]])}, + {"pred": p([[[0.0, 1.0], [1.0, 1.0]]]), "label": p([[[0.0, 1.0], [1.0, 1.0]]])}, + (1, 2, 2), + ] + ) + class TestAsDiscreted(unittest.TestCase): @parameterized.expand(TEST_CASES)