Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions monai/utils/deprecate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

import inspect
import sys
import warnings
from functools import wraps
from types import FunctionType
Expand Down Expand Up @@ -62,7 +63,7 @@ def deprecated(

# if version_val.startswith("0+"):
# # version unknown, set version_val to a large value (assuming the latest version)
# version_val = "100"
# version_val = f"{sys.maxsize}"
if since is not None and removed is not None and not version_leq(since, removed):
raise ValueError(f"since must be less or equal to removed, got since={since}, removed={removed}.")
is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since)
Expand Down Expand Up @@ -144,14 +145,16 @@ def deprecated_arg(
msg_suffix: message appended to warning/exception detailing reasons for deprecation and what to use instead.
version_val: (used for testing) version to compare since and removed against, default is MONAI version.
new_name: name of position or keyword argument to replace the deprecated argument.
if it is specified and the signature of the decorated function has a `kwargs`, the value to the
deprecated argument `name` will be removed.

Returns:
Decorated callable which warns or raises exception when deprecated argument used.
"""

if version_val.startswith("0+") or not f"{version_val}".strip()[0].isdigit():
# version unknown, set version_val to a large value (assuming the latest version)
version_val = "100"
version_val = f"{sys.maxsize}"
if since is not None and removed is not None and not version_leq(since, removed):
raise ValueError(f"since must be less or equal to removed, got since={since}, removed={removed}.")
is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since)
Expand Down Expand Up @@ -197,9 +200,13 @@ def _wrapper(*args, **kwargs):
# multiple values for new_name using both args and kwargs
kwargs.pop(new_name, None)
binding = sig.bind(*args, **kwargs).arguments

positional_found = name in binding
kw_found = "kwargs" in binding and name in binding["kwargs"]
kw_found = False
for k, param in sig.parameters.items():
if param.kind == inspect.Parameter.VAR_KEYWORD and k in binding and name in binding[k]:
kw_found = True
# if the deprecated arg is found in the **kwargs, it should be removed
kwargs.pop(name, None)

if positional_found or kw_found:
if is_removed:
Expand Down
45 changes: 42 additions & 3 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_warning(self):
def foo2():
pass

print(foo2())
foo2() # should not raise any warnings

def test_warning_milestone(self):
"""Test deprecated decorator with `since` and `removed` set for a milestone version"""
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_arg_warn2(self):
"""Test deprecated_arg decorator with just `since` set."""

@deprecated_arg("b", since=self.prev_version, version_val=self.test_version)
def afoo2(a, **kwargs):
def afoo2(a, **kw):
pass

afoo2(1) # ok when no b provided
Expand Down Expand Up @@ -235,6 +235,19 @@ def afoo4(a, b=None):

self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2))

def test_arg_except3_unknown(self):
"""
Test deprecated_arg decorator raises exception with `removed` set in the past.
with unknown version and kwargs
"""

@deprecated_arg("b", removed=self.prev_version, version_val="0+untagged.1.g3131155")
def afoo4(a, b=None, **kwargs):
pass

self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2))
self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2, c=3))

def test_replacement_arg(self):
"""
Test deprecated arg being replaced.
Expand All @@ -245,10 +258,36 @@ def afoo4(a, b=None):
return a

self.assertEqual(afoo4(b=2), 2)
# self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2))
self.assertEqual(afoo4(1, b=2), 1) # new name is in use
self.assertEqual(afoo4(a=1, b=2), 1) # prefers the new arg

def test_replacement_arg1(self):
"""
Test deprecated arg being replaced with kwargs.
"""

@deprecated_arg("b", new_name="a", since=self.prev_version, version_val=self.test_version)
def afoo4(a, *args, **kwargs):
return a

self.assertEqual(afoo4(b=2), 2)
self.assertEqual(afoo4(1, b=2, c=3), 1) # new name is in use
self.assertEqual(afoo4(a=1, b=2, c=3), 1) # prefers the new arg

def test_replacement_arg2(self):
"""
Test deprecated arg (with a default value) being replaced.
"""

@deprecated_arg("b", new_name="a", since=self.prev_version, version_val=self.test_version)
def afoo4(a, b=None, **kwargs):
return a, kwargs

self.assertEqual(afoo4(b=2, c=3), (2, {"c": 3}))
self.assertEqual(afoo4(1, b=2, c=3), (1, {"c": 3})) # new name is in use
self.assertEqual(afoo4(a=1, b=2, c=3), (1, {"c": 3})) # prefers the new arg
self.assertEqual(afoo4(1, 2, c=3), (1, {"c": 3})) # prefers the new positional arg


if __name__ == "__main__":
unittest.main()