Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from monai.networks.layers import GaussianFilter
from monai.transforms.transform import Transform
from monai.transforms.utils import fill_holes, get_largest_connected_component_mask
from monai.utils import ensure_tuple
from monai.utils import ensure_tuple, look_up_option

__all__ = [
"Activations",
Expand Down Expand Up @@ -112,7 +112,8 @@ class AsDiscrete(Transform):

- execute `argmax` for input logits values.
- threshold input value to 0.0 or 1.0.
- convert input value to One-Hot format
- convert input value to One-Hot format.
- round the value to the closest integer.

Args:
argmax: whether to execute argmax function on input data before transform.
Expand All @@ -125,6 +126,8 @@ class AsDiscrete(Transform):
Defaults to ``False``.
logit_thresh: the threshold value for thresholding operation..
Defaults to ``0.5``.
rounding: if not None, round the data according to the specified option,
available options: ["torchrounding"].

"""

Expand All @@ -135,12 +138,14 @@ def __init__(
n_classes: Optional[int] = None,
threshold_values: bool = False,
logit_thresh: float = 0.5,
rounding: Optional[str] = None,
) -> None:
self.argmax = argmax
self.to_onehot = to_onehot
self.n_classes = n_classes
self.threshold_values = threshold_values
self.logit_thresh = logit_thresh
self.rounding = rounding

def __call__(
self,
Expand All @@ -150,6 +155,7 @@ def __call__(
n_classes: Optional[int] = None,
threshold_values: Optional[bool] = None,
logit_thresh: Optional[float] = None,
rounding: Optional[str] = None,
) -> torch.Tensor:
"""
Args:
Expand All @@ -165,6 +171,8 @@ def __call__(
Defaults to ``self.threshold_values``.
logit_thresh: the threshold value for thresholding operation..
Defaults to ``self.logit_thresh``.
rounding: if not None, round the data according to the specified option,
available options: ["torchrounding"].

"""
if argmax or self.argmax:
Expand All @@ -179,6 +187,11 @@ def __call__(
if threshold_values or self.threshold_values:
img = img >= (self.logit_thresh if logit_thresh is None else logit_thresh)

rounding = self.rounding if rounding is None else rounding
if rounding is not None:
rounding = look_up_option(rounding, ["torchrounding"])
img = torch.round(img)

return img.float()


Expand Down
10 changes: 8 additions & 2 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(
n_classes: Optional[Union[Sequence[int], int]] = None,
threshold_values: Union[Sequence[bool], bool] = False,
logit_thresh: Union[Sequence[float], float] = 0.5,
rounding: Union[Sequence[Optional[str]], Optional[str]] = None,
allow_missing_keys: bool = False,
) -> None:
"""
Expand All @@ -150,6 +151,9 @@ def __init__(
it also can be a sequence of bool, each element corresponds to a key in ``keys``.
logit_thresh: the threshold value for thresholding operation, default is 0.5.
it also can be a sequence of float, each element corresponds to a key in ``keys``.
rounding: if not None, round the data according to the specified option,
available options: ["torchrounding"]. it also can be a sequence of str or None,
each element corresponds to a key in ``keys``.
allow_missing_keys: don't raise exception if key is missing.

"""
Expand All @@ -159,12 +163,13 @@ def __init__(
self.n_classes = ensure_tuple_rep(n_classes, len(self.keys))
self.threshold_values = ensure_tuple_rep(threshold_values, len(self.keys))
self.logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys))
self.rounding = ensure_tuple_rep(rounding, len(self.keys))
self.converter = AsDiscrete()

def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)
for key, argmax, to_onehot, n_classes, threshold_values, logit_thresh in self.key_iterator(
d, self.argmax, self.to_onehot, self.n_classes, self.threshold_values, self.logit_thresh
for key, argmax, to_onehot, n_classes, threshold_values, logit_thresh, rounding in self.key_iterator(
d, self.argmax, self.to_onehot, self.n_classes, self.threshold_values, self.logit_thresh, self.rounding
):
d[key] = self.converter(
d[key],
Expand All @@ -173,6 +178,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
n_classes,
threshold_values,
logit_thresh,
rounding,
)
return d

Expand Down
9 changes: 8 additions & 1 deletion tests/test_as_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,16 @@
(3,),
]

TEST_CASE_5 = [
{"rounding": "torchrounding"},
torch.tensor([[[0.123, 1.345], [2.567, 3.789]]]),
torch.tensor([[[0.0, 1.0], [3.0, 4.0]]]),
(1, 2, 2),
]


class TestAsDiscrete(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
def test_value_shape(self, input_param, img, out, expected_shape):
result = AsDiscrete(**input_param)(img)
torch.testing.assert_allclose(result, out)
Expand Down
9 changes: 8 additions & 1 deletion tests/test_as_discreted.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,16 @@
(2, 1, 2),
]

TEST_CASE_4 = [
{"keys": "pred", "rounding": "torchrounding"},
{"pred": torch.tensor([[[0.123, 1.345], [2.567, 3.789]]])},
{"pred": torch.tensor([[[0.0, 1.0], [3.0, 4.0]]])},
(1, 2, 2),
]


class TestAsDiscreted(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_value_shape(self, input_param, test_input, output, expected_shape):
result = AsDiscreted(**input_param)(test_input)
torch.testing.assert_allclose(result["pred"], output["pred"])
Expand Down