diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py index 613ad4e44a..ad20534a1f 100644 --- a/monai/bundle/config_parser.py +++ b/monai/bundle/config_parser.py @@ -23,6 +23,7 @@ from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY from monai.config import PathLike from monai.utils import ensure_tuple, look_up_option, optional_import +from monai.utils.misc import CheckKeyDuplicatesYamlLoader, check_key_duplicates if TYPE_CHECKING: import yaml @@ -400,9 +401,9 @@ def load_config_file(cls, filepath: PathLike, **kwargs: Any) -> dict: raise ValueError(f'unknown file input: "{filepath}"') with open(_filepath) as f: if _filepath.lower().endswith(cls.suffixes[0]): - return json.load(f, **kwargs) # type: ignore[no-any-return] + return json.load(f, object_pairs_hook=check_key_duplicates, **kwargs) # type: ignore[no-any-return] if _filepath.lower().endswith(cls.suffixes[1:]): - return yaml.safe_load(f, **kwargs) # type: ignore[no-any-return] + return yaml.load(f, CheckKeyDuplicatesYamlLoader, **kwargs) # type: ignore[no-any-return] raise ValueError(f"only support JSON or YAML config file so far, got name {_filepath}.") @classmethod diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 0b2df36a5b..924c00c3ec 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -24,13 +24,18 @@ from collections.abc import Callable, Iterable, Sequence from distutils.util import strtobool from pathlib import Path -from typing import Any, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, TypeVar, cast, overload import numpy as np import torch from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike -from monai.utils.module import version_leq +from monai.utils.module import optional_import, version_leq + +if TYPE_CHECKING: + from yaml import SafeLoader +else: + SafeLoader, _ = optional_import("yaml", name="SafeLoader", as_type="base") __all__ = [ "zip_with", @@ -679,3 +684,42 @@ def pprint_edges(val: Any, n_lines: int = 20) -> str: hidden_n = len(val_str) - n_lines * 2 val_str = val_str[:n_lines] + [f"\n ... omitted {hidden_n} line(s)\n\n"] + val_str[-n_lines:] return "".join(val_str) + + +def check_key_duplicates(ordered_pairs: Sequence[tuple[Any, Any]]) -> dict[Any, Any]: + """ + Checks if there is a duplicated key in the sequence of `ordered_pairs`. + If there is - it will log a warning or raise ValueError + (if configured by environmental var `MONAI_FAIL_ON_DUPLICATE_CONFIG==1`) + + Otherwise, it returns the dict made from this sequence. + + Satisfies a format for an `object_pairs_hook` in `json.load` + + Args: + ordered_pairs: sequence of (key, value) + """ + keys = set() + for k, _ in ordered_pairs: + if k in keys: + if os.environ.get("MONAI_FAIL_ON_DUPLICATE_CONFIG", "0") == "1": + raise ValueError(f"Duplicate key: `{k}`") + else: + warnings.warn(f"Duplicate key: `{k}`") + else: + keys.add(k) + return dict(ordered_pairs) + + +class CheckKeyDuplicatesYamlLoader(SafeLoader): + def construct_mapping(self, node, deep=False): + mapping = set() + for key_node, _ in node.value: + key = self.construct_object(key_node, deep=deep) + if key in mapping: + if os.environ.get("MONAI_FAIL_ON_DUPLICATE_CONFIG", "0") == "1": + raise ValueError(f"Duplicate key: `{key}`") + else: + warnings.warn(f"Duplicate key: `{key}`") + mapping.add(key) + return super().construct_mapping(node, deep) diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index c4b50daed1..d45a251f9b 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -14,7 +14,9 @@ import os import tempfile import unittest -from unittest import skipUnless +import warnings +from pathlib import Path +from unittest import mock, skipUnless import numpy as np from parameterized import parameterized @@ -27,6 +29,7 @@ from tests.utils import TimedCall _, has_tv = optional_import("torchvision", "0.8.0", min_version) +_, has_yaml = optional_import("yaml") @TimedCall(seconds=100, force_quit=True) @@ -109,6 +112,18 @@ def __call__(self, a, b): TEST_CASE_5 = [{"training": {"A": 1, "A_B": 2}, "total": "$@training#A + @training#A_B + 1"}, 4] +TEST_CASE_DUPLICATED_KEY_JSON = ["""{"key": {"unique": 1, "duplicate": 0, "duplicate": 4 } }""", "json", 1, [0, 4]] + +TEST_CASE_DUPLICATED_KEY_YAML = [ + """key: + unique: 1 + duplicate: 0 + duplicate: 4""", + "yaml", + 1, + [0, 4], +] + class TestConfigParser(unittest.TestCase): def test_config_content(self): @@ -303,6 +318,36 @@ def test_substring_reference(self, config, expected): parser = ConfigParser(config=config) self.assertEqual(parser.get_parsed_content("total"), expected) + @parameterized.expand([TEST_CASE_DUPLICATED_KEY_JSON, TEST_CASE_DUPLICATED_KEY_YAML]) + @mock.patch.dict(os.environ, {"MONAI_FAIL_ON_DUPLICATE_CONFIG": "1"}) + @skipUnless(has_yaml, "Requires pyyaml") + def test_parse_json_raise(self, config_string, extension, _, __): + with tempfile.TemporaryDirectory() as tempdir: + config_path = Path(tempdir) / f"config.{extension}" + config_path.write_text(config_string) + parser = ConfigParser() + + with self.assertRaises(ValueError) as context: + parser.read_config(config_path) + + self.assertTrue("Duplicate key: `duplicate`" in str(context.exception)) + + @parameterized.expand([TEST_CASE_DUPLICATED_KEY_JSON, TEST_CASE_DUPLICATED_KEY_YAML]) + @skipUnless(has_yaml, "Requires pyyaml") + def test_parse_json_warn(self, config_string, extension, expected_unique_val, expected_duplicate_vals): + with tempfile.TemporaryDirectory() as tempdir: + config_path = Path(tempdir) / f"config.{extension}" + config_path.write_text(config_string) + parser = ConfigParser() + + with warnings.catch_warnings(record=True) as w: + parser.read_config(config_path) + self.assertEqual(len(w), 1) + self.assertTrue("Duplicate key: `duplicate`" in str(w[-1].message)) + + self.assertEqual(parser.get_parsed_content("key#unique"), expected_unique_val) + self.assertIn(parser.get_parsed_content("key#duplicate"), expected_duplicate_vals) + if __name__ == "__main__": unittest.main()