From 79c4350b11cdc2047663b9c987b9a6fe0106b127 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 7 Dec 2021 11:17:54 +0000 Subject: [PATCH 1/2] fixes deprecated args Signed-off-by: Wenqi Li --- monai/utils/deprecate_utils.py | 10 ++++++-- tests/test_deprecated.py | 45 +++++++++++++++++++++++++++++++--- 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/monai/utils/deprecate_utils.py b/monai/utils/deprecate_utils.py index 4ae5991d9f..891ce900c4 100644 --- a/monai/utils/deprecate_utils.py +++ b/monai/utils/deprecate_utils.py @@ -144,6 +144,8 @@ def deprecated_arg( msg_suffix: message appended to warning/exception detailing reasons for deprecation and what to use instead. version_val: (used for testing) version to compare since and removed against, default is MONAI version. new_name: name of position or keyword argument to replace the deprecated argument. + if it is specified and the signature of the decorated function has a `kwargs`, the value to the + deprecated argument `name` will be removed. Returns: Decorated callable which warns or raises exception when deprecated argument used. @@ -197,9 +199,13 @@ def _wrapper(*args, **kwargs): # multiple values for new_name using both args and kwargs kwargs.pop(new_name, None) binding = sig.bind(*args, **kwargs).arguments - positional_found = name in binding - kw_found = "kwargs" in binding and name in binding["kwargs"] + kw_found = False + for k, param in sig.parameters.items(): + if param.kind == inspect.Parameter.VAR_KEYWORD and k in binding and name in binding[k]: + kw_found = True + # if the deprecated arg is found in the **kwargs, it should be removed + kwargs.pop(name, None) if positional_found or kw_found: if is_removed: diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 9c7fe4f632..59e888240b 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -29,7 +29,7 @@ def test_warning(self): def foo2(): pass - print(foo2()) + foo2() # should not raise any warnings def test_warning_milestone(self): """Test deprecated decorator with `since` and `removed` set for a milestone version""" @@ -172,7 +172,7 @@ def test_arg_warn2(self): """Test deprecated_arg decorator with just `since` set.""" @deprecated_arg("b", since=self.prev_version, version_val=self.test_version) - def afoo2(a, **kwargs): + def afoo2(a, **kw): pass afoo2(1) # ok when no b provided @@ -235,6 +235,19 @@ def afoo4(a, b=None): self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2)) + def test_arg_except3_unknown(self): + """ + Test deprecated_arg decorator raises exception with `removed` set in the past. + with unknown version and kwargs + """ + + @deprecated_arg("b", removed=self.prev_version, version_val="0+untagged.1.g3131155") + def afoo4(a, b=None, **kwargs): + pass + + self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2)) + self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2, c=3)) + def test_replacement_arg(self): """ Test deprecated arg being replaced. @@ -245,10 +258,36 @@ def afoo4(a, b=None): return a self.assertEqual(afoo4(b=2), 2) - # self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2)) self.assertEqual(afoo4(1, b=2), 1) # new name is in use self.assertEqual(afoo4(a=1, b=2), 1) # prefers the new arg + def test_replacement_arg1(self): + """ + Test deprecated arg being replaced with kwargs. + """ + + @deprecated_arg("b", new_name="a", since=self.prev_version, version_val=self.test_version) + def afoo4(a, *args, **kwargs): + return a + + self.assertEqual(afoo4(b=2), 2) + self.assertEqual(afoo4(1, b=2, c=3), 1) # new name is in use + self.assertEqual(afoo4(a=1, b=2, c=3), 1) # prefers the new arg + + def test_replacement_arg2(self): + """ + Test deprecated arg (with a default value) being replaced. + """ + + @deprecated_arg("b", new_name="a", since=self.prev_version, version_val=self.test_version) + def afoo4(a, b=None, **kwargs): + return a, kwargs + + self.assertEqual(afoo4(b=2, c=3), (2, {"c": 3})) + self.assertEqual(afoo4(1, b=2, c=3), (1, {"c": 3})) # new name is in use + self.assertEqual(afoo4(a=1, b=2, c=3), (1, {"c": 3})) # prefers the new arg + self.assertEqual(afoo4(1, 2, c=3), (1, {"c": 3})) # prefers the new positional arg + if __name__ == "__main__": unittest.main() From 216e6a065492618c2bb6bfb01a5541e9ebd1f99c Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 7 Dec 2021 15:42:29 +0000 Subject: [PATCH 2/2] update based on comments Signed-off-by: Wenqi Li --- monai/utils/deprecate_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/utils/deprecate_utils.py b/monai/utils/deprecate_utils.py index 891ce900c4..4d92370c1e 100644 --- a/monai/utils/deprecate_utils.py +++ b/monai/utils/deprecate_utils.py @@ -10,6 +10,7 @@ # limitations under the License. import inspect +import sys import warnings from functools import wraps from types import FunctionType @@ -62,7 +63,7 @@ def deprecated( # if version_val.startswith("0+"): # # version unknown, set version_val to a large value (assuming the latest version) - # version_val = "100" + # version_val = f"{sys.maxsize}" if since is not None and removed is not None and not version_leq(since, removed): raise ValueError(f"since must be less or equal to removed, got since={since}, removed={removed}.") is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since) @@ -153,7 +154,7 @@ def deprecated_arg( if version_val.startswith("0+") or not f"{version_val}".strip()[0].isdigit(): # version unknown, set version_val to a large value (assuming the latest version) - version_val = "100" + version_val = f"{sys.maxsize}" if since is not None and removed is not None and not version_leq(since, removed): raise ValueError(f"since must be less or equal to removed, got since={since}, removed={removed}.") is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since)