From 1dfbb1c68b4955fde74c06f939ba9a772d46b1f6 Mon Sep 17 00:00:00 2001 From: Han123su Date: Mon, 10 Jun 2024 23:07:18 +0800 Subject: [PATCH 01/31] Define save_onnx function. Signed-off-by: Han123su --- monai/data/torchscript_utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index cabf06ce89..edc3df2d83 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -18,6 +18,7 @@ from typing import IO, Any import torch +import onnx from monai.config import get_config_values from monai.utils import JITMetadataKeys @@ -100,6 +101,20 @@ def save_net_with_metadata( torch.jit.save(jit_obj, filename_prefix_or_stream, extra_files) +def save_onnx( + model_obj: onnx.ModelProto, + filepath: str | IO[Any] +) -> None: + """ + Save the ONNX model to the given file or stream. + + Args: + onnx_model: ONNX model to save. + filepath: Filename or file-like stream object to save the ONNX model. + """ + onnx.save(model_obj, filepath) + + def load_net_with_metadata( filename_prefix_or_stream: str | IO[Any], map_location: torch.device | None = None, From 4ce9b2c9a84016787488c79ded01cff728c505be Mon Sep 17 00:00:00 2001 From: Han123su Date: Mon, 10 Jun 2024 23:07:55 +0800 Subject: [PATCH 02/31] Add saver_type.py (Saver class). Signed-off-by: Han123su --- monai/data/saver_type.py | 50 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 monai/data/saver_type.py diff --git a/monai/data/saver_type.py b/monai/data/saver_type.py new file mode 100644 index 0000000000..0a387b120a --- /dev/null +++ b/monai/data/saver_type.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA +from monai.data import save_net_with_metadata +from monai.data.torchscript_utils import save_onnx + + +class OnnxSaver: + def save(self, model_obj, filepath): + save_onnx( + model_obj=model_obj, + filepath=filepath + ) + +class CkptSaver: + def __init__(self, include_config_vals: bool = True, append_timestamp: bool = False, meta_values: Mapping[str, Any] | None = None, more_extra_files: Mapping[str, bytes] | None = None): + self.include_config_vals = include_config_vals + self.append_timestamp = append_timestamp + self.meta_values = meta_values + self.more_extra_files = more_extra_files + + def save(self, model_obj, filepath): + save_net_with_metadata( + model_obj=model_obj, + filepath=filepath, + include_config_vals=self.include_config_vals, + append_timestamp=self.append_timestamp, + meta_values=self.meta_values, + more_extra_files=self.more_extra_files + ) + +class TrtSaver: + def __init__(self, include_config_vals: bool = True, append_timestamp: bool = False, meta_values: Mapping[str, Any] | None = None, more_extra_files: Mapping[str, bytes] | None = None): + self.include_config_vals = include_config_vals + self.append_timestamp = append_timestamp + self.meta_values = meta_values + self.more_extra_files = more_extra_files + + def save(self, model_obj, filepath): + save_net_with_metadata( + model_obj=model_obj, + filepath=filepath, + include_config_vals=self.include_config_vals, + append_timestamp=self.append_timestamp, + meta_values=self.meta_values, + more_extra_files=self.more_extra_files + ) From 1f0fe900f69b65b1a4e1d7a5d1b74f63985bf934 Mon Sep 17 00:00:00 2001 From: Han123su Date: Mon, 10 Jun 2024 23:12:54 +0800 Subject: [PATCH 03/31] Modify the parameter name in save_net_with_metadata function. Signed-off-by: Han123su --- monai/data/torchscript_utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index edc3df2d83..d5a6b56d46 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -27,15 +27,15 @@ def save_net_with_metadata( - jit_obj: torch.nn.Module, - filename_prefix_or_stream: str | IO[Any], + model_obj: torch.nn.Module, + filepath: str | IO[Any], include_config_vals: bool = True, append_timestamp: bool = False, meta_values: Mapping[str, Any] | None = None, more_extra_files: Mapping[str, bytes] | None = None, ) -> None: """ - Save the JIT object (script or trace produced object) `jit_obj` to the given file or stream with metadata + Save the JIT object (script or trace produced object) `model_obj` to the given file or stream with metadata included as a JSON file. The Torchscript format is a zip file which can contain extra file data which is used here as a mechanism for storing metadata about the network being saved. The data in `meta_values` should be compatible with conversion to JSON using the standard library function `dumps`. The intent is this metadata will @@ -63,8 +63,8 @@ def save_net_with_metadata( Args: - jit_obj: object to save, should be generated by `script` or `trace`. - filename_prefix_or_stream: filename or file-like stream object, if filename has no extension it becomes `.ts`. + model_obj: object to save, should be generated by `script` or `trace`. + filepath: filename or file-like stream object, if filename has no extension it becomes `.ts`. include_config_vals: if True, MONAI, Pytorch, and Numpy versions are included in metadata. append_timestamp: if True, a timestamp for "now" is appended to the file's name before the extension. meta_values: metadata values to store with the object, not limited just to keys in `JITMetadataKeys`. @@ -88,17 +88,17 @@ def save_net_with_metadata( if more_extra_files is not None: extra_files.update(more_extra_files) - if isinstance(filename_prefix_or_stream, str): - filename_no_ext, ext = os.path.splitext(filename_prefix_or_stream) + if isinstance(filepath, str): + filename_no_ext, ext = os.path.splitext(filepath) if ext == "": ext = ".ts" if append_timestamp: - filename_prefix_or_stream = now.strftime(f"{filename_no_ext}_%Y%m%d%H%M%S{ext}") + filepath = now.strftime(f"{filename_no_ext}_%Y%m%d%H%M%S{ext}") else: - filename_prefix_or_stream = filename_no_ext + ext + filepath = filename_no_ext + ext - torch.jit.save(jit_obj, filename_prefix_or_stream, extra_files) + torch.jit.save(model_obj, filepath, extra_files) def save_onnx( From b799d1482793aa3b76e215c7a9e942fa9c7b0362 Mon Sep 17 00:00:00 2001 From: Han123su Date: Mon, 10 Jun 2024 23:14:52 +0800 Subject: [PATCH 04/31] Import new modules as needed. Signed-off-by: Han123su --- monai/bundle/scripts.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 598d938cbd..2de126abd8 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -34,6 +34,8 @@ from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow from monai.config import IgniteInfo, PathLike from monai.data import load_net_with_metadata, save_net_with_metadata +from monai.data.torchscript_utils import save_onnx +from monai.data.saver_type import CkptSaver, OnnxSaver, TrtSaver from monai.networks import ( convert_to_onnx, convert_to_torchscript, From 50e9ae5593834b3e2141b891ae2e5960cfd34e13 Mon Sep 17 00:00:00 2001 From: Han123su Date: Mon, 10 Jun 2024 23:21:11 +0800 Subject: [PATCH 05/31] Modify export function and Add saver (mainly conversion and saving). Signed-off-by: Han123su --- monai/bundle/scripts.py | 52 +++++++---------------------------------- 1 file changed, 8 insertions(+), 44 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 2de126abd8..bd29055eb6 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1063,12 +1063,9 @@ def verify_net_in_out( def _export( converter: Callable, - parser: ConfigParser, - net_id: str, + saver: Callable, + net: str, filepath: str, - ckpt_file: str, - config_file: str, - key_in_ckpt: str, **kwargs: Any, ) -> None: """ @@ -1077,52 +1074,19 @@ def _export( Args: converter: a callable object that takes a torch.nn.module and kwargs as input and converts the module to another type. - parser: a ConfigParser of the bundle to be converted. - net_id: ID name of the network component in the parser, it must be `torch.nn.Module`. + saver: a callable object that takes the converted model and a filepath as input and + saves the model. + net: the network model to be converted. filepath: filepath to export, if filename has no extension, it becomes `.ts`. - ckpt_file: filepath of the model checkpoint to load. - config_file: filepath of the config file to save in the converted model,the saved key in the converted - model is the config filename without extension, and the saved config value is always serialized in - JSON format no matter the original file format is JSON or YAML. it can be a single file or a list - of files. - key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model - weights. if not nested checkpoint, no need to set. kwargs: key arguments for the converter. """ - net = parser.get_parsed_content(net_id) - if has_ignite: - # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver - Checkpoint.load_objects(to_load={key_in_ckpt: net}, checkpoint=ckpt_file) - else: - ckpt = torch.load(ckpt_file) - copy_model_state(dst=net, src=ckpt if key_in_ckpt == "" else ckpt[key_in_ckpt]) - # Use the given converter to convert a model and save with metadata, config content net = converter(model=net, **kwargs) - extra_files: dict = {} - for i in ensure_tuple(config_file): - # split the filename and directory - filename = os.path.basename(i) - # remove extension - filename, _ = os.path.splitext(filename) - # because all files are stored as JSON their name parts without extension must be unique - if filename in extra_files: - raise ValueError(f"Filename part '{filename}' is given multiple times in config file list.") - # the file may be JSON or YAML but will get loaded and dumped out again as JSON - extra_files[filename] = json.dumps(ConfigParser.load_config_file(i)).encode() - - # add .json extension to all extra files which are always encoded as JSON - extra_files = {k + ".json": v for k, v in extra_files.items()} - - save_net_with_metadata( - jit_obj=net, - filename_prefix_or_stream=filepath, - include_config_vals=False, - append_timestamp=False, - meta_values=parser.get().pop("_meta_", None), - more_extra_files=extra_files, + saver( + model_obj=net, + filepath=filepath, ) logger.info(f"exported to file: {filepath}.") From e1fbacd8c3521f9e950c5290c32ba63a45616f4e Mon Sep 17 00:00:00 2001 From: Han123su Date: Mon, 10 Jun 2024 23:23:35 +0800 Subject: [PATCH 06/31] Modify onnx_export and call _export() instead. Signed-off-by: Han123su --- monai/bundle/scripts.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index bd29055eb6..1f459552be 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1195,8 +1195,16 @@ def onnx_export( copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_]) converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_}) - onnx_model = convert_to_onnx(model=net, **converter_kwargs_) - onnx.save(onnx_model, filepath_) + + onnx_saver = OnnxSaver() + + _export( + convert_to_onnx, + onnx_saver.save, + net=net, + filepath=filepath_, + **converter_kwargs_, + ) def ckpt_export( From b390dbb1239eddbe6fdf19e0d417ffff6e82ca9d Mon Sep 17 00:00:00 2001 From: Han123su Date: Mon, 10 Jun 2024 23:25:42 +0800 Subject: [PATCH 07/31] Modify ckpt_export: do specific processing before calling _export. Signed-off-by: Han123su --- monai/bundle/scripts.py | 45 ++++++++++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 1f459552be..cf307dc0c8 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1323,17 +1323,42 @@ def ckpt_export( inputs_: Sequence[Any] | None = [torch.rand(input_shape_)] if input_shape_ else None - converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_}) - # Use the given converter to convert a model and save with metadata, config content + net = parser.get_parsed_content(net_id_) + if has_ignite: + # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver + Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_) + else: + ckpt = torch.load(ckpt_file_) + copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_]) + + extra_files: dict = {} + for i in ensure_tuple(config_file_): + # split the filename and directory + filename = os.path.basename(i) + # remove extension + filename, _ = os.path.splitext(filename) + # because all files are stored as JSON their name parts without extension must be unique + if filename in extra_files: + raise ValueError(f"Filename part '{filename}' is given multiple times in config file list.") + # the file may be JSON or YAML but will get loaded and dumped out again as JSON + extra_files[filename] = json.dumps(ConfigParser.load_config_file(i)).encode() + + # add .json extension to all extra files which are always encoded as JSON + extra_files = {k + ".json": v for k, v in extra_files.items()} + + ckpt_saver = CkptSaver( + include_config_vals=False, + append_timestamp=False, + meta_values=parser.get().pop("_meta_", None), + more_extra_files=extra_files, + ) + _export( - convert_to_torchscript, - parser, - net_id=net_id_, - filepath=filepath_, - ckpt_file=ckpt_file_, - config_file=config_file_, - key_in_ckpt=key_in_ckpt_, - **converter_kwargs_, + convert_to_torchscript, + ckpt_saver.save, + net=net, + filepath=filepath_, + **converter_kwargs_, ) From 4079598a8af06e44c4a53e2dab650dc26364f0de Mon Sep 17 00:00:00 2001 From: Han123su Date: Mon, 10 Jun 2024 23:48:06 +0800 Subject: [PATCH 08/31] Modify trt_export: do specific processing before calling _export. Signed-off-by: Han123su --- monai/bundle/scripts.py | 43 +++++++++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index cf307dc0c8..76501b2091 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1521,15 +1521,42 @@ def trt_export( } converter_kwargs_.update(trt_api_parameters) + net = parser.get_parsed_content(net_id_) + if has_ignite: + # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver + Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_) + else: + ckpt = torch.load(ckpt_file_) + copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_]) + + extra_files: dict = {} + for i in ensure_tuple(config_file_): + # split the filename and directory + filename = os.path.basename(i) + # remove extension + filename, _ = os.path.splitext(filename) + # because all files are stored as JSON their name parts without extension must be unique + if filename in extra_files: + raise ValueError(f"Filename part '{filename}' is given multiple times in config file list.") + # the file may be JSON or YAML but will get loaded and dumped out again as JSON + extra_files[filename] = json.dumps(ConfigParser.load_config_file(i)).encode() + + # add .json extension to all extra files which are always encoded as JSON + extra_files = {k + ".json": v for k, v in extra_files.items()} + + trt_saver = TrtSaver( + include_config_vals=False, + append_timestamp=False, + meta_values=parser.get().pop("_meta_", None), + more_extra_files=extra_files, + ) + _export( - convert_to_trt, - parser, - net_id=net_id_, - filepath=filepath_, - ckpt_file=ckpt_file_, - config_file=config_file_, - key_in_ckpt=key_in_ckpt_, - **converter_kwargs_, + convert_to_trt, + trt_saver.save, + net=net, + filepath=filepath_, + **converter_kwargs_, ) From eb754ae86869ea059a0675beac99fe35154ea35d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Jun 2024 18:57:15 +0000 Subject: [PATCH 09/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/bundle/scripts.py | 51 ++++++++++++++++++++-------------------- monai/data/saver_type.py | 1 - 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 76501b2091..4f87f62e89 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -33,9 +33,8 @@ from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow from monai.config import IgniteInfo, PathLike -from monai.data import load_net_with_metadata, save_net_with_metadata -from monai.data.torchscript_utils import save_onnx -from monai.data.saver_type import CkptSaver, OnnxSaver, TrtSaver +from monai.data import load_net_with_metadata +from monai.data.saver_type import CkptSaver, OnnxSaver, TrtSaver from monai.networks import ( convert_to_onnx, convert_to_torchscript, @@ -1195,15 +1194,15 @@ def onnx_export( copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_]) converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_}) - - onnx_saver = OnnxSaver() + + onnx_saver = OnnxSaver() _export( - convert_to_onnx, - onnx_saver.save, - net=net, - filepath=filepath_, - **converter_kwargs_, + convert_to_onnx, + onnx_saver.save, + net=net, + filepath=filepath_, + **converter_kwargs_, ) @@ -1346,19 +1345,19 @@ def ckpt_export( # add .json extension to all extra files which are always encoded as JSON extra_files = {k + ".json": v for k, v in extra_files.items()} - ckpt_saver = CkptSaver( + ckpt_saver = CkptSaver( include_config_vals=False, append_timestamp=False, meta_values=parser.get().pop("_meta_", None), more_extra_files=extra_files, - ) + ) _export( - convert_to_torchscript, - ckpt_saver.save, - net=net, - filepath=filepath_, - **converter_kwargs_, + convert_to_torchscript, + ckpt_saver.save, + net=net, + filepath=filepath_, + **converter_kwargs_, ) @@ -1543,20 +1542,20 @@ def trt_export( # add .json extension to all extra files which are always encoded as JSON extra_files = {k + ".json": v for k, v in extra_files.items()} - - trt_saver = TrtSaver( + + trt_saver = TrtSaver( include_config_vals=False, append_timestamp=False, meta_values=parser.get().pop("_meta_", None), more_extra_files=extra_files, - ) - + ) + _export( - convert_to_trt, - trt_saver.save, - net=net, - filepath=filepath_, - **converter_kwargs_, + convert_to_trt, + trt_saver.save, + net=net, + filepath=filepath_, + **converter_kwargs_, ) diff --git a/monai/data/saver_type.py b/monai/data/saver_type.py index 0a387b120a..728bca34ac 100644 --- a/monai/data/saver_type.py +++ b/monai/data/saver_type.py @@ -3,7 +3,6 @@ from collections.abc import Mapping from typing import Any -from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA from monai.data import save_net_with_metadata from monai.data.torchscript_utils import save_onnx From d4e20d14b4446cd38bcfaf00163d822dc5013f66 Mon Sep 17 00:00:00 2001 From: Han123su Date: Tue, 11 Jun 2024 03:13:04 +0800 Subject: [PATCH 10/31] Add license header in saver_type.py Signed-off-by: Han123su --- monai/data/saver_type.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/monai/data/saver_type.py b/monai/data/saver_type.py index 0a387b120a..ca5366033c 100644 --- a/monai/data/saver_type.py +++ b/monai/data/saver_type.py @@ -1,3 +1,14 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations from collections.abc import Mapping From dd4a3d5d927a8d4a909f6726ab8b5d2cc00d9f30 Mon Sep 17 00:00:00 2001 From: Han123su Date: Tue, 11 Jun 2024 03:55:48 +0800 Subject: [PATCH 11/31] Autofix Signed-off-by: Han123su --- monai/bundle/scripts.py | 37 +++++---------------------------- monai/data/saver_type.py | 27 +++++++++++++++++------- monai/data/torchscript_utils.py | 7 ++----- 3 files changed, 26 insertions(+), 45 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 4f87f62e89..4775e4e25f 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1060,13 +1060,7 @@ def verify_net_in_out( logger.info("data shape of network is verified with no error.") -def _export( - converter: Callable, - saver: Callable, - net: str, - filepath: str, - **kwargs: Any, -) -> None: +def _export(converter: Callable, saver: Callable, net: str, filepath: str, **kwargs: Any) -> None: """ Export a model defined in the parser to a new one specified by the converter. @@ -1083,10 +1077,7 @@ def _export( # Use the given converter to convert a model and save with metadata, config content net = converter(model=net, **kwargs) - saver( - model_obj=net, - filepath=filepath, - ) + saver(model_obj=net, filepath=filepath) logger.info(f"exported to file: {filepath}.") @@ -1197,13 +1188,7 @@ def onnx_export( onnx_saver = OnnxSaver() - _export( - convert_to_onnx, - onnx_saver.save, - net=net, - filepath=filepath_, - **converter_kwargs_, - ) + _export(convert_to_onnx, onnx_saver.save, net=net, filepath=filepath_, **converter_kwargs_) def ckpt_export( @@ -1352,13 +1337,7 @@ def ckpt_export( more_extra_files=extra_files, ) - _export( - convert_to_torchscript, - ckpt_saver.save, - net=net, - filepath=filepath_, - **converter_kwargs_, - ) + _export(convert_to_torchscript, ckpt_saver.save, net=net, filepath=filepath_, **converter_kwargs_) def trt_export( @@ -1550,13 +1529,7 @@ def trt_export( more_extra_files=extra_files, ) - _export( - convert_to_trt, - trt_saver.save, - net=net, - filepath=filepath_, - **converter_kwargs_, - ) + _export(convert_to_trt, trt_saver.save, net=net, filepath=filepath_, **converter_kwargs_) def init_bundle( diff --git a/monai/data/saver_type.py b/monai/data/saver_type.py index f1771d0035..7431873cc2 100644 --- a/monai/data/saver_type.py +++ b/monai/data/saver_type.py @@ -20,13 +20,17 @@ class OnnxSaver: def save(self, model_obj, filepath): - save_onnx( - model_obj=model_obj, - filepath=filepath - ) + save_onnx(model_obj=model_obj, filepath=filepath) + class CkptSaver: - def __init__(self, include_config_vals: bool = True, append_timestamp: bool = False, meta_values: Mapping[str, Any] | None = None, more_extra_files: Mapping[str, bytes] | None = None): + def __init__( + self, + include_config_vals: bool = True, + append_timestamp: bool = False, + meta_values: Mapping[str, Any] | None = None, + more_extra_files: Mapping[str, bytes] | None = None, + ): self.include_config_vals = include_config_vals self.append_timestamp = append_timestamp self.meta_values = meta_values @@ -39,11 +43,18 @@ def save(self, model_obj, filepath): include_config_vals=self.include_config_vals, append_timestamp=self.append_timestamp, meta_values=self.meta_values, - more_extra_files=self.more_extra_files + more_extra_files=self.more_extra_files, ) + class TrtSaver: - def __init__(self, include_config_vals: bool = True, append_timestamp: bool = False, meta_values: Mapping[str, Any] | None = None, more_extra_files: Mapping[str, bytes] | None = None): + def __init__( + self, + include_config_vals: bool = True, + append_timestamp: bool = False, + meta_values: Mapping[str, Any] | None = None, + more_extra_files: Mapping[str, bytes] | None = None, + ): self.include_config_vals = include_config_vals self.append_timestamp = append_timestamp self.meta_values = meta_values @@ -56,5 +67,5 @@ def save(self, model_obj, filepath): include_config_vals=self.include_config_vals, append_timestamp=self.append_timestamp, meta_values=self.meta_values, - more_extra_files=self.more_extra_files + more_extra_files=self.more_extra_files, ) diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index d5a6b56d46..a5076b8dc3 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -17,8 +17,8 @@ from collections.abc import Mapping, Sequence from typing import IO, Any -import torch import onnx +import torch from monai.config import get_config_values from monai.utils import JITMetadataKeys @@ -101,10 +101,7 @@ def save_net_with_metadata( torch.jit.save(model_obj, filepath, extra_files) -def save_onnx( - model_obj: onnx.ModelProto, - filepath: str | IO[Any] -) -> None: +def save_onnx(model_obj: onnx.ModelProto, filepath: str | IO[Any]) -> None: """ Save the ONNX model to the given file or stream. From 5a9fedce907373b0129a43ba3cc6ed98e149f1c6 Mon Sep 17 00:00:00 2001 From: Han123su <107395380+Han123su@users.noreply.github.com> Date: Tue, 11 Jun 2024 16:30:49 +0800 Subject: [PATCH 12/31] Add onnx in requirements-min.txt Signed-off-by: Han123su <107395380+Han123su@users.noreply.github.com> --- requirements-min.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-min.txt b/requirements-min.txt index ad0bb1ef20..a8fa2fe221 100644 --- a/requirements-min.txt +++ b/requirements-min.txt @@ -3,3 +3,4 @@ setuptools>=50.3.0,<66.0.0,!=60.6.0 coverage>=5.5 parameterized +onnx From f9174326bcbe494f33fa1013f997969fb183cf00 Mon Sep 17 00:00:00 2001 From: Han123su Date: Tue, 11 Jun 2024 16:45:54 +0800 Subject: [PATCH 13/31] Fix local variable 'inputs_' is assigned to but never used Signed-off-by: Han123su --- monai/bundle/scripts.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 4775e4e25f..a5421eb0fe 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1307,6 +1307,8 @@ def ckpt_export( inputs_: Sequence[Any] | None = [torch.rand(input_shape_)] if input_shape_ else None + converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_}) + net = parser.get_parsed_content(net_id_) if has_ignite: # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver From ac62f042069ebff19842720f881d22e0f73e47cd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Jun 2024 08:51:25 +0000 Subject: [PATCH 14/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/bundle/scripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index a5421eb0fe..7e4bb1a2a3 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1308,7 +1308,7 @@ def ckpt_export( inputs_: Sequence[Any] | None = [torch.rand(input_shape_)] if input_shape_ else None converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_}) - + net = parser.get_parsed_content(net_id_) if has_ignite: # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver From f389a4d98bc776559f5ba261e475ab7b37259007 Mon Sep 17 00:00:00 2001 From: Han123su <107395380+Han123su@users.noreply.github.com> Date: Tue, 11 Jun 2024 16:53:24 +0800 Subject: [PATCH 15/31] Delete onnx in requirements-dev.txt Signed-off-by: Han123su <107395380+Han123su@users.noreply.github.com> --- requirements-dev.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index a8ba25966b..133dcb8856 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -49,7 +49,6 @@ h5py nni==2.10.1; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine optuna git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded -onnx>=1.13.0 onnxruntime; python_version <= '3.10' typeguard<3 # https://github.com/microsoft/nni/issues/5457 filelock<3.12.0 # https://github.com/microsoft/nni/issues/5523 From eb2f7e16877b37da4698cc79ae6ee6a8e488a642 Mon Sep 17 00:00:00 2001 From: Han123su Date: Tue, 11 Jun 2024 18:41:57 +0800 Subject: [PATCH 16/31] Change onnx_save location instead import onnx Signed-off-by: Han123su --- monai/bundle/scripts.py | 10 ++++++++++ requirements-min.txt | 1 - 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index a5421eb0fe..ed3bd54f64 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -67,6 +67,16 @@ DEFAULT_DOWNLOAD_SOURCE = os.environ.get("BUNDLE_DOWNLOAD_SRC", "monaihosting") PPRINT_CONFIG_N = 5 +def save_onnx(model_obj: onnx.ModelProto, filepath: str | IO[Any]) -> None: + """ + Save the ONNX model to the given file or stream. + + Args: + onnx_model: ONNX model to save. + filepath: Filename or file-like stream object to save the ONNX model. + """ + onnx.save(model_obj, filepath) + def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict: """ diff --git a/requirements-min.txt b/requirements-min.txt index a8fa2fe221..ad0bb1ef20 100644 --- a/requirements-min.txt +++ b/requirements-min.txt @@ -3,4 +3,3 @@ setuptools>=50.3.0,<66.0.0,!=60.6.0 coverage>=5.5 parameterized -onnx From d45dbcef76b2c9dcdebf7bf9dfe25cf81301d2f4 Mon Sep 17 00:00:00 2001 From: Han123su <107395380+Han123su@users.noreply.github.com> Date: Tue, 11 Jun 2024 18:48:11 +0800 Subject: [PATCH 17/31] Back to original requirements-dev.txt Signed-off-by: Han123su <107395380+Han123su@users.noreply.github.com> --- requirements-dev.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index 133dcb8856..a8ba25966b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -49,6 +49,7 @@ h5py nni==2.10.1; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine optuna git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded +onnx>=1.13.0 onnxruntime; python_version <= '3.10' typeguard<3 # https://github.com/microsoft/nni/issues/5457 filelock<3.12.0 # https://github.com/microsoft/nni/issues/5523 From 2e74d6efc839c15adc4121f737726305720d448e Mon Sep 17 00:00:00 2001 From: Han123su Date: Tue, 11 Jun 2024 18:50:49 +0800 Subject: [PATCH 18/31] Delete import onnx in torchscript_utils.py Signed-off-by: Han123su --- monai/data/torchscript_utils.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index a5076b8dc3..466d3c85ab 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -17,7 +17,6 @@ from collections.abc import Mapping, Sequence from typing import IO, Any -import onnx import torch from monai.config import get_config_values @@ -101,17 +100,6 @@ def save_net_with_metadata( torch.jit.save(model_obj, filepath, extra_files) -def save_onnx(model_obj: onnx.ModelProto, filepath: str | IO[Any]) -> None: - """ - Save the ONNX model to the given file or stream. - - Args: - onnx_model: ONNX model to save. - filepath: Filename or file-like stream object to save the ONNX model. - """ - onnx.save(model_obj, filepath) - - def load_net_with_metadata( filename_prefix_or_stream: str | IO[Any], map_location: torch.device | None = None, From 8b4766ff849a38d964e2f5bdc599dbe20233c015 Mon Sep 17 00:00:00 2001 From: Han123su Date: Tue, 11 Jun 2024 20:13:11 +0800 Subject: [PATCH 19/31] Fix import Signed-off-by: Han123su --- monai/data/saver_type.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/saver_type.py b/monai/data/saver_type.py index 7431873cc2..d4eda7cf8a 100644 --- a/monai/data/saver_type.py +++ b/monai/data/saver_type.py @@ -15,7 +15,7 @@ from typing import Any from monai.data import save_net_with_metadata -from monai.data.torchscript_utils import save_onnx +from monai.bundle.scripts import save_onnx class OnnxSaver: From 9c396aa94ad958762ad4f5f42ddcaf10c1a97114 Mon Sep 17 00:00:00 2001 From: Han123su Date: Tue, 11 Jun 2024 20:35:25 +0800 Subject: [PATCH 20/31] Fix cannot import name 'save_onnx' from partially initialized module 'monai.bundle.scripts' Signed-off-by: Han123su --- monai/bundle/scripts.py | 10 ---------- monai/data/saver_type.py | 13 ++++++++++++- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index dacd2117f5..7e4bb1a2a3 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -67,16 +67,6 @@ DEFAULT_DOWNLOAD_SOURCE = os.environ.get("BUNDLE_DOWNLOAD_SRC", "monaihosting") PPRINT_CONFIG_N = 5 -def save_onnx(model_obj: onnx.ModelProto, filepath: str | IO[Any]) -> None: - """ - Save the ONNX model to the given file or stream. - - Args: - onnx_model: ONNX model to save. - filepath: Filename or file-like stream object to save the ONNX model. - """ - onnx.save(model_obj, filepath) - def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict: """ diff --git a/monai/data/saver_type.py b/monai/data/saver_type.py index d4eda7cf8a..0675a0340c 100644 --- a/monai/data/saver_type.py +++ b/monai/data/saver_type.py @@ -15,8 +15,19 @@ from typing import Any from monai.data import save_net_with_metadata -from monai.bundle.scripts import save_onnx +from monai.utils.module import optional_import +onnx, _ = optional_import("onnx") + +def save_onnx(model_obj: onnx.ModelProto, filepath: str | IO[Any]) -> None: + """ + Save the ONNX model to the given file or stream. + + Args: + model_obj: ONNX model to save. + filepath: Filename or file-like stream object to save the ONNX model. + """ + onnx.save(model_obj, filepath) class OnnxSaver: def save(self, model_obj, filepath): From c12d65c42b628d8c59d0b68e0a90abb2e4196c08 Mon Sep 17 00:00:00 2001 From: Han123su Date: Tue, 11 Jun 2024 20:39:21 +0800 Subject: [PATCH 21/31] Fix Undefined name Signed-off-by: Han123su --- monai/data/saver_type.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/saver_type.py b/monai/data/saver_type.py index 0675a0340c..01a561fd8b 100644 --- a/monai/data/saver_type.py +++ b/monai/data/saver_type.py @@ -15,7 +15,7 @@ from typing import Any from monai.data import save_net_with_metadata - +from typing import IO, Any from monai.utils.module import optional_import onnx, _ = optional_import("onnx") From 9babcbc88d8f77b168dd84589c3a9dc70c77e652 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Jun 2024 12:43:01 +0000 Subject: [PATCH 22/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/saver_type.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/saver_type.py b/monai/data/saver_type.py index 01a561fd8b..dc796d5adf 100644 --- a/monai/data/saver_type.py +++ b/monai/data/saver_type.py @@ -15,7 +15,7 @@ from typing import Any from monai.data import save_net_with_metadata -from typing import IO, Any +from typing import IO from monai.utils.module import optional_import onnx, _ = optional_import("onnx") From cb1d4d288c0edf33e6ce4df356582d6cae739f77 Mon Sep 17 00:00:00 2001 From: Han123su Date: Tue, 11 Jun 2024 20:47:22 +0800 Subject: [PATCH 23/31] Fix incorrectly sorted and/or formatted. Signed-off-by: Han123su --- monai/data/saver_type.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/data/saver_type.py b/monai/data/saver_type.py index 01a561fd8b..658ba60ea3 100644 --- a/monai/data/saver_type.py +++ b/monai/data/saver_type.py @@ -12,13 +12,14 @@ from __future__ import annotations from collections.abc import Mapping -from typing import Any +from typing import IO, Any from monai.data import save_net_with_metadata -from typing import IO, Any from monai.utils.module import optional_import + onnx, _ = optional_import("onnx") + def save_onnx(model_obj: onnx.ModelProto, filepath: str | IO[Any]) -> None: """ Save the ONNX model to the given file or stream. @@ -29,6 +30,7 @@ def save_onnx(model_obj: onnx.ModelProto, filepath: str | IO[Any]) -> None: """ onnx.save(model_obj, filepath) + class OnnxSaver: def save(self, model_obj, filepath): save_onnx(model_obj=model_obj, filepath=filepath) From 2f96bcfd8d643674e2f69fefd3f04569b4c6e860 Mon Sep 17 00:00:00 2001 From: Han123su Date: Tue, 11 Jun 2024 21:27:57 +0800 Subject: [PATCH 24/31] Fix onnx.ModelProto not defined Signed-off-by: Han123su --- monai/data/saver_type.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/saver_type.py b/monai/data/saver_type.py index 658ba60ea3..aa4a030f10 100644 --- a/monai/data/saver_type.py +++ b/monai/data/saver_type.py @@ -20,7 +20,7 @@ onnx, _ = optional_import("onnx") -def save_onnx(model_obj: onnx.ModelProto, filepath: str | IO[Any]) -> None: +def save_onnx(model_obj: Any, filepath: str | IO[Any]) -> None: """ Save the ONNX model to the given file or stream. From 6459cc97bb966b2b0505d330b1101f137a7da2e3 Mon Sep 17 00:00:00 2001 From: Han123su Date: Thu, 18 Jul 2024 18:36:43 +0800 Subject: [PATCH 25/31] Back to original --- monai/bundle/scripts.py | 146 ++++++++++++++++---------------- monai/data/saver_type.py | 84 ------------------ monai/data/torchscript_utils.py | 22 ++--- 3 files changed, 84 insertions(+), 168 deletions(-) delete mode 100644 monai/data/saver_type.py diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 3d46553a9c..6ee031661f 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -34,8 +34,7 @@ from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow from monai.config import IgniteInfo, PathLike -from monai.data import load_net_with_metadata -from monai.data.saver_type import CkptSaver, OnnxSaver, TrtSaver +from monai.data import load_net_with_metadata, save_net_with_metadata from monai.networks import ( convert_to_onnx, convert_to_torchscript, @@ -1158,24 +1157,69 @@ def verify_net_in_out( logger.info("data shape of network is verified with no error.") -def _export(converter: Callable, saver: Callable, net: str, filepath: str, **kwargs: Any) -> None: +def _export( + converter: Callable, + parser: ConfigParser, + net_id: str, + filepath: str, + ckpt_file: str, + config_file: str, + key_in_ckpt: str, + **kwargs: Any, +) -> None: """ Export a model defined in the parser to a new one specified by the converter. Args: converter: a callable object that takes a torch.nn.module and kwargs as input and converts the module to another type. - saver: a callable object that takes the converted model and a filepath as input and - saves the model. - net: the network model to be converted. + parser: a ConfigParser of the bundle to be converted. + net_id: ID name of the network component in the parser, it must be `torch.nn.Module`. filepath: filepath to export, if filename has no extension, it becomes `.ts`. + ckpt_file: filepath of the model checkpoint to load. + config_file: filepath of the config file to save in the converted model,the saved key in the converted + model is the config filename without extension, and the saved config value is always serialized in + JSON format no matter the original file format is JSON or YAML. it can be a single file or a list + of files. + key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model + weights. if not nested checkpoint, no need to set. kwargs: key arguments for the converter. """ + net = parser.get_parsed_content(net_id) + if has_ignite: + # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver + Checkpoint.load_objects(to_load={key_in_ckpt: net}, checkpoint=ckpt_file) + else: + ckpt = torch.load(ckpt_file) + copy_model_state(dst=net, src=ckpt if key_in_ckpt == "" else ckpt[key_in_ckpt]) + # Use the given converter to convert a model and save with metadata, config content net = converter(model=net, **kwargs) - saver(model_obj=net, filepath=filepath) + extra_files: dict = {} + for i in ensure_tuple(config_file): + # split the filename and directory + filename = os.path.basename(i) + # remove extension + filename, _ = os.path.splitext(filename) + # because all files are stored as JSON their name parts without extension must be unique + if filename in extra_files: + raise ValueError(f"Filename part '{filename}' is given multiple times in config file list.") + # the file may be JSON or YAML but will get loaded and dumped out again as JSON + extra_files[filename] = json.dumps(ConfigParser.load_config_file(i)).encode() + + # add .json extension to all extra files which are always encoded as JSON + extra_files = {k + ".json": v for k, v in extra_files.items()} + + save_net_with_metadata( + jit_obj=net, + filename_prefix_or_stream=filepath, + include_config_vals=False, + append_timestamp=False, + meta_values=parser.get().pop("_meta_", None), + more_extra_files=extra_files, + ) logger.info(f"exported to file: {filepath}.") @@ -1283,10 +1327,8 @@ def onnx_export( copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_]) converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_}) - - onnx_saver = OnnxSaver() - - _export(convert_to_onnx, onnx_saver.save, net=net, filepath=filepath_, **converter_kwargs_) + onnx_model = convert_to_onnx(model=net, **converter_kwargs_) + onnx.save(onnx_model, filepath_) def ckpt_export( @@ -1406,39 +1448,18 @@ def ckpt_export( inputs_: Sequence[Any] | None = [torch.rand(input_shape_)] if input_shape_ else None converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_}) - - net = parser.get_parsed_content(net_id_) - if has_ignite: - # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver - Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_) - else: - ckpt = torch.load(ckpt_file_) - copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_]) - - extra_files: dict = {} - for i in ensure_tuple(config_file_): - # split the filename and directory - filename = os.path.basename(i) - # remove extension - filename, _ = os.path.splitext(filename) - # because all files are stored as JSON their name parts without extension must be unique - if filename in extra_files: - raise ValueError(f"Filename part '{filename}' is given multiple times in config file list.") - # the file may be JSON or YAML but will get loaded and dumped out again as JSON - extra_files[filename] = json.dumps(ConfigParser.load_config_file(i)).encode() - - # add .json extension to all extra files which are always encoded as JSON - extra_files = {k + ".json": v for k, v in extra_files.items()} - - ckpt_saver = CkptSaver( - include_config_vals=False, - append_timestamp=False, - meta_values=parser.get().pop("_meta_", None), - more_extra_files=extra_files, + # Use the given converter to convert a model and save with metadata, config content + _export( + convert_to_torchscript, + parser, + net_id=net_id_, + filepath=filepath_, + ckpt_file=ckpt_file_, + config_file=config_file_, + key_in_ckpt=key_in_ckpt_, + **converter_kwargs_, ) - _export(convert_to_torchscript, ckpt_saver.save, net=net, filepath=filepath_, **converter_kwargs_) - def trt_export( net_id: str | None = None, @@ -1599,38 +1620,17 @@ def trt_export( } converter_kwargs_.update(trt_api_parameters) - net = parser.get_parsed_content(net_id_) - if has_ignite: - # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver - Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_) - else: - ckpt = torch.load(ckpt_file_) - copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_]) - - extra_files: dict = {} - for i in ensure_tuple(config_file_): - # split the filename and directory - filename = os.path.basename(i) - # remove extension - filename, _ = os.path.splitext(filename) - # because all files are stored as JSON their name parts without extension must be unique - if filename in extra_files: - raise ValueError(f"Filename part '{filename}' is given multiple times in config file list.") - # the file may be JSON or YAML but will get loaded and dumped out again as JSON - extra_files[filename] = json.dumps(ConfigParser.load_config_file(i)).encode() - - # add .json extension to all extra files which are always encoded as JSON - extra_files = {k + ".json": v for k, v in extra_files.items()} - - trt_saver = TrtSaver( - include_config_vals=False, - append_timestamp=False, - meta_values=parser.get().pop("_meta_", None), - more_extra_files=extra_files, + _export( + convert_to_trt, + parser, + net_id=net_id_, + filepath=filepath_, + ckpt_file=ckpt_file_, + config_file=config_file_, + key_in_ckpt=key_in_ckpt_, + **converter_kwargs_, ) - _export(convert_to_trt, trt_saver.save, net=net, filepath=filepath_, **converter_kwargs_) - def init_bundle( bundle_dir: PathLike, @@ -1901,4 +1901,4 @@ def download_large_files(bundle_path: str | None = None, large_file_name: str | lf_data.pop("hash_type") lf_data["filepath"] = os.path.join(bundle_path, lf_data["path"]) lf_data.pop("path") - download_url(**lf_data) + download_url(**lf_data) \ No newline at end of file diff --git a/monai/data/saver_type.py b/monai/data/saver_type.py deleted file mode 100644 index aa4a030f10..0000000000 --- a/monai/data/saver_type.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from collections.abc import Mapping -from typing import IO, Any - -from monai.data import save_net_with_metadata -from monai.utils.module import optional_import - -onnx, _ = optional_import("onnx") - - -def save_onnx(model_obj: Any, filepath: str | IO[Any]) -> None: - """ - Save the ONNX model to the given file or stream. - - Args: - model_obj: ONNX model to save. - filepath: Filename or file-like stream object to save the ONNX model. - """ - onnx.save(model_obj, filepath) - - -class OnnxSaver: - def save(self, model_obj, filepath): - save_onnx(model_obj=model_obj, filepath=filepath) - - -class CkptSaver: - def __init__( - self, - include_config_vals: bool = True, - append_timestamp: bool = False, - meta_values: Mapping[str, Any] | None = None, - more_extra_files: Mapping[str, bytes] | None = None, - ): - self.include_config_vals = include_config_vals - self.append_timestamp = append_timestamp - self.meta_values = meta_values - self.more_extra_files = more_extra_files - - def save(self, model_obj, filepath): - save_net_with_metadata( - model_obj=model_obj, - filepath=filepath, - include_config_vals=self.include_config_vals, - append_timestamp=self.append_timestamp, - meta_values=self.meta_values, - more_extra_files=self.more_extra_files, - ) - - -class TrtSaver: - def __init__( - self, - include_config_vals: bool = True, - append_timestamp: bool = False, - meta_values: Mapping[str, Any] | None = None, - more_extra_files: Mapping[str, bytes] | None = None, - ): - self.include_config_vals = include_config_vals - self.append_timestamp = append_timestamp - self.meta_values = meta_values - self.more_extra_files = more_extra_files - - def save(self, model_obj, filepath): - save_net_with_metadata( - model_obj=model_obj, - filepath=filepath, - include_config_vals=self.include_config_vals, - append_timestamp=self.append_timestamp, - meta_values=self.meta_values, - more_extra_files=self.more_extra_files, - ) diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index 89603209e0..d61abeafa7 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -26,15 +26,15 @@ def save_net_with_metadata( - model_obj: torch.nn.Module, - filepath: str | IO[Any], + jit_obj: torch.nn.Module, + filename_prefix_or_stream: str | IO[Any], include_config_vals: bool = True, append_timestamp: bool = False, meta_values: Mapping[str, Any] | None = None, more_extra_files: Mapping[str, bytes] | None = None, ) -> None: """ - Save the JIT object (script or trace produced object) `model_obj` to the given file or stream with metadata + Save the JIT object (script or trace produced object) `jit_obj` to the given file or stream with metadata included as a JSON file. The Torchscript format is a zip file which can contain extra file data which is used here as a mechanism for storing metadata about the network being saved. The data in `meta_values` should be compatible with conversion to JSON using the standard library function `dumps`. The intent is this metadata will @@ -62,8 +62,8 @@ def save_net_with_metadata( Args: - model_obj: object to save, should be generated by `script` or `trace`. - filepath: filename or file-like stream object, if filename has no extension it becomes `.ts`. + jit_obj: object to save, should be generated by `script` or `trace`. + filename_prefix_or_stream: filename or file-like stream object, if filename has no extension it becomes `.ts`. include_config_vals: if True, MONAI, Pytorch, and Numpy versions are included in metadata. append_timestamp: if True, a timestamp for "now" is appended to the file's name before the extension. meta_values: metadata values to store with the object, not limited just to keys in `JITMetadataKeys`. @@ -87,17 +87,17 @@ def save_net_with_metadata( if more_extra_files is not None: extra_files.update(more_extra_files) - if isinstance(filepath, str): - filename_no_ext, ext = os.path.splitext(filepath) + if isinstance(filename_prefix_or_stream, str): + filename_no_ext, ext = os.path.splitext(filename_prefix_or_stream) if ext == "": ext = ".ts" if append_timestamp: - filepath = now.strftime(f"{filename_no_ext}_%Y%m%d%H%M%S{ext}") + filename_prefix_or_stream = now.strftime(f"{filename_no_ext}_%Y%m%d%H%M%S{ext}") else: - filepath = filename_no_ext + ext + filename_prefix_or_stream = filename_no_ext + ext - torch.jit.save(model_obj, filepath, extra_files) + torch.jit.save(jit_obj, filename_prefix_or_stream, extra_files) def load_net_with_metadata( @@ -131,4 +131,4 @@ def load_net_with_metadata( json_data_dict = json.loads(json_data) - return jit_obj, json_data_dict, extra_files + return jit_obj, json_data_dict, extra_files \ No newline at end of file From 6d4b8c999da5d62e8fae7f7e98eba5b31a028b13 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Jul 2024 10:42:30 +0000 Subject: [PATCH 26/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/bundle/scripts.py | 2 +- monai/data/torchscript_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 6ee031661f..56146546e8 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1901,4 +1901,4 @@ def download_large_files(bundle_path: str | None = None, large_file_name: str | lf_data.pop("hash_type") lf_data["filepath"] = os.path.join(bundle_path, lf_data["path"]) lf_data.pop("path") - download_url(**lf_data) \ No newline at end of file + download_url(**lf_data) diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index d61abeafa7..507cf411d6 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -131,4 +131,4 @@ def load_net_with_metadata( json_data_dict = json.loads(json_data) - return jit_obj, json_data_dict, extra_files \ No newline at end of file + return jit_obj, json_data_dict, extra_files From 42e08f236b47ce03e087bf67f5bdacf65dc1efbc Mon Sep 17 00:00:00 2001 From: Han123su Date: Sat, 20 Jul 2024 16:47:22 +0800 Subject: [PATCH 27/31] Modify to get closer to target --- monai/bundle/scripts.py | 55 +++++++++++++++++++++++++++++------------ 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 6ee031661f..22d1461561 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -35,6 +35,7 @@ from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow from monai.config import IgniteInfo, PathLike from monai.data import load_net_with_metadata, save_net_with_metadata +from functools import partial from monai.networks import ( convert_to_onnx, convert_to_torchscript, @@ -1159,6 +1160,7 @@ def verify_net_in_out( def _export( converter: Callable, + saver: Callable, parser: ConfigParser, net_id: str, filepath: str, @@ -1173,6 +1175,8 @@ def _export( Args: converter: a callable object that takes a torch.nn.module and kwargs as input and converts the module to another type. + saver: a callable object that takes the converted model and a filepath as input and + saves the model to the specified location. parser: a ConfigParser of the bundle to be converted. net_id: ID name of the network component in the parser, it must be `torch.nn.Module`. filepath: filepath to export, if filename has no extension, it becomes `.ts`. @@ -1212,14 +1216,12 @@ def _export( # add .json extension to all extra files which are always encoded as JSON extra_files = {k + ".json": v for k, v in extra_files.items()} - save_net_with_metadata( - jit_obj=net, - filename_prefix_or_stream=filepath, - include_config_vals=False, - append_timestamp=False, - meta_values=parser.get().pop("_meta_", None), - more_extra_files=extra_files, + saver( + jit_obj = net, + filename_prefix_or_stream = filepath, + more_extra_files = extra_files, ) + logger.info(f"exported to file: {filepath}.") @@ -1318,17 +1320,27 @@ def onnx_export( input_shape_ = _get_fake_input_shape(parser=parser) inputs_ = [torch.rand(input_shape_)] - net = parser.get_parsed_content(net_id_) - if has_ignite: - # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver - Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_) - else: - ckpt = torch.load(ckpt_file_) - copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_]) converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_}) - onnx_model = convert_to_onnx(model=net, **converter_kwargs_) - onnx.save(onnx_model, filepath_) + + def save_onnx( + jit_obj: torch.nn.Module, + filename_prefix_or_stream: str | IO[Any], + more_extra_files: None = None, + ) -> None: + onnx.save(jit_obj, filename_prefix_or_stream) + + _export( + convert_to_onnx, + save_onnx, + parser, + net_id=net_id_, + filepath=filepath_, + ckpt_file=ckpt_file_, + config_file=config_file_, + key_in_ckpt=key_in_ckpt_, + **converter_kwargs_, + ) def ckpt_export( @@ -1449,8 +1461,14 @@ def ckpt_export( converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_}) # Use the given converter to convert a model and save with metadata, config content + + save_ts = partial(save_net_with_metadata, include_config_vals=False, + append_timestamp=False, + meta_values=parser.get().pop("_meta_", None)) + _export( convert_to_torchscript, + save_ts, parser, net_id=net_id_, filepath=filepath_, @@ -1620,8 +1638,13 @@ def trt_export( } converter_kwargs_.update(trt_api_parameters) + save_ts = partial(save_net_with_metadata, include_config_vals=False, + append_timestamp=False, + meta_values=parser.get().pop("_meta_", None)) + _export( convert_to_trt, + save_ts, parser, net_id=net_id_, filepath=filepath_, From 182e20d9c7ba7c76ee21e9ab88cdc4aa56757a0a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 20 Jul 2024 09:17:33 +0000 Subject: [PATCH 28/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/bundle/scripts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 600dfbf11a..d92665b484 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1221,7 +1221,7 @@ def _export( filename_prefix_or_stream = filepath, more_extra_files = extra_files, ) - + logger.info(f"exported to file: {filepath}.") @@ -1329,7 +1329,7 @@ def save_onnx( more_extra_files: None = None, ) -> None: onnx.save(jit_obj, filename_prefix_or_stream) - + _export( convert_to_onnx, save_onnx, @@ -1465,7 +1465,7 @@ def ckpt_export( save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False, meta_values=parser.get().pop("_meta_", None)) - + _export( convert_to_torchscript, save_ts, From 7d6795371385930ce5e0d86980f7914c8f927f78 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Jul 2024 10:42:30 +0000 Subject: [PATCH 29/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/bundle/scripts.py | 2 +- monai/data/torchscript_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 22d1461561..600dfbf11a 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1924,4 +1924,4 @@ def download_large_files(bundle_path: str | None = None, large_file_name: str | lf_data.pop("hash_type") lf_data["filepath"] = os.path.join(bundle_path, lf_data["path"]) lf_data.pop("path") - download_url(**lf_data) \ No newline at end of file + download_url(**lf_data) diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index d61abeafa7..507cf411d6 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -131,4 +131,4 @@ def load_net_with_metadata( json_data_dict = json.loads(json_data) - return jit_obj, json_data_dict, extra_files \ No newline at end of file + return jit_obj, json_data_dict, extra_files From 608c407bfb17ee6296350fcc05bdd4dd991a1f6b Mon Sep 17 00:00:00 2001 From: Han123su Date: Sat, 20 Jul 2024 17:28:46 +0800 Subject: [PATCH 30/31] modify Signed-off-by: Han123su --- monai/bundle/scripts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 600dfbf11a..7eacfccb35 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1328,6 +1328,7 @@ def save_onnx( filename_prefix_or_stream: str | IO[Any], more_extra_files: None = None, ) -> None: + onnx.save(jit_obj, filename_prefix_or_stream) _export( From 87d736d29c0484fce18df374f290ebb5a16f0056 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 20 Jul 2024 09:33:44 +0000 Subject: [PATCH 31/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/bundle/scripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 169f661462..3dfe40f917 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1328,7 +1328,7 @@ def save_onnx( filename_prefix_or_stream: str | IO[Any], more_extra_files: None = None, ) -> None: - + onnx.save(jit_obj, filename_prefix_or_stream) _export(