Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions monai/bundle/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import re
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union

from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem
from monai.bundle.reference_resolver import ReferenceResolver
Expand Down Expand Up @@ -76,6 +76,11 @@ class ConfigParser:
The current supported globals and alias names are
``{"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"}``.
These are MONAI's minimal dependencies. Additional packages could be included with `globals={"itk": "itk"}`.
item_types: list of supported config item types, must be subclass of `ConfigComponent`,
`ConfigExpression`, `ConfigItem`, will check the types in order for every config item.
if `None`, default to: ``(ConfigComponent, ConfigExpression, ConfigItem)``.
resolver: manage a set of ``ConfigItem`` and resolve the references between them.
if `None`, will create a default `ReferenceResolver` instance.

See also:

Expand All @@ -96,6 +101,8 @@ def __init__(
config: Any = None,
excludes: Optional[Union[Sequence[str], str]] = None,
globals: Optional[Dict[str, Any]] = None,
item_types: Optional[Union[Sequence[Type[ConfigItem]], Type[ConfigItem]]] = None,
resolver: Optional[ReferenceResolver] = None,
):
self.config = None
self.globals: Dict[str, Any] = {}
Expand All @@ -105,9 +112,17 @@ def __init__(
if _globals is not None:
for k, v in _globals.items():
self.globals[k] = optional_import(v)[0] if isinstance(v, str) else v
self.item_types = (
(ConfigComponent, ConfigExpression, ConfigItem) if item_types is None else ensure_tuple(item_types)
)

self.locator = ComponentLocator(excludes=excludes)
self.ref_resolver = ReferenceResolver()
if resolver is not None:
if not isinstance(resolver, ReferenceResolver):
raise TypeError(f"resolver must be subclass of ReferenceResolver, but got: {type(resolver)}.")
self.ref_resolver = resolver
else:
self.ref_resolver = ReferenceResolver()
if config is None:
config = {self.meta_key: {}}
self.set(config=config)
Expand Down Expand Up @@ -309,12 +324,20 @@ def _do_parse(self, config, id: str = ""):

# copy every config item to make them independent and add them to the resolver
item_conf = deepcopy(config)
if ConfigComponent.is_instantiable(item_conf):
self.ref_resolver.add_item(ConfigComponent(config=item_conf, id=id, locator=self.locator))
elif ConfigExpression.is_expression(item_conf):
self.ref_resolver.add_item(ConfigExpression(config=item_conf, id=id, globals=self.globals))
else:
self.ref_resolver.add_item(ConfigItem(config=item_conf, id=id))
for item_type in self.item_types:
if issubclass(item_type, ConfigComponent):
if item_type.is_instantiable(item_conf):
return self.ref_resolver.add_item(item_type(config=item_conf, id=id, locator=self.locator))
continue
if issubclass(item_type, ConfigExpression):
if item_type.is_expression(item_conf):
return self.ref_resolver.add_item(item_type(config=item_conf, id=id, globals=self.globals))
continue
if issubclass(item_type, ConfigItem):
return self.ref_resolver.add_item(item_type(config=item_conf, id=id))
raise TypeError(
f"item type must be subclass of `ConfigComponent`, `ConfigExpression`, `ConfigItem`, got: {item_type}."
)

@classmethod
def load_config_file(cls, filepath: PathLike, **kwargs):
Expand Down
15 changes: 12 additions & 3 deletions tests/test_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from parameterized import parameterized

from monai.bundle.config_parser import ConfigParser
from monai.bundle.config_item import ConfigComponent, ConfigExpression, ConfigItem
from monai.bundle.config_parser import ConfigParser, ReferenceResolver
from monai.data import DataLoader, Dataset
from monai.transforms import Compose, LoadImaged, RandTorchVisiond
from monai.utils import min_version, optional_import
Expand Down Expand Up @@ -59,6 +60,10 @@ def __call__(self, a, b):
return self.compute(a, b)


class TestConfigComponent(ConfigComponent):
pass


TEST_CASE_2 = [
{
"basic_func": "$lambda x, y: x + y",
Expand Down Expand Up @@ -106,7 +111,7 @@ def test_config_content(self):
@parameterized.expand([TEST_CASE_1])
@skipUnless(has_tv, "Requires torchvision >= 0.8.0.")
def test_parse(self, config, expected_ids, output_types):
parser = ConfigParser(config=config, globals={"monai": "monai"})
parser = ConfigParser(config=config, globals={"monai": "monai"}, resolver=ReferenceResolver())
# test lazy instantiation with original config content
parser["transform"]["transforms"][0]["keys"] = "label1"
self.assertEqual(parser.get_parsed_content(id="transform#transforms#0").keys[0], "label1")
Expand All @@ -122,7 +127,11 @@ def test_parse(self, config, expected_ids, output_types):

@parameterized.expand([TEST_CASE_2])
def test_function(self, config):
parser = ConfigParser(config=config, globals={"TestClass": TestClass})
parser = ConfigParser(
config=config,
globals={"TestClass": TestClass},
item_types=(TestConfigComponent, ConfigExpression, ConfigItem),
)
for id in config:
func = parser.get_parsed_content(id=id)
self.assertTrue(id in parser.ref_resolver.resolved_content)
Expand Down