diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index f4652fb3e7..00496dba66 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -28,6 +28,7 @@ from monai.data import load_net_with_metadata, save_net_with_metadata from monai.networks import convert_to_torchscript, copy_model_state from monai.utils import check_parent_dir, get_equivalent_dtype, min_version, optional_import +from monai.utils.misc import ensure_tuple validate, _ = optional_import("jsonschema", name="validate") ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError") @@ -550,8 +551,10 @@ def ckpt_export( filepath: filepath to export, if filename has no extension it becomes `.ts`. ckpt_file: filepath of the model checkpoint to load. meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged. - config_file: filepath of the config file, if `None`, must be provided in `args_file`. - if it is a list of file paths, the content of them will be merged. + config_file: filepath of the config file to save in TorchScript model and extract network information, + the saved key in the TorchScript model is the config filename without extension, and the saved config + value is always serialized in JSON format no matter the original file format is JSON or YAML. + it can be a single file or a list of files. if `None`, must be provided in `args_file`. key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model weights. if not nested checkpoint, no need to set. args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`, @@ -595,12 +598,22 @@ def ckpt_export( # convert to TorchScript model and save with meta data, config content net = convert_to_torchscript(model=net) + extra_files: Dict = {} + for i in ensure_tuple(config_file_): + # split the filename and directory + filename = os.path.basename(i) + # remove extension + filename, _ = os.path.splitext(filename) + if filename in extra_files: + raise ValueError(f"filename '{filename}' is given multiple times in config file list.") + extra_files[filename] = json.dumps(ConfigParser.load_config_file(i)).encode() + save_net_with_metadata( jit_obj=net, filename_prefix_or_stream=filepath_, include_config_vals=False, append_timestamp=False, meta_values=parser.get().pop("_meta_", None), - more_extra_files={"config": json.dumps(parser.get()).encode()}, + more_extra_files=extra_files, ) logger.info(f"exported to TorchScript file: {filepath_}.") diff --git a/tests/test_bundle_ckpt_export.py b/tests/test_bundle_ckpt_export.py index 0f7d0f7d35..36aa7319f0 100644 --- a/tests/test_bundle_ckpt_export.py +++ b/tests/test_bundle_ckpt_export.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import subprocess import tempfile @@ -17,6 +18,7 @@ from parameterized import parameterized from monai.bundle import ConfigParser +from monai.data import load_net_with_metadata from monai.networks import save_state from tests.utils import skip_if_windows @@ -33,7 +35,8 @@ def test_export(self, key_in_ckpt): config_file = os.path.join(os.path.dirname(__file__), "testing_data", "inference.json") with tempfile.TemporaryDirectory() as tempdir: def_args = {"meta_file": "will be replaced by `meta_file` arg"} - def_args_file = os.path.join(tempdir, "def_args.json") + def_args_file = os.path.join(tempdir, "def_args.yaml") + ckpt_file = os.path.join(tempdir, "model.pt") ts_file = os.path.join(tempdir, "model.ts") @@ -44,11 +47,16 @@ def test_export(self, key_in_ckpt): save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file) cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "network_def", "--filepath", ts_file] - cmd += ["--meta_file", meta_file, "--config_file", config_file, "--ckpt_file", ckpt_file] - cmd += ["--key_in_ckpt", key_in_ckpt, "--args_file", def_args_file] + cmd += ["--meta_file", meta_file, "--config_file", f"['{config_file}','{def_args_file}']", "--ckpt_file"] + cmd += [ckpt_file, "--key_in_ckpt", key_in_ckpt, "--args_file", def_args_file] subprocess.check_call(cmd) self.assertTrue(os.path.exists(ts_file)) + _, metadata, extra_files = load_net_with_metadata(ts_file, more_extra_files=["inference", "def_args"]) + self.assertTrue("schema" in metadata) + self.assertTrue("meta_file" in json.loads(extra_files["def_args"])) + self.assertTrue("network_def" in json.loads(extra_files["inference"])) + if __name__ == "__main__": unittest.main()