diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 56146546e8..3dfe40f917 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -35,6 +35,7 @@ from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow from monai.config import IgniteInfo, PathLike from monai.data import load_net_with_metadata, save_net_with_metadata +from functools import partial from monai.networks import ( convert_to_onnx, convert_to_torchscript, @@ -1159,6 +1160,7 @@ def verify_net_in_out( def _export( converter: Callable, + saver: Callable, parser: ConfigParser, net_id: str, filepath: str, @@ -1173,6 +1175,8 @@ def _export( Args: converter: a callable object that takes a torch.nn.module and kwargs as input and converts the module to another type. + saver: a callable object that takes the converted model and a filepath as input and + saves the model to the specified location. parser: a ConfigParser of the bundle to be converted. net_id: ID name of the network component in the parser, it must be `torch.nn.Module`. filepath: filepath to export, if filename has no extension, it becomes `.ts`. @@ -1212,14 +1216,12 @@ def _export( # add .json extension to all extra files which are always encoded as JSON extra_files = {k + ".json": v for k, v in extra_files.items()} - save_net_with_metadata( - jit_obj=net, - filename_prefix_or_stream=filepath, - include_config_vals=False, - append_timestamp=False, - meta_values=parser.get().pop("_meta_", None), - more_extra_files=extra_files, + saver( + jit_obj = net, + filename_prefix_or_stream = filepath, + more_extra_files = extra_files, ) + logger.info(f"exported to file: {filepath}.") @@ -1318,17 +1320,28 @@ def onnx_export( input_shape_ = _get_fake_input_shape(parser=parser) inputs_ = [torch.rand(input_shape_)] - net = parser.get_parsed_content(net_id_) - if has_ignite: - # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver - Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_) - else: - ckpt = torch.load(ckpt_file_) - copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_]) converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_}) - onnx_model = convert_to_onnx(model=net, **converter_kwargs_) - onnx.save(onnx_model, filepath_) + + def save_onnx( + jit_obj: torch.nn.Module, + filename_prefix_or_stream: str | IO[Any], + more_extra_files: None = None, + ) -> None: + + onnx.save(jit_obj, filename_prefix_or_stream) + + _export( + convert_to_onnx, + save_onnx, + parser, + net_id=net_id_, + filepath=filepath_, + ckpt_file=ckpt_file_, + config_file=config_file_, + key_in_ckpt=key_in_ckpt_, + **converter_kwargs_, + ) def ckpt_export( @@ -1449,8 +1462,14 @@ def ckpt_export( converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_}) # Use the given converter to convert a model and save with metadata, config content + + save_ts = partial(save_net_with_metadata, include_config_vals=False, + append_timestamp=False, + meta_values=parser.get().pop("_meta_", None)) + _export( convert_to_torchscript, + save_ts, parser, net_id=net_id_, filepath=filepath_, @@ -1620,8 +1639,13 @@ def trt_export( } converter_kwargs_.update(trt_api_parameters) + save_ts = partial(save_net_with_metadata, include_config_vals=False, + append_timestamp=False, + meta_values=parser.get().pop("_meta_", None)) + _export( convert_to_trt, + save_ts, parser, net_id=net_id_, filepath=filepath_,