diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 6dc12a0254..1dd7e40930 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -12,7 +12,7 @@ # have to explicitly bring these in here to resolve circular import issues from .aliases import alias, resolve_name from .decorators import MethodReplacer, RestartGenerator -from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg +from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default from .dist import evenly_divisible_all_gather, get_dist_device, string_list_all_gather from .enums import ( Average, diff --git a/monai/utils/deprecate_utils.py b/monai/utils/deprecate_utils.py index 68f2d6e46d..eb182aae47 100644 --- a/monai/utils/deprecate_utils.py +++ b/monai/utils/deprecate_utils.py @@ -14,13 +14,13 @@ import warnings from functools import wraps from types import FunctionType -from typing import Optional +from typing import Any, Optional from monai.utils.module import version_leq from .. import __version__ -__all__ = ["deprecated", "deprecated_arg", "DeprecatedError"] +__all__ = ["deprecated", "deprecated_arg", "DeprecatedError", "deprecated_arg_default"] class DeprecatedError(Exception): @@ -223,3 +223,105 @@ def _wrapper(*args, **kwargs): return _wrapper return _decorator + + +def deprecated_arg_default( + name: str, + old_default: Any, + new_default: Any, + since: Optional[str] = None, + replaced: Optional[str] = None, + msg_suffix: str = "", + version_val: str = __version__, + warning_category=FutureWarning, +): + """ + Marks a particular arguments default of a callable as deprecated. It is changed from `old_default` to `new_default` + in version `changed`. + + When the decorated definition is called, a `warning_category` is issued if `since` is given, + the default is not explicitly set by the caller and the current version is at or later than that given. + Another warning with the same category is issued if `changed` is given and the current version is at or later. + + The relevant docstring of the deprecating function should also be updated accordingly, + using the Sphinx directives such as `.. versionchanged:: version` and `.. deprecated:: version`. + https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded + + In the current implementation type annotations are not preserved. + + + Args: + name: name of position or keyword argument where the default is deprecated/changed. + old_default: name of the old default. This is only for the warning message, it will not be validated. + new_default: name of the new default. + It is validated that this value is not present as the default before version `replaced`. + This means, that you can also use this if the actual default value is `None` and set later in the function. + You can also set this to any string representation, e.g. `"calculate_default_value()"` + if the default is calculated from another function. + since: version at which the argument default was marked deprecated but not replaced. + replaced: version at which the argument default was/will be replaced. + msg_suffix: message appended to warning/exception detailing reasons for deprecation. + version_val: (used for testing) version to compare since and removed against, default is MONAI version. + warning_category: a warning category class, defaults to `FutureWarning`. + + Returns: + Decorated callable which warns when deprecated default argument is not explicitly specified. + """ + + if version_val.startswith("0+") or not f"{version_val}".strip()[0].isdigit(): + # version unknown, set version_val to a large value (assuming the latest version) + version_val = f"{sys.maxsize}" + if since is not None and replaced is not None and not version_leq(since, replaced): + raise ValueError(f"since must be less or equal to replaced, got since={since}, replaced={replaced}.") + is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since) + if is_not_yet_deprecated: + # smaller than `since`, do nothing + return lambda obj: obj + if since is None and replaced is None: + # raise a DeprecatedError directly + is_replaced = True + is_deprecated = True + else: + # compare the numbers + is_deprecated = since is not None and version_leq(since, version_val) + is_replaced = replaced is not None and version_leq(replaced, version_val) + + def _decorator(func): + argname = f"{func.__module__} {func.__qualname__}:{name}" + + msg_prefix = f"Default of argument `{name}`" + + if is_replaced: + msg_infix = f"was replaced in version {replaced} from `{old_default}` to `{new_default}`." + elif is_deprecated: + msg_infix = f"has been deprecated since version {since} from `{old_default}` to `{new_default}`." + if replaced is not None: + msg_infix += f" It will be replaced in version {replaced}." + else: + msg_infix = f"has been deprecated from `{old_default}` to `{new_default}`." + + msg = f"{msg_prefix} {msg_infix} {msg_suffix}".strip() + + sig = inspect.signature(func) + if name not in sig.parameters: + raise ValueError(f"Argument `{name}` not found in signature of {func.__qualname__}.") + param = sig.parameters[name] + if param.default is inspect.Parameter.empty: + raise ValueError(f"Argument `{name}` has no default value.") + + if param.default == new_default and not is_replaced: + raise ValueError( + f"Argument `{name}` was replaced to the new default value `{new_default}` before the specified version {replaced}." + ) + + @wraps(func) + def _wrapper(*args, **kwargs): + if name not in sig.bind(*args, **kwargs).arguments and is_deprecated: + # arg was not found so the default value is used + warn_deprecated(argname, msg, warning_category) + + return func(*args, **kwargs) + + return _wrapper + + return _decorator diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index c94c300175..286ec4f8a5 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -12,7 +12,7 @@ import unittest import warnings -from monai.utils import DeprecatedError, deprecated, deprecated_arg +from monai.utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default class TestDeprecatedRC(unittest.TestCase): @@ -287,6 +287,159 @@ def afoo4(a, b=None, **kwargs): self.assertEqual(afoo4(a=1, b=2, c=3), (1, {"c": 3})) # prefers the new arg self.assertEqual(afoo4(1, 2, c=3), (1, {"c": 3})) # prefers the new positional arg + def test_deprecated_arg_default_explicit_default(self): + """ + Test deprecated arg default, where the default is explicitly set (no warning). + """ + + @deprecated_arg_default( + "b", old_default="a", new_default="b", since=self.prev_version, version_val=self.test_version + ) + def foo(a, b="a"): + return a, b + + with self.assertWarns(FutureWarning) as aw: + self.assertEqual(foo("a", "a"), ("a", "a")) + self.assertEqual(foo("a", "b"), ("a", "b")) + self.assertEqual(foo("a", "c"), ("a", "c")) + warnings.warn("fake warning", FutureWarning) + + self.assertEqual(aw.warning.args[0], "fake warning") + + def test_deprecated_arg_default_version_less_than_since(self): + """ + Test deprecated arg default, where the current version is less than `since` (no warning). + """ + + @deprecated_arg_default( + "b", old_default="a", new_default="b", since=self.test_version, version_val=self.prev_version + ) + def foo(a, b="a"): + return a, b + + with self.assertWarns(FutureWarning) as aw: + self.assertEqual(foo("a"), ("a", "a")) + self.assertEqual(foo("a", "a"), ("a", "a")) + warnings.warn("fake warning", FutureWarning) + + self.assertEqual(aw.warning.args[0], "fake warning") + + def test_deprecated_arg_default_warning_deprecated(self): + """ + Test deprecated arg default, where the default is used. + """ + + @deprecated_arg_default( + "b", old_default="a", new_default="b", since=self.prev_version, version_val=self.test_version + ) + def foo(a, b="a"): + return a, b + + self.assertWarns(FutureWarning, lambda: foo("a")) + + def test_deprecated_arg_default_warning_replaced(self): + """ + Test deprecated arg default, where the default is used. + """ + + @deprecated_arg_default( + "b", + old_default="a", + new_default="b", + since=self.prev_version, + replaced=self.prev_version, + version_val=self.test_version, + ) + def foo(a, b="a"): + return a, b + + self.assertWarns(FutureWarning, lambda: foo("a")) + + def test_deprecated_arg_default_warning_with_none_as_placeholder(self): + """ + Test deprecated arg default, where the default is used. + """ + + @deprecated_arg_default( + "b", old_default="a", new_default="b", since=self.prev_version, version_val=self.test_version + ) + def foo(a, b=None): + if b is None: + b = "a" + return a, b + + self.assertWarns(FutureWarning, lambda: foo("a")) + + @deprecated_arg_default( + "b", old_default="a", new_default="b", since=self.prev_version, version_val=self.test_version + ) + def foo2(a, b=None): + if b is None: + b = "b" + return a, b + + self.assertWarns(FutureWarning, lambda: foo2("a")) + + def test_deprecated_arg_default_errors(self): + """ + Test deprecated arg default, where the decorator is wrongly used. + """ + + # since > replaced + def since_grater_than_replaced(): + @deprecated_arg_default( + "b", + old_default="a", + new_default="b", + since=self.test_version, + replaced=self.prev_version, + version_val=self.test_version, + ) + def foo(a, b=None): + return a, b + + self.assertRaises(ValueError, since_grater_than_replaced) + + # argname doesnt exist + def argname_doesnt_exist(): + @deprecated_arg_default( + "other", old_default="a", new_default="b", since=self.test_version, version_val=self.test_version + ) + def foo(a, b=None): + return a, b + + self.assertRaises(ValueError, argname_doesnt_exist) + + # argname has no default + def argname_has_no_default(): + @deprecated_arg_default( + "a", + old_default="a", + new_default="b", + since=self.prev_version, + replaced=self.test_version, + version_val=self.test_version, + ) + def foo(a): + return a + + self.assertRaises(ValueError, argname_has_no_default) + + # new default is used but version < replaced + def argname_was_replaced_before_specified_version(): + @deprecated_arg_default( + "a", + old_default="a", + new_default="b", + since=self.prev_version, + replaced=self.next_version, + version_val=self.test_version, + ) + def foo(a, b="b"): + return a, b + + self.assertRaises(ValueError, argname_was_replaced_before_specified_version) + if __name__ == "__main__": unittest.main()