Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 30 additions & 22 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,47 +155,53 @@ class AsDiscrete(Transform):

backend = [TransformBackends.TORCH]

@deprecated_arg("n_classes", since="0.6")
@deprecated_arg("num_classes", since="0.7")
@deprecated_arg("logit_thresh", since="0.7")
@deprecated_arg(name="threshold_values", new_name="threshold", since="0.7")
@deprecated_arg(name="n_classes", new_name="num_classes", since="0.6", msg_suffix="please use `to_onehot` instead.")
@deprecated_arg("num_classes", since="0.7", msg_suffix="please use `to_onehot` instead.")
@deprecated_arg("logit_thresh", since="0.7", msg_suffix="please use `threshold` instead.")
@deprecated_arg(
name="threshold_values", new_name="threshold", since="0.7", msg_suffix="please use `threshold` instead."
)
def __init__(
self,
argmax: bool = False,
to_onehot: Optional[int] = None,
threshold: Optional[float] = None,
rounding: Optional[str] = None,
n_classes: Optional[int] = None,
num_classes: Optional[int] = None,
logit_thresh: float = 0.5,
threshold_values: bool = False,
n_classes: Optional[int] = None, # deprecated
num_classes: Optional[int] = None, # deprecated
logit_thresh: float = 0.5, # deprecated
threshold_values: Optional[bool] = False, # deprecated
) -> None:
self.argmax = argmax
if isinstance(to_onehot, bool):
raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
if isinstance(to_onehot, bool): # for backward compatibility
warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
to_onehot = num_classes if to_onehot else None
self.to_onehot = to_onehot

if isinstance(threshold, bool):
raise ValueError("`threshold_values=True/False` is deprecated, please use `threashold=value` instead.")
if isinstance(threshold, bool): # for backward compatibility
warnings.warn("`threshold_values=True/False` is deprecated, please use `threashold=value` instead.")
threshold = logit_thresh if threshold else None
self.threshold = threshold

self.rounding = rounding

@deprecated_arg("n_classes", since="0.6")
@deprecated_arg("num_classes", since="0.7")
@deprecated_arg("logit_thresh", since="0.7")
@deprecated_arg(name="threshold_values", new_name="threshold", since="0.7")
@deprecated_arg(name="n_classes", new_name="num_classes", since="0.6", msg_suffix="please use `to_onehot` instead.")
@deprecated_arg("num_classes", since="0.7", msg_suffix="please use `to_onehot` instead.")
@deprecated_arg("logit_thresh", since="0.7", msg_suffix="please use `threshold` instead.")
@deprecated_arg(
name="threshold_values", new_name="threshold", since="0.7", msg_suffix="please use `threshold` instead."
)
def __call__(
self,
img: NdarrayOrTensor,
argmax: Optional[bool] = None,
to_onehot: Optional[int] = None,
threshold: Optional[float] = None,
rounding: Optional[str] = None,
n_classes: Optional[int] = None,
num_classes: Optional[int] = None,
logit_thresh: Optional[float] = None,
threshold_values: Optional[bool] = None,
n_classes: Optional[int] = None, # deprecated
num_classes: Optional[int] = None, # deprecated
logit_thresh: Optional[float] = None, # deprecated
threshold_values: Optional[bool] = None, # deprecated
) -> NdarrayOrTensor:
"""
Args:
Expand All @@ -220,9 +226,11 @@ def __call__(

"""
if isinstance(to_onehot, bool):
raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
to_onehot = num_classes if to_onehot else None
if isinstance(threshold, bool):
raise ValueError("`threshold_values=True/False` is deprecated, please use `threashold=value` instead.")
warnings.warn("`threshold_values=True/False` is deprecated, please use `threashold=value` instead.")
threshold = logit_thresh if threshold else None

img_t: torch.Tensor
img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore
Expand Down
28 changes: 20 additions & 8 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,12 @@ class AsDiscreted(MapTransform):

backend = AsDiscrete.backend

@deprecated_arg("n_classes", since="0.6")
@deprecated_arg("num_classes", since="0.7")
@deprecated_arg("logit_thresh", since="0.7")
@deprecated_arg(name="threshold_values", new_name="threshold", since="0.7")
@deprecated_arg(name="n_classes", new_name="num_classes", since="0.6", msg_suffix="please use `to_onehot` instead.")
@deprecated_arg("num_classes", since="0.7", msg_suffix="please use `to_onehot` instead.")
@deprecated_arg("logit_thresh", since="0.7", msg_suffix="please use `threshold` instead.")
@deprecated_arg(
name="threshold_values", new_name="threshold", since="0.7", msg_suffix="please use `threshold` instead."
)
def __init__(
self,
keys: KeysCollection,
Expand All @@ -140,10 +142,10 @@ def __init__(
threshold: Union[Sequence[Optional[float]], Optional[float]] = None,
rounding: Union[Sequence[Optional[str]], Optional[str]] = None,
allow_missing_keys: bool = False,
n_classes: Optional[Union[Sequence[int], int]] = None,
num_classes: Optional[Union[Sequence[int], int]] = None,
logit_thresh: Union[Sequence[float], float] = 0.5,
threshold_values: Union[Sequence[bool], bool] = False,
n_classes: Optional[Union[Sequence[int], int]] = None, # deprecated
num_classes: Optional[Union[Sequence[int], int]] = None, # deprecated
logit_thresh: Union[Sequence[float], float] = 0.5, # deprecated
threshold_values: Union[Sequence[bool], bool] = False, # deprecated
) -> None:
"""
Args:
Expand Down Expand Up @@ -172,7 +174,17 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.argmax = ensure_tuple_rep(argmax, len(self.keys))
self.to_onehot = ensure_tuple_rep(to_onehot, len(self.keys))

if True in self.to_onehot or False in self.to_onehot: # backward compatibility
warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
num_classes = ensure_tuple_rep(num_classes, len(self.keys))
self.to_onehot = tuple(val if flag else None for flag, val in zip(self.to_onehot, num_classes))

self.threshold = ensure_tuple_rep(threshold, len(self.keys))
if True in self.threshold or False in self.threshold: # backward compatibility
warnings.warn("`threshold_values=True/False` is deprecated, please use `threshold=value` instead.")
logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys))
self.threshold = tuple(val if flag else None for flag, val in zip(self.threshold, logit_thresh))
self.rounding = ensure_tuple_rep(rounding, len(self.keys))
self.converter = AsDiscrete()

Expand Down