From ddd4647a3f397479c1450cb0fef8b47e3238473a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 24 Nov 2021 14:46:20 +0800 Subject: [PATCH 1/2] [DLMED] fix None value Signed-off-by: Nic Ma --- monai/transforms/post/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index d83443507a..5dca7df3f6 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -237,13 +237,13 @@ def __call__( if argmax or self.argmax: img_t = torch.argmax(img_t, dim=0, keepdim=True) - to_onehot = to_onehot or self.to_onehot + to_onehot = self.to_onehot if to_onehot is None else to_onehot if to_onehot is not None: if not isinstance(to_onehot, int): raise AssertionError("the number of classes for One-Hot must be an integer.") img_t = one_hot(img_t, num_classes=to_onehot, dim=0) - threshold = threshold or self.threshold + threshold = self.threshold if threshold is None else threshold if threshold is not None: img_t = img_t >= threshold From edfc4dbed16871705c7c9814767b6f2df898a541 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 24 Nov 2021 15:05:35 +0800 Subject: [PATCH 2/2] [DLMED] add unit tests Signed-off-by: Nic Ma --- tests/test_as_discrete.py | 10 ++++++++++ tests/test_as_discreted.py | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index e8db2052f0..8cbefbac39 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -45,6 +45,16 @@ ] ) + # test threshold = 0.0 + TEST_CASES.append( + [ + {"argmax": False, "to_onehot": None, "threshold": 0.0}, + p([[[0.0, -1.0], [-2.0, 3.0]]]), + p([[[1.0, 0.0], [0.0, 1.0]]]), + (1, 2, 2), + ] + ) + TEST_CASES.append([{"argmax": False, "to_onehot": 3}, p(1), p([0.0, 1.0, 0.0]), (3,)]) TEST_CASES.append( diff --git a/tests/test_as_discreted.py b/tests/test_as_discreted.py index ae9d578f78..7c7cfdf6e5 100644 --- a/tests/test_as_discreted.py +++ b/tests/test_as_discreted.py @@ -70,6 +70,16 @@ ] ) + # test threshold = 0.0 + TEST_CASES.append( + [ + {"keys": ["pred", "label"], "argmax": False, "to_onehot": None, "threshold": [0.0, None]}, + {"pred": p([[[0.0, -1.0], [-2.0, 3.0]]]), "label": p([[[0, 1], [1, 1]]])}, + {"pred": p([[[1.0, 0.0], [0.0, 1.0]]]), "label": p([[[0.0, 1.0], [1.0, 1.0]]])}, + (1, 2, 2), + ] + ) + class TestAsDiscreted(unittest.TestCase): @parameterized.expand(TEST_CASES)