diff --git a/.gitignore b/.gitignore
index fafc7c1cef..542e08e3b6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -130,6 +130,7 @@ temp/
tests/testing_data/MedNIST*
tests/testing_data/*Hippocampus*
tests/testing_data/*.tiff
+tests/testing_data/schema.json
# clang format tool
.clang-format-bin/
diff --git a/docs/requirements.txt b/docs/requirements.txt
index c2d0f22dcb..f9749e9e36 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -27,3 +27,4 @@ imagecodecs; platform_system == "Linux"
tifffile; platform_system == "Linux"
pyyaml
fire
+jsonschema
diff --git a/docs/source/bundle.rst b/docs/source/bundle.rst
index 22260d822f..0283cb909e 100644
--- a/docs/source/bundle.rst
+++ b/docs/source/bundle.rst
@@ -36,3 +36,4 @@ Model Bundle
`Scripts`
---------
.. autofunction:: run
+.. autofunction:: verify_metadata
diff --git a/docs/source/installation.md b/docs/source/installation.md
index 29cf1eab66..6ea442fa7f 100644
--- a/docs/source/installation.md
+++ b/docs/source/installation.md
@@ -190,9 +190,9 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
- The options are
```
-[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire]
+[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema]
```
which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`,
-`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, respectively.
+`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, respectively.
- `pip install 'monai[all]'` installs all the optional dependencies.
diff --git a/environment-dev.yml b/environment-dev.yml
index 4491f87ceb..a361262930 100644
--- a/environment-dev.yml
+++ b/environment-dev.yml
@@ -44,6 +44,7 @@ dependencies:
- tensorboardX
- pyyaml
- fire
+ - jsonschema
- pip
- pip:
# pip for itk as conda-forge version only up to v5.1
diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py
index b411406e84..6f84800208 100644
--- a/monai/bundle/__init__.py
+++ b/monai/bundle/__init__.py
@@ -12,5 +12,5 @@
from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable
from .config_parser import ConfigParser
from .reference_resolver import ReferenceResolver
-from .scripts import run
+from .scripts import run, verify_metadata
from .utils import EXPR_KEY, ID_REF_KEY, ID_SEP_KEY, MACRO_KEY
diff --git a/monai/bundle/__main__.py b/monai/bundle/__main__.py
index 7a87030bec..45cd89bfdd 100644
--- a/monai/bundle/__main__.py
+++ b/monai/bundle/__main__.py
@@ -10,7 +10,7 @@
# limitations under the License.
-from monai.bundle.scripts import run
+from monai.bundle.scripts import run, verify_metadata
if __name__ == "__main__":
from monai.utils import optional_import
diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py
index ebfd3e54ac..1f3165dee3 100644
--- a/monai/bundle/scripts.py
+++ b/monai/bundle/scripts.py
@@ -10,18 +10,28 @@
# limitations under the License.
import pprint
+import re
from typing import Dict, Optional, Sequence, Union
+from monai.apps.utils import download_url, get_logger
from monai.bundle.config_parser import ConfigParser
+from monai.config import PathLike
+from monai.utils import check_parent_dir, optional_import
+validate, _ = optional_import("jsonschema", name="validate")
+ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError")
-def _update_default_args(args: Optional[Union[str, Dict]] = None, **kwargs) -> Dict:
+logger = get_logger(module_name=__name__)
+
+
+def _update_args(args: Optional[Union[str, Dict]] = None, ignore_none: bool = True, **kwargs) -> Dict:
"""
Update the `args` with the input `kwargs`.
For dict data, recursively update the content based on the keys.
Args:
args: source args to update.
+ ignore_none: whether to ignore input args with None value, default to `True`.
kwargs: destination args to update.
"""
@@ -32,14 +42,26 @@ def _update_default_args(args: Optional[Union[str, Dict]] = None, **kwargs) -> D
# recursively update the default args with new args
for k, v in kwargs.items():
- args_[k] = _update_default_args(args_[k], **v) if isinstance(v, dict) and isinstance(args_.get(k), dict) else v
+ if ignore_none and v is None:
+ continue
+ if isinstance(v, dict) and isinstance(args_.get(k), dict):
+ args_[k] = _update_args(args_[k], ignore_none, **v)
+ else:
+ args_[k] = v
return args_
+def _log_input_summary(tag: str, args: Dict):
+ logger.info(f"\n--- input summary of monai.bundle.scripts.{tag} ---")
+ for name, val in args.items():
+ logger.info(f"> {name}: {pprint.pformat(val)}")
+ logger.info("---\n\n")
+
+
def run(
+ runner_id: Optional[str] = None,
meta_file: Optional[Union[str, Sequence[str]]] = None,
config_file: Optional[Union[str, Sequence[str]]] = None,
- target_id: Optional[str] = None,
args_file: Optional[str] = None,
**override,
):
@@ -51,58 +73,46 @@ def run(
.. code-block:: bash
# Execute this module as a CLI entry:
- python -m monai.bundle run --meta_file --config_file --target_id trainer
+ python -m monai.bundle run trainer --meta_file --config_file
# Override config values at runtime by specifying the component id and its new value:
- python -m monai.bundle run --net#input_chns 1 ...
+ python -m monai.bundle run trainer --net#input_chns 1 ...
# Override config values with another config file `/path/to/another.json`:
- python -m monai.bundle run --net %/path/to/another.json ...
+ python -m monai.bundle run evaluator --net %/path/to/another.json ...
# Override config values with part content of another config file:
- python -m monai.bundle run --net %/data/other.json#net_arg ...
+ python -m monai.bundle run trainer --net %/data/other.json#net_arg ...
# Set default args of `run` in a JSON / YAML file, help to record and simplify the command line.
# Other args still can override the default args at runtime:
python -m monai.bundle run --args_file "/workspace/data/args.json" --config_file
Args:
+ runner_id: ID name of the runner component or workflow, it must have a `run` method.
meta_file: filepath of the metadata file, if `None`, must be provided in `args_file`.
if it is a list of file paths, the content of them will be merged.
config_file: filepath of the config file, if `None`, must be provided in `args_file`.
if it is a list of file paths, the content of them will be merged.
- target_id: ID name of the target component or workflow, it must have a `run` method.
args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`,
- `target_id` and override pairs. so that the command line inputs can be simplified.
+ `runner_id` and override pairs. so that the command line inputs can be simplified.
override: id-value pairs to override or add the corresponding config content.
e.g. ``--net#input_chns 42``.
"""
- k_v = zip(["meta_file", "config_file", "target_id"], [meta_file, config_file, target_id])
- for k, v in k_v:
- if v is not None:
- override[k] = v
-
- full_kv = zip(
- ("meta_file", "config_file", "target_id", "args_file", "override"),
- (meta_file, config_file, target_id, args_file, override),
- )
- print("\n--- input summary of monai.bundle.scripts.run ---")
- for name, val in full_kv:
- print(f"> {name}: {pprint.pformat(val)}")
- print("---\n\n")
- _args = _update_default_args(args=args_file, **override)
+ _args = _update_args(args=args_file, runner_id=runner_id, meta_file=meta_file, config_file=config_file, **override)
for k in ("meta_file", "config_file"):
if k not in _args:
raise ValueError(f"{k} is required for 'monai.bundle run'.\n{run.__doc__}")
+ _log_input_summary(tag="run", args=_args)
parser = ConfigParser()
parser.read_config(f=_args.pop("config_file"))
parser.read_meta(f=_args.pop("meta_file"))
- id = _args.pop("target_id", "")
+ id = _args.pop("runner_id", "")
- # the rest key-values in the args are to override config content
+ # the rest key-values in the _args are to override config content
for k, v in _args.items():
parser[k] = v
@@ -110,3 +120,55 @@ def run(
if not hasattr(workflow, "run"):
raise ValueError(f"The parsed workflow {type(workflow)} does not have a `run` method.\n{run.__doc__}")
workflow.run()
+
+
+def verify_metadata(
+ meta_file: Optional[Union[str, Sequence[str]]] = None,
+ filepath: Optional[PathLike] = None,
+ create_dir: Optional[bool] = None,
+ hash_val: Optional[str] = None,
+ args_file: Optional[str] = None,
+ **kwargs,
+):
+ """
+ Verify the provided `metadata` file based on the predefined `schema`.
+ `metadata` content must contain the `schema` field for the URL of shcema file to download.
+ The schema standard follows: http://json-schema.org/.
+
+ Args:
+ meta_file: filepath of the metadata file to verify, if `None`, must be provided in `args_file`.
+ if it is a list of file paths, the content of them will be merged.
+ filepath: file path to store the downloaded schema.
+ create_dir: whether to create directories if not existing, default to `True`.
+ hash_val: if not None, define the hash value to verify the downloaded schema file.
+ args_file: a JSON or YAML file to provide default values for all the args in this function.
+ so that the command line inputs can be simplified.
+ kwargs: other arguments for `jsonschema.validate()`. for more details:
+ https://python-jsonschema.readthedocs.io/en/stable/validate/#jsonschema.validate.
+
+ """
+
+ _args = _update_args(
+ args=args_file, meta_file=meta_file, filepath=filepath, create_dir=create_dir, hash_val=hash_val, **kwargs
+ )
+ _log_input_summary(tag="verify_metadata", args=_args)
+
+ filepath_ = _args.pop("filepath")
+ create_dir_ = _args.pop("create_dir", True)
+ check_parent_dir(path=filepath_, create_dir=create_dir_)
+
+ metadata = ConfigParser.load_config_files(files=_args.pop("meta_file"))
+ url = metadata.get("schema")
+ if url is None:
+ raise ValueError("must provide the `schema` field in the metadata for the URL of schema file.")
+ download_url(url=url, filepath=filepath_, hash_val=_args.pop("hash_val", None), hash_type="md5", progress=True)
+ schema = ConfigParser.load_config_file(filepath=filepath_)
+
+ try:
+ # the rest key-values in the _args are for `validate` API
+ validate(instance=metadata, schema=schema, **_args)
+ except ValidationError as e:
+ # as the error message is very long, only extract the key information
+ logger.info(re.compile(r".*Failed validating", re.S).findall(str(e))[0] + f" against schema `{url}`.")
+ return
+ logger.info("metadata is verified with no error.")
diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py
index 636ea15c8d..b8c462a8b7 100644
--- a/monai/utils/__init__.py
+++ b/monai/utils/__init__.py
@@ -41,6 +41,7 @@
from .misc import (
MAX_SEED,
ImageMetaKey,
+ check_parent_dir,
copy_to_device,
ensure_tuple,
ensure_tuple_rep,
diff --git a/monai/utils/misc.py b/monai/utils/misc.py
index 1c79562f07..36ba7722b8 100644
--- a/monai/utils/misc.py
+++ b/monai/utils/misc.py
@@ -50,6 +50,7 @@
"is_module_ver_at_least",
"has_option",
"sample_slices",
+ "check_parent_dir",
"save_obj",
]
@@ -400,6 +401,25 @@ def sample_slices(data: NdarrayOrTensor, dim: int = 1, as_indices: bool = True,
return data[tuple(slices)]
+def check_parent_dir(path: PathLike, create_dir: bool = True):
+ """
+ Utility to check whether the parent directory of the `path` exists.
+
+ Args:
+ path: input path to check the parent directory.
+ create_dir: if True, when the parent directory doesn't exist, create the directory,
+ otherwise, raise exception.
+
+ """
+ path = Path(path)
+ path_dir = path.parent
+ if not path_dir.exists():
+ if create_dir:
+ path_dir.mkdir(parents=True)
+ else:
+ raise ValueError(f"the directory of specified path does not exist: `{path_dir}`.")
+
+
def save_obj(
obj, path: PathLike, create_dir: bool = True, atomic: bool = True, func: Optional[Callable] = None, **kwargs
):
@@ -421,12 +441,7 @@ def save_obj(
"""
path = Path(path)
- path_dir = path.parent
- if not path_dir.exists():
- if create_dir:
- path_dir.mkdir(parents=True)
- else:
- raise ValueError(f"the directory of specified path is not existing: {path_dir}.")
+ check_parent_dir(path=path, create_dir=create_dir)
if path.exists():
# remove the existing file
os.remove(path)
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 2b3786d1f3..4d2829f930 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -46,3 +46,4 @@ tensorboardX
types-PyYAML
pyyaml
fire
+jsonschema
diff --git a/setup.cfg b/setup.cfg
index aa5eae07a9..a7d597d6bd 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -52,6 +52,7 @@ all =
tensorboardX
pyyaml
fire
+ jsonschema
nibabel =
nibabel
skimage =
@@ -98,6 +99,8 @@ pyyaml =
pyyaml
fire =
fire
+jsonschema =
+ jsonschema
[flake8]
select = B,C,E,F,N,P,T4,W,B9
diff --git a/tests/min_tests.py b/tests/min_tests.py
index 8f01ee1826..bb47403090 100644
--- a/tests/min_tests.py
+++ b/tests/min_tests.py
@@ -160,6 +160,7 @@ def run_testsuit():
"test_prepare_batch_default_dist",
"test_parallel_execution_dist",
"test_bundle_run",
+ "test_bundle_verify_metadata",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"
diff --git a/tests/test_bundle_run.py b/tests/test_bundle_run.py
index 75002d3631..b0ce353240 100644
--- a/tests/test_bundle_run.py
+++ b/tests/test_bundle_run.py
@@ -64,15 +64,14 @@ def test_shape(self, config_file, expected_shape):
else:
override = f"--network %{overridefile1}#move_net --dataset#_target_ %{overridefile2}"
# test with `monai.bundle` as CLI entry directly
- cmd = "-m monai.bundle run --target_id evaluator"
- cmd += f" --postprocessing#transforms#2#output_postfix seg {override}"
+ cmd = f"-m monai.bundle run evaluator --postprocessing#transforms#2#output_postfix seg {override}"
la = [f"{sys.executable}"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file]
ret = subprocess.check_call(la + ["--args_file", def_args_file])
self.assertEqual(ret, 0)
self.assertTupleEqual(saver(os.path.join(tempdir, "image", "image_seg.nii.gz")).shape, expected_shape)
# here test the script with `google fire` tool as CLI
- cmd = "-m fire monai.bundle.scripts run --target_id evaluator"
+ cmd = "-m fire monai.bundle.scripts run --runner_id evaluator"
cmd += f" --evaluator#amp False {override}"
la = [f"{sys.executable}"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file]
ret = subprocess.check_call(la)
diff --git a/tests/test_bundle_verify_metadata.py b/tests/test_bundle_verify_metadata.py
new file mode 100644
index 0000000000..7e2bd02209
--- /dev/null
+++ b/tests/test_bundle_verify_metadata.py
@@ -0,0 +1,69 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import logging
+import os
+import subprocess
+import sys
+import tempfile
+import unittest
+
+from parameterized import parameterized
+
+from monai.bundle import ConfigParser
+from tests.utils import skip_if_windows
+
+TEST_CASE_1 = [
+ os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json"),
+ os.path.join(os.path.dirname(__file__), "testing_data", "schema.json"),
+]
+
+
+@skip_if_windows
+class TestVerifyMetaData(unittest.TestCase):
+ @parameterized.expand([TEST_CASE_1])
+ def test_verify(self, meta_file, schema_file):
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
+ with tempfile.TemporaryDirectory() as tempdir:
+ def_args = {"meta_file": "will be replaced by `meta_file` arg"}
+ def_args_file = os.path.join(tempdir, "def_args.json")
+ ConfigParser.export_config_file(config=def_args, filepath=def_args_file)
+
+ hash_val = "b11acc946148c0186924f8234562b947"
+
+ cmd = [sys.executable, "-m", "monai.bundle", "verify_metadata", "--meta_file", meta_file]
+ cmd += ["--filepath", schema_file, "--hash_val", hash_val, "--args_file", def_args_file]
+ ret = subprocess.check_call(cmd)
+ self.assertEqual(ret, 0)
+
+ def test_verify_error(self):
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
+ with tempfile.TemporaryDirectory() as tempdir:
+ filepath = os.path.join(tempdir, "schema.json")
+ metafile = os.path.join(tempdir, "metadata.json")
+ with open(metafile, "w") as f:
+ json.dump(
+ {
+ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/"
+ "download/0.8.1/meta_schema_202203130950.json",
+ "wrong_meta": "wrong content",
+ },
+ f,
+ )
+
+ cmd = [sys.executable, "-m", "monai.bundle", "verify_metadata", metafile, "--filepath", filepath]
+ ret = subprocess.check_call(cmd)
+ self.assertEqual(ret, 0)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/testing_data/metadata.json b/tests/testing_data/metadata.json
new file mode 100644
index 0000000000..97bc218f5e
--- /dev/null
+++ b/tests/testing_data/metadata.json
@@ -0,0 +1,77 @@
+{
+ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_202203130950.json",
+ "version": "0.1.0",
+ "changelog": {
+ "0.1.0": "complete the model package",
+ "0.0.1": "initialize the model package structure"
+ },
+ "monai_version": "0.8.0",
+ "pytorch_version": "1.10.0",
+ "numpy_version": "1.21.2",
+ "optional_packages_version": {
+ "nibabel": "3.2.1"
+ },
+ "task": "Decathlon spleen segmentation",
+ "description": "A pre-trained model for volumetric (3D) segmentation of the spleen from CT image",
+ "authorship": "MONAI team",
+ "copyright": "Copyright (c) MONAI Consortium",
+ "data_source": "Task09_Spleen.tar from http://medicaldecathlon.com/",
+ "data_type": "dicom",
+ "dataset_dir": "/workspace/data/Task09_Spleen",
+ "image_classes": "single channel data, intensity scaled to [0, 1]",
+ "label_classes": "single channel data, 1 is spleen, 0 is everything else",
+ "pred_classes": "2 channels OneHot data, channel 1 is spleen, channel 0 is background",
+ "eval_metrics": {
+ "mean_dice": 0.96
+ },
+ "intended_use": "This is an example, not to be used for diagnostic purposes",
+ "references": [
+ "Xia, Yingda, et al. '3D Semi-Supervised Learning with Uncertainty-Aware Multi-View Co-Training. arXiv preprint arXiv:1811.12506 (2018). https://arxiv.org/abs/1811.12506.",
+ "Kerfoot E., Clough J., Oksuz I., Lee J., King A.P., Schnabel J.A. (2019) Left-Ventricle Quantification Using Residual U-Net. In: Pop M. et al. (eds) Statistical Atlases and Computational Models of the Heart. Atrial Segmentation and LV Quantification Challenges. STACOM 2018. Lecture Notes in Computer Science, vol 11395. Springer, Cham. https://doi.org/10.1007/978-3-030-12029-0_40"
+ ],
+ "network_data_format": {
+ "inputs": {
+ "image": {
+ "type": "image",
+ "format": "magnitude",
+ "num_channels": 1,
+ "spatial_shape": [
+ 160,
+ 160,
+ 160
+ ],
+ "dtype": "float32",
+ "value_range": [
+ 0,
+ 1
+ ],
+ "is_patch_data": false,
+ "channel_def": {
+ "0": "image"
+ }
+ }
+ },
+ "outputs": {
+ "pred": {
+ "type": "image",
+ "format": "segmentation",
+ "num_channels": 2,
+ "spatial_shape": [
+ 160,
+ 160,
+ 160
+ ],
+ "dtype": "float32",
+ "value_range": [
+ 0,
+ 1
+ ],
+ "is_patch_data": false,
+ "channel_def": {
+ "0": "background",
+ "1": "spleen"
+ }
+ }
+ }
+ }
+}