diff --git a/docs/source/config_syntax.md b/docs/source/config_syntax.md index c1e3d5cbe9..c932879b5a 100644 --- a/docs/source/config_syntax.md +++ b/docs/source/config_syntax.md @@ -168,9 +168,10 @@ _Description:_ `_requires_`, `_disabled_`, `_desc_`, and `_mode_` are optional k - `_mode_` specifies the operating mode when the component is instantiated or the callable is called. it currently supports the following values: - `"default"` (default) -- return the return value of ``_target_(**kwargs)`` - - `"partial"` -- return a partial function of ``functools.partial(_target_, **kwargs)`` (this is often - useful when some portion of the full set of arguments are supplied to the ``_target_``, and the user wants to - call it with additional arguments later). + - `"callable"` -- return a callable, either as ``_target_`` itself or, if ``kwargs`` are provided, as a + partial function of ``functools.partial(_target_, **kwargs)``. Useful for defining a class or function + that will be instantied or called later. User can pre-define some arguments to the ``_target_`` and call + it with additional arguments later. - `"debug"` -- execute with debug prompt and return the return value of ``pdb.runcall(_target_, **kwargs)``, see also [`pdb.runcall`](https://docs.python.org/3/library/pdb.html#pdb.runcall). diff --git a/monai/bundle/config_item.py b/monai/bundle/config_item.py index c6da0a73de..844d5b30bf 100644 --- a/monai/bundle/config_item.py +++ b/monai/bundle/config_item.py @@ -181,7 +181,7 @@ class ConfigComponent(ConfigItem, Instantiable): - ``"_mode_"`` (optional): operating mode for invoking the callable ``component`` defined by ``"_target_"``: - ``"default"``: returns ``component(**kwargs)`` - - ``"partial"``: returns ``functools.partial(component, **kwargs)`` + - ``"callable"``: returns ``component`` or, if ``kwargs`` are provided, ``functools.partial(component, **kwargs)`` - ``"debug"``: returns ``pdb.runcall(component, **kwargs)`` Other fields in the config content are input arguments to the python module. diff --git a/monai/utils/enums.py b/monai/utils/enums.py index a0847dd76c..b786e92151 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -411,7 +411,7 @@ class CompInitMode(StrEnum): """ DEFAULT = "default" - PARTIAL = "partial" + CALLABLE = "callable" DEBUG = "debug" diff --git a/monai/utils/module.py b/monai/utils/module.py index db62e1e72b..0dcf22fcd7 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -231,11 +231,14 @@ def instantiate(__path: str, __mode: str, **kwargs: Any) -> Any: Args: __path: if a string is provided, it's interpreted as the full path of the target class or function component. - If a callable is provided, ``__path(**kwargs)`` or ``functools.partial(__path, **kwargs)`` will be returned. + If a callable is provided, ``__path(**kwargs)`` will be invoked and returned for ``__mode="default"``. + For ``__mode="callable"``, the callable will be returned as ``__path`` or, if ``kwargs`` are provided, + as ``functools.partial(__path, **kwargs)`` for future invoking. + __mode: the operating mode for invoking the (callable) ``component`` represented by ``__path``: - ``"default"``: returns ``component(**kwargs)`` - - ``"partial"``: returns ``functools.partial(component, **kwargs)`` + - ``"callable"``: returns ``component`` or, if ``kwargs`` are provided, ``functools.partial(component, **kwargs)`` - ``"debug"``: returns ``pdb.runcall(component, **kwargs)`` kwargs: keyword arguments to the callable represented by ``__path``. @@ -259,8 +262,8 @@ def instantiate(__path: str, __mode: str, **kwargs: Any) -> Any: return component if m == CompInitMode.DEFAULT: return component(**kwargs) - if m == CompInitMode.PARTIAL: - return partial(component, **kwargs) + if m == CompInitMode.CALLABLE: + return partial(component, **kwargs) if kwargs else component if m == CompInitMode.DEBUG: warnings.warn( f"\n\npdb: instantiating component={component}, mode={m}\n" diff --git a/tests/test_config_item.py b/tests/test_config_item.py index 72f54adf0a..4909ecf6be 100644 --- a/tests/test_config_item.py +++ b/tests/test_config_item.py @@ -37,7 +37,7 @@ TEST_CASE_5 = [{"_target_": "LoadImaged", "_disabled_": "true", "keys": ["image"]}, dict] # test non-monai modules and excludes TEST_CASE_6 = [{"_target_": "torch.optim.Adam", "params": torch.nn.PReLU().parameters(), "lr": 1e-4}, torch.optim.Adam] -TEST_CASE_7 = [{"_target_": "decollate_batch", "detach": True, "pad": True, "_mode_": "partial"}, partial] +TEST_CASE_7 = [{"_target_": "decollate_batch", "detach": True, "pad": True, "_mode_": "callable"}, partial] # test args contains "name" field TEST_CASE_8 = [ {"_target_": "RandTorchVisiond", "keys": "image", "name": "ColorJitter", "brightness": 0.25}, diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index 41d7aa7a4e..8947158857 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -72,7 +72,6 @@ def case_pdb_inst(sarg=None): class TestClass: - @staticmethod def compute(a, b, func=lambda x, y: x + y): return func(a, b) @@ -127,7 +126,6 @@ def __call__(self, a, b): class TestConfigParser(unittest.TestCase): - def test_config_content(self): test_config = {"preprocessing": [{"_target_": "LoadImage"}], "dataset": {"_target_": "Dataset"}} parser = ConfigParser(config=test_config) @@ -183,7 +181,7 @@ def test_function(self, config): parser = ConfigParser(config=config, globals={"TestClass": TestClass}) for id in config: if id in ("compute", "cls_compute"): - parser[f"{id}#_mode_"] = "partial" + parser[f"{id}#_mode_"] = "callable" func = parser.get_parsed_content(id=id) self.assertTrue(id in parser.ref_resolver.resolved_content) if id == "error_func": @@ -279,7 +277,7 @@ def test_lambda_reference(self): def test_non_str_target(self): configs = { - "fwd": {"_target_": "$@model.forward", "x": "$torch.rand(1, 3, 256, 256)", "_mode_": "partial"}, + "fwd": {"_target_": "$@model.forward", "x": "$torch.rand(1, 3, 256, 256)", "_mode_": "callable"}, "model": {"_target_": "monai.networks.nets.resnet.resnet18", "pretrained": False, "spatial_dims": 2}, } self.assertTrue(callable(ConfigParser(config=configs).fwd))