Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from monai.data import load_net_with_metadata, save_net_with_metadata
from monai.networks import convert_to_torchscript, copy_model_state
from monai.utils import check_parent_dir, get_equivalent_dtype, min_version, optional_import
from monai.utils.misc import ensure_tuple

validate, _ = optional_import("jsonschema", name="validate")
ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError")
Expand Down Expand Up @@ -550,8 +551,10 @@ def ckpt_export(
filepath: filepath to export, if filename has no extension it becomes `.ts`.
ckpt_file: filepath of the model checkpoint to load.
meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged.
config_file: filepath of the config file, if `None`, must be provided in `args_file`.
if it is a list of file paths, the content of them will be merged.
config_file: filepath of the config file to save in TorchScript model and extract network information,
the saved key in the TorchScript model is the config filename without extension, and the saved config
value is always serialized in JSON format no matter the original file format is JSON or YAML.
it can be a single file or a list of files. if `None`, must be provided in `args_file`.
key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model
weights. if not nested checkpoint, no need to set.
args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`,
Expand Down Expand Up @@ -595,12 +598,22 @@ def ckpt_export(
# convert to TorchScript model and save with meta data, config content
net = convert_to_torchscript(model=net)

extra_files: Dict = {}
for i in ensure_tuple(config_file_):
# split the filename and directory
filename = os.path.basename(i)
# remove extension
filename, _ = os.path.splitext(filename)
if filename in extra_files:
raise ValueError(f"filename '{filename}' is given multiple times in config file list.")
extra_files[filename] = json.dumps(ConfigParser.load_config_file(i)).encode()

save_net_with_metadata(
jit_obj=net,
filename_prefix_or_stream=filepath_,
include_config_vals=False,
append_timestamp=False,
meta_values=parser.get().pop("_meta_", None),
more_extra_files={"config": json.dumps(parser.get()).encode()},
more_extra_files=extra_files,
)
logger.info(f"exported to TorchScript file: {filepath_}.")
14 changes: 11 additions & 3 deletions tests/test_bundle_ckpt_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import subprocess
import tempfile
Expand All @@ -17,6 +18,7 @@
from parameterized import parameterized

from monai.bundle import ConfigParser
from monai.data import load_net_with_metadata
from monai.networks import save_state
from tests.utils import skip_if_windows

Expand All @@ -33,7 +35,8 @@ def test_export(self, key_in_ckpt):
config_file = os.path.join(os.path.dirname(__file__), "testing_data", "inference.json")
with tempfile.TemporaryDirectory() as tempdir:
def_args = {"meta_file": "will be replaced by `meta_file` arg"}
def_args_file = os.path.join(tempdir, "def_args.json")
def_args_file = os.path.join(tempdir, "def_args.yaml")

ckpt_file = os.path.join(tempdir, "model.pt")
ts_file = os.path.join(tempdir, "model.ts")

Expand All @@ -44,11 +47,16 @@ def test_export(self, key_in_ckpt):
save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file)

cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "network_def", "--filepath", ts_file]
cmd += ["--meta_file", meta_file, "--config_file", config_file, "--ckpt_file", ckpt_file]
cmd += ["--key_in_ckpt", key_in_ckpt, "--args_file", def_args_file]
cmd += ["--meta_file", meta_file, "--config_file", f"['{config_file}','{def_args_file}']", "--ckpt_file"]
cmd += [ckpt_file, "--key_in_ckpt", key_in_ckpt, "--args_file", def_args_file]
subprocess.check_call(cmd)
self.assertTrue(os.path.exists(ts_file))

_, metadata, extra_files = load_net_with_metadata(ts_file, more_extra_files=["inference", "def_args"])
self.assertTrue("schema" in metadata)
self.assertTrue("meta_file" in json.loads(extra_files["def_args"]))
self.assertTrue("network_def" in json.loads(extra_files["inference"]))


if __name__ == "__main__":
unittest.main()