diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 0f9133037a..6bcd4df9ef 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -155,36 +155,42 @@ class AsDiscrete(Transform): backend = [TransformBackends.TORCH] - @deprecated_arg("n_classes", since="0.6") - @deprecated_arg("num_classes", since="0.7") - @deprecated_arg("logit_thresh", since="0.7") - @deprecated_arg(name="threshold_values", new_name="threshold", since="0.7") + @deprecated_arg(name="n_classes", new_name="num_classes", since="0.6", msg_suffix="please use `to_onehot` instead.") + @deprecated_arg("num_classes", since="0.7", msg_suffix="please use `to_onehot` instead.") + @deprecated_arg("logit_thresh", since="0.7", msg_suffix="please use `threshold` instead.") + @deprecated_arg( + name="threshold_values", new_name="threshold", since="0.7", msg_suffix="please use `threshold` instead." + ) def __init__( self, argmax: bool = False, to_onehot: Optional[int] = None, threshold: Optional[float] = None, rounding: Optional[str] = None, - n_classes: Optional[int] = None, - num_classes: Optional[int] = None, - logit_thresh: float = 0.5, - threshold_values: bool = False, + n_classes: Optional[int] = None, # deprecated + num_classes: Optional[int] = None, # deprecated + logit_thresh: float = 0.5, # deprecated + threshold_values: Optional[bool] = False, # deprecated ) -> None: self.argmax = argmax - if isinstance(to_onehot, bool): - raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") + if isinstance(to_onehot, bool): # for backward compatibility + warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") + to_onehot = num_classes if to_onehot else None self.to_onehot = to_onehot - if isinstance(threshold, bool): - raise ValueError("`threshold_values=True/False` is deprecated, please use `threashold=value` instead.") + if isinstance(threshold, bool): # for backward compatibility + warnings.warn("`threshold_values=True/False` is deprecated, please use `threashold=value` instead.") + threshold = logit_thresh if threshold else None self.threshold = threshold self.rounding = rounding - @deprecated_arg("n_classes", since="0.6") - @deprecated_arg("num_classes", since="0.7") - @deprecated_arg("logit_thresh", since="0.7") - @deprecated_arg(name="threshold_values", new_name="threshold", since="0.7") + @deprecated_arg(name="n_classes", new_name="num_classes", since="0.6", msg_suffix="please use `to_onehot` instead.") + @deprecated_arg("num_classes", since="0.7", msg_suffix="please use `to_onehot` instead.") + @deprecated_arg("logit_thresh", since="0.7", msg_suffix="please use `threshold` instead.") + @deprecated_arg( + name="threshold_values", new_name="threshold", since="0.7", msg_suffix="please use `threshold` instead." + ) def __call__( self, img: NdarrayOrTensor, @@ -192,10 +198,10 @@ def __call__( to_onehot: Optional[int] = None, threshold: Optional[float] = None, rounding: Optional[str] = None, - n_classes: Optional[int] = None, - num_classes: Optional[int] = None, - logit_thresh: Optional[float] = None, - threshold_values: Optional[bool] = None, + n_classes: Optional[int] = None, # deprecated + num_classes: Optional[int] = None, # deprecated + logit_thresh: Optional[float] = None, # deprecated + threshold_values: Optional[bool] = None, # deprecated ) -> NdarrayOrTensor: """ Args: @@ -220,9 +226,11 @@ def __call__( """ if isinstance(to_onehot, bool): - raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") + warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") + to_onehot = num_classes if to_onehot else None if isinstance(threshold, bool): - raise ValueError("`threshold_values=True/False` is deprecated, please use `threashold=value` instead.") + warnings.warn("`threshold_values=True/False` is deprecated, please use `threashold=value` instead.") + threshold = logit_thresh if threshold else None img_t: torch.Tensor img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 8f97114a69..2f639e0a95 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -128,10 +128,12 @@ class AsDiscreted(MapTransform): backend = AsDiscrete.backend - @deprecated_arg("n_classes", since="0.6") - @deprecated_arg("num_classes", since="0.7") - @deprecated_arg("logit_thresh", since="0.7") - @deprecated_arg(name="threshold_values", new_name="threshold", since="0.7") + @deprecated_arg(name="n_classes", new_name="num_classes", since="0.6", msg_suffix="please use `to_onehot` instead.") + @deprecated_arg("num_classes", since="0.7", msg_suffix="please use `to_onehot` instead.") + @deprecated_arg("logit_thresh", since="0.7", msg_suffix="please use `threshold` instead.") + @deprecated_arg( + name="threshold_values", new_name="threshold", since="0.7", msg_suffix="please use `threshold` instead." + ) def __init__( self, keys: KeysCollection, @@ -140,10 +142,10 @@ def __init__( threshold: Union[Sequence[Optional[float]], Optional[float]] = None, rounding: Union[Sequence[Optional[str]], Optional[str]] = None, allow_missing_keys: bool = False, - n_classes: Optional[Union[Sequence[int], int]] = None, - num_classes: Optional[Union[Sequence[int], int]] = None, - logit_thresh: Union[Sequence[float], float] = 0.5, - threshold_values: Union[Sequence[bool], bool] = False, + n_classes: Optional[Union[Sequence[int], int]] = None, # deprecated + num_classes: Optional[Union[Sequence[int], int]] = None, # deprecated + logit_thresh: Union[Sequence[float], float] = 0.5, # deprecated + threshold_values: Union[Sequence[bool], bool] = False, # deprecated ) -> None: """ Args: @@ -172,7 +174,17 @@ def __init__( super().__init__(keys, allow_missing_keys) self.argmax = ensure_tuple_rep(argmax, len(self.keys)) self.to_onehot = ensure_tuple_rep(to_onehot, len(self.keys)) + + if True in self.to_onehot or False in self.to_onehot: # backward compatibility + warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") + num_classes = ensure_tuple_rep(num_classes, len(self.keys)) + self.to_onehot = tuple(val if flag else None for flag, val in zip(self.to_onehot, num_classes)) + self.threshold = ensure_tuple_rep(threshold, len(self.keys)) + if True in self.threshold or False in self.threshold: # backward compatibility + warnings.warn("`threshold_values=True/False` is deprecated, please use `threshold=value` instead.") + logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys)) + self.threshold = tuple(val if flag else None for flag, val in zip(self.threshold, logit_thresh)) self.rounding = ensure_tuple_rep(rounding, len(self.keys)) self.converter = AsDiscrete()