Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/data/test_time_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class TestTimeAugmentation:
.. code-block:: python

transform = RandAffined(keys, ...)
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

tt_aug = TestTimeAugmentation(
transform, batch_size=5, num_workers=0, inferrer_fn=lambda x: post_trans(model(x)), device=device
Expand Down
108 changes: 68 additions & 40 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,94 +122,122 @@ class AsDiscrete(Transform):
Args:
argmax: whether to execute argmax function on input data before transform.
Defaults to ``False``.
to_onehot: whether to convert input data into the one-hot format.
Defaults to ``False``.
num_classes: the number of classes to convert to One-Hot format.
to_onehot: if not None, convert input data into the one-hot format with specified number of classes.
Defaults to ``None``.
threshold: if not None, threshold the float values to int number 0 or 1 with specified theashold.
Defaults to ``None``.
threshold_values: whether threshold the float value to int number 0 or 1.
Defaults to ``False``.
logit_thresh: the threshold value for thresholding operation..
Defaults to ``0.5``.
rounding: if not None, round the data according to the specified option,
available options: ["torchrounding"].

Example:

>>> transform = AsDiscrete(argmax=True)
>>> print(transform(np.array([[[0.0, 1.0]], [[2.0, 3.0]]])))
# [[[1.0, 1.0]]]

>>> transform = AsDiscrete(threshold=0.6)
>>> print(transform(np.array([[[0.0, 0.5], [0.8, 3.0]]])))
# [[[0.0, 0.0], [1.0, 1.0]]]

>>> transform = AsDiscrete(argmax=True, to_onehot=2, threshold=0.5)
>>> print(transform(np.array([[[0.0, 1.0]], [[2.0, 3.0]]])))
# [[[0.0, 0.0]], [[1.0, 1.0]]]

.. deprecated:: 0.6.0
``n_classes`` is deprecated, use ``num_classes`` instead.
``n_classes`` is deprecated, use ``to_onehot`` instead.

.. deprecated:: 0.7.0
``num_classes`` is deprecated, use ``to_onehot`` instead.
``logit_thresh`` is deprecated, use ``threshold`` instead.
``threshold_values`` is deprecated, use ``threshold`` instead.

"""

backend = [TransformBackends.TORCH]

@deprecated_arg("n_classes", since="0.6")
@deprecated_arg("num_classes", since="0.7")
@deprecated_arg("logit_thresh", since="0.7")
@deprecated_arg(name="threshold_values", new_name="threshold", since="0.7")
def __init__(
self,
argmax: bool = False,
to_onehot: bool = False,
num_classes: Optional[int] = None,
threshold_values: bool = False,
logit_thresh: float = 0.5,
to_onehot: Optional[int] = None,
threshold: Optional[float] = None,
rounding: Optional[str] = None,
n_classes: Optional[int] = None,
num_classes: Optional[int] = None,
logit_thresh: float = 0.5,
threshold_values: bool = False,
) -> None:
# in case the new num_classes is default but you still call deprecated n_classes
if n_classes is not None and num_classes is None:
num_classes = n_classes
self.argmax = argmax
if isinstance(to_onehot, bool):
raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
self.to_onehot = to_onehot
self.num_classes = num_classes
self.threshold_values = threshold_values
self.logit_thresh = logit_thresh

if isinstance(threshold, bool):
raise ValueError("`threshold_values=True/False` is deprecated, please use `threashold=value` instead.")
self.threshold = threshold

self.rounding = rounding

@deprecated_arg("n_classes", since="0.6")
@deprecated_arg("num_classes", since="0.7")
@deprecated_arg("logit_thresh", since="0.7")
@deprecated_arg(name="threshold_values", new_name="threshold", since="0.7")
def __call__(
self,
img: NdarrayOrTensor,
argmax: Optional[bool] = None,
to_onehot: Optional[bool] = None,
num_classes: Optional[int] = None,
threshold_values: Optional[bool] = None,
logit_thresh: Optional[float] = None,
to_onehot: Optional[int] = None,
threshold: Optional[float] = None,
rounding: Optional[str] = None,
n_classes: Optional[int] = None,
num_classes: Optional[int] = None,
logit_thresh: Optional[float] = None,
threshold_values: Optional[bool] = None,
) -> NdarrayOrTensor:
"""
Args:
img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`,
will automatically add it.
argmax: whether to execute argmax function on input data before transform.
Defaults to ``self.argmax``.
to_onehot: whether to convert input data into the one-hot format.
to_onehot: if not None, convert input data into the one-hot format with specified number of classes.
Defaults to ``self.to_onehot``.
num_classes: the number of classes to convert to One-Hot format.
Defaults to ``self.num_classes``.
threshold_values: whether threshold the float value to int number 0 or 1.
Defaults to ``self.threshold_values``.
logit_thresh: the threshold value for thresholding operation..
Defaults to ``self.logit_thresh``.
threshold: if not None, threshold the float values to int number 0 or 1 with specified theashold value.
Defaults to ``self.threshold``.
rounding: if not None, round the data according to the specified option,
available options: ["torchrounding"].

.. deprecated:: 0.6.0
``n_classes`` is deprecated, use ``num_classes`` instead.
``n_classes`` is deprecated, use ``to_onehot`` instead.

.. deprecated:: 0.7.0
``num_classes`` is deprecated, use ``to_onehot`` instead.
``logit_thresh`` is deprecated, use ``threshold`` instead.
``threshold_values`` is deprecated, use ``threshold`` instead.

"""
# in case the new num_classes is default but you still call deprecated n_classes
if n_classes is not None and num_classes is None:
num_classes = n_classes
if isinstance(to_onehot, bool):
raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
if isinstance(threshold, bool):
raise ValueError("`threshold_values=True/False` is deprecated, please use `threashold=value` instead.")

img_t: torch.Tensor
img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore
if argmax or self.argmax:
img_t = torch.argmax(img_t, dim=0, keepdim=True)

if to_onehot or self.to_onehot:
_nclasses = self.num_classes if num_classes is None else num_classes
if not isinstance(_nclasses, int):
raise AssertionError("One of self.num_classes or num_classes must be an integer")
img_t = one_hot(img_t, num_classes=_nclasses, dim=0)
to_onehot = to_onehot or self.to_onehot
if to_onehot is not None:
if not isinstance(to_onehot, int):
raise AssertionError("the number of classes for One-Hot must be an integer.")
img_t = one_hot(img_t, num_classes=to_onehot, dim=0)

if threshold_values or self.threshold_values:
img_t = img_t >= (self.logit_thresh if logit_thresh is None else logit_thresh)
threshold = threshold or self.threshold
if threshold is not None:
img_t = img_t >= threshold

rounding = self.rounding if rounding is None else rounding
if rounding is not None:
Expand Down
46 changes: 23 additions & 23 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,59 +129,59 @@ class AsDiscreted(MapTransform):
backend = AsDiscrete.backend

@deprecated_arg("n_classes", since="0.6")
@deprecated_arg("num_classes", since="0.7")
@deprecated_arg("logit_thresh", since="0.7")
@deprecated_arg(name="threshold_values", new_name="threshold", since="0.7")
def __init__(
self,
keys: KeysCollection,
argmax: Union[Sequence[bool], bool] = False,
to_onehot: Union[Sequence[bool], bool] = False,
num_classes: Optional[Union[Sequence[int], int]] = None,
threshold_values: Union[Sequence[bool], bool] = False,
logit_thresh: Union[Sequence[float], float] = 0.5,
to_onehot: Union[Sequence[Optional[int]], Optional[int]] = None,
threshold: Union[Sequence[Optional[float]], Optional[float]] = None,
rounding: Union[Sequence[Optional[str]], Optional[str]] = None,
allow_missing_keys: bool = False,
n_classes: Optional[int] = None,
n_classes: Optional[Union[Sequence[int], int]] = None,
num_classes: Optional[Union[Sequence[int], int]] = None,
logit_thresh: Union[Sequence[float], float] = 0.5,
threshold_values: Union[Sequence[bool], bool] = False,
) -> None:
"""
Args:
keys: keys of the corresponding items to model output and label.
See also: :py:class:`monai.transforms.compose.MapTransform`
argmax: whether to execute argmax function on input data before transform.
it also can be a sequence of bool, each element corresponds to a key in ``keys``.
to_onehot: whether to convert input data into the one-hot format. Defaults to False.
it also can be a sequence of bool, each element corresponds to a key in ``keys``.
num_classes: the number of classes to convert to One-Hot format. it also can be a
sequence of int, each element corresponds to a key in ``keys``.
threshold_values: whether threshold the float value to int number 0 or 1, default is False.
it also can be a sequence of bool, each element corresponds to a key in ``keys``.
logit_thresh: the threshold value for thresholding operation, default is 0.5.
it also can be a sequence of float, each element corresponds to a key in ``keys``.
to_onehot: if not None, convert input data into the one-hot format with specified number of classes.
defaults to ``None``. it also can be a sequence, each element corresponds to a key in ``keys``.
threshold: if not None, threshold the float values to int number 0 or 1 with specified theashold value.
defaults to ``None``. it also can be a sequence, each element corresponds to a key in ``keys``.
rounding: if not None, round the data according to the specified option,
available options: ["torchrounding"]. it also can be a sequence of str or None,
each element corresponds to a key in ``keys``.
allow_missing_keys: don't raise exception if key is missing.

.. deprecated:: 0.6.0
``n_classes`` is deprecated, use ``num_classes`` instead.
``n_classes`` is deprecated, use ``to_onehot`` instead.

.. deprecated:: 0.7.0
``num_classes`` is deprecated, use ``to_onehot`` instead.
``logit_thresh`` is deprecated, use ``threshold`` instead.
``threshold_values`` is deprecated, use ``threshold`` instead.

"""
# in case the new num_classes is default but you still call deprecated n_classes
if n_classes is not None and num_classes is None:
num_classes = n_classes
super().__init__(keys, allow_missing_keys)
self.argmax = ensure_tuple_rep(argmax, len(self.keys))
self.to_onehot = ensure_tuple_rep(to_onehot, len(self.keys))
self.num_classes = ensure_tuple_rep(num_classes, len(self.keys))
self.threshold_values = ensure_tuple_rep(threshold_values, len(self.keys))
self.logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys))
self.threshold = ensure_tuple_rep(threshold, len(self.keys))
self.rounding = ensure_tuple_rep(rounding, len(self.keys))
self.converter = AsDiscrete()

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key, argmax, to_onehot, num_classes, threshold_values, logit_thresh, rounding in self.key_iterator(
d, self.argmax, self.to_onehot, self.num_classes, self.threshold_values, self.logit_thresh, self.rounding
for key, argmax, to_onehot, threshold, rounding in self.key_iterator(
d, self.argmax, self.to_onehot, self.threshold, self.rounding
):
d[key] = self.converter(d[key], argmax, to_onehot, num_classes, threshold_values, logit_thresh, rounding)
d[key] = self.converter(d[key], argmax, to_onehot, threshold, rounding)
return d


Expand Down
11 changes: 2 additions & 9 deletions monai/transforms/utils_create_transform_ims.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,15 +640,8 @@ def create_transform_im(
create_transform_im(RandScaleCropd, dict(keys=keys, roi_scale=0.4), data)
create_transform_im(CenterScaleCrop, dict(roi_scale=0.4), data)
create_transform_im(CenterScaleCropd, dict(keys=keys, roi_scale=0.4), data)
create_transform_im(
AsDiscrete, dict(num_classes=2, threshold_values=True, logit_thresh=10), data, is_post=True, colorbar=True
)
create_transform_im(
AsDiscreted,
dict(keys=CommonKeys.LABEL, num_classes=2, threshold_values=True, logit_thresh=10),
data,
is_post=True,
)
create_transform_im(AsDiscrete, dict(to_onehot=2, threshold=10), data, is_post=True, colorbar=True)
create_transform_im(AsDiscreted, dict(keys=CommonKeys.LABEL, to_onehot=2, threshold=10), data, is_post=True)
create_transform_im(LabelFilter, dict(applied_labels=(1, 2, 3, 4, 5, 6)), data, is_post=True)
create_transform_im(
LabelFilterd, dict(keys=CommonKeys.LABEL, applied_labels=(1, 2, 3, 4, 5, 6)), data, is_post=True
Expand Down
8 changes: 4 additions & 4 deletions tests/test_as_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
for p in TEST_NDARRAYS:
TEST_CASES.append(
[
{"argmax": True, "to_onehot": False, "num_classes": None, "threshold_values": False, "logit_thresh": 0.5},
{"argmax": True, "to_onehot": None, "threshold": 0.5},
p([[[0.0, 1.0]], [[2.0, 3.0]]]),
p([[[1.0, 1.0]]]),
(1, 1, 2),
Expand All @@ -29,7 +29,7 @@

TEST_CASES.append(
[
{"argmax": True, "to_onehot": True, "num_classes": 2, "threshold_values": False, "logit_thresh": 0.5},
{"argmax": True, "to_onehot": 2, "threshold": 0.5},
p([[[0.0, 1.0]], [[2.0, 3.0]]]),
p([[[0.0, 0.0]], [[1.0, 1.0]]]),
(2, 1, 2),
Expand All @@ -38,14 +38,14 @@

TEST_CASES.append(
[
{"argmax": False, "to_onehot": False, "num_classes": None, "threshold_values": True, "logit_thresh": 0.6},
{"argmax": False, "to_onehot": None, "threshold": 0.6},
p([[[0.0, 1.0], [2.0, 3.0]]]),
p([[[0.0, 1.0], [1.0, 1.0]]]),
(1, 2, 2),
]
)

TEST_CASES.append([{"argmax": False, "to_onehot": True, "num_classes": 3}, p(1), p([0.0, 1.0, 0.0]), (3,)])
TEST_CASES.append([{"argmax": False, "to_onehot": 3}, p(1), p([0.0, 1.0, 0.0]), (3,)])

TEST_CASES.append(
[{"rounding": "torchrounding"}, p([[[0.123, 1.345], [2.567, 3.789]]]), p([[[0.0, 1.0], [3.0, 4.0]]]), (1, 2, 2)]
Expand Down
27 changes: 3 additions & 24 deletions tests/test_as_discreted.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,7 @@
for p in TEST_NDARRAYS:
TEST_CASES.append(
[
{
"keys": ["pred", "label"],
"argmax": [True, False],
"to_onehot": True,
"num_classes": 2,
"threshold_values": False,
"logit_thresh": 0.5,
},
{"keys": ["pred", "label"], "argmax": [True, False], "to_onehot": 2, "threshold": 0.5},
{"pred": p([[[0.0, 1.0]], [[2.0, 3.0]]]), "label": p([[[0, 1]]])},
{"pred": p([[[0.0, 0.0]], [[1.0, 1.0]]]), "label": p([[[1.0, 0.0]], [[0.0, 1.0]]])},
(2, 1, 2),
Expand All @@ -36,14 +29,7 @@

TEST_CASES.append(
[
{
"keys": ["pred", "label"],
"argmax": False,
"to_onehot": False,
"num_classes": None,
"threshold_values": [True, False],
"logit_thresh": 0.6,
},
{"keys": ["pred", "label"], "argmax": False, "to_onehot": None, "threshold": [0.6, None]},
{"pred": p([[[0.0, 1.0], [2.0, 3.0]]]), "label": p([[[0, 1], [1, 1]]])},
{"pred": p([[[0.0, 1.0], [1.0, 1.0]]]), "label": p([[[0.0, 1.0], [1.0, 1.0]]])},
(1, 2, 2),
Expand All @@ -52,14 +38,7 @@

TEST_CASES.append(
[
{
"keys": ["pred"],
"argmax": True,
"to_onehot": True,
"num_classes": 2,
"threshold_values": False,
"logit_thresh": 0.5,
},
{"keys": ["pred"], "argmax": True, "to_onehot": 2, "threshold": 0.5},
{"pred": p([[[0.0, 1.0]], [[2.0, 3.0]]])},
{"pred": p([[[0.0, 0.0]], [[1.0, 1.0]]])},
(2, 1, 2),
Expand Down
Loading