diff --git a/monai/config/__init__.py b/monai/config/__init__.py index e3f623823c..64fdd3d7f0 100644 --- a/monai/config/__init__.py +++ b/monai/config/__init__.py @@ -12,7 +12,9 @@ from .deviceconfig import ( USE_COMPILED, IgniteInfo, + get_config_values, get_gpu_info, + get_optional_config_values, get_system_info, print_config, print_debug_info, diff --git a/monai/data/__init__.py b/monai/data/__init__.py index b12a307663..cbca16af37 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -43,6 +43,7 @@ from .synthetic import create_test_image_2d, create_test_image_3d from .test_time_augmentation import TestTimeAugmentation from .thread_buffer import ThreadBuffer, ThreadDataLoader +from .torchscript_utils import load_net_with_metadata, save_net_with_metadata from .utils import ( compute_importance_map, compute_shape_offset, diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py new file mode 100644 index 0000000000..17539ce1dc --- /dev/null +++ b/monai/data/torchscript_utils.py @@ -0,0 +1,149 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import json +import os +from typing import IO, Any, Mapping, Optional, Sequence, Tuple, Union + +import torch + +from monai.config import get_config_values +from monai.utils import JITMetadataKeys +from monai.utils.module import pytorch_after + +METADATA_FILENAME = "metadata.json" + + +def save_net_with_metadata( + jit_obj: torch.nn.Module, + filename_prefix_or_stream: Union[str, IO[Any]], + include_config_vals: bool = True, + append_timestamp: bool = False, + meta_values: Optional[Mapping[str, Any]] = None, + more_extra_files: Optional[Mapping[str, bytes]] = None, +) -> None: + """ + Save the JIT object (script or trace produced object) `jit_obj` to the given file or stream with metadata + included as a JSON file. The Torchscript format is a zip file which can contain extra file data which is used + here as a mechanism for storing metadata about the network being saved. The data in `meta_values` should be + compatible with conversion to JSON using the standard library function `dumps`. The intent is this metadata will + include information about the network applicable to some use case, such as describing the input and output format, + a network name and version, a plain language description of what the network does, and other relevant scientific + information. Clients can use this information to determine automatically how to use the network, and users can + read what the network does and keep track of versions. + + Examples:: + + net = torch.jit.script(monai.networks.nets.UNet(2, 1, 1, [8, 16], [2])) + + meta = { + "name": "Test UNet", + "used_for": "demonstration purposes", + "input_dims": 2, + "output_dims": 2 + } + + # save the Torchscript bundle with the above dictionary stored as an extra file + save_net_with_metadata(m, "test", meta_values=meta) + + # load the network back, `loaded_meta` has same data as `meta` plus version information + loaded_net, loaded_meta, _ = load_net_with_metadata("test.pt") + + + Args: + jit_obj: object to save, should be generated by `script` or `trace`. + filename_prefix_or_stream: filename or file-like stream object, if filename has no extension it becomes `.pt`. + include_config_vals: if True, MONAI, Pytorch, and Numpy versions are included in metadata. + append_timestamp: if True, a timestamp for "now" is appended to the file's name before the extension. + meta_values: metadata values to store with the object, not limited just to keys in `JITMetadataKeys`. + more_extra_files: other extra file data items to include in bundle, see `_extra_files` of `torch.jit.save`. + """ + + now = datetime.datetime.now() + metadict = {} + + if include_config_vals: + metadict.update(get_config_values()) + metadict[JITMetadataKeys.TIMESTAMP.value] = now.astimezone().isoformat() + + if meta_values is not None: + metadict.update(meta_values) + + json_data = json.dumps(metadict) + + # Pytorch>1.6 can use dictionaries directly, otherwise need to use special map object + if pytorch_after(1, 7): + extra_files = {METADATA_FILENAME: json_data.encode()} + + if more_extra_files is not None: + extra_files.update(more_extra_files) + else: + extra_files = torch._C.ExtraFilesMap() # type:ignore[attr-defined] + extra_files[METADATA_FILENAME] = json_data.encode() + + if more_extra_files is not None: + for k, v in more_extra_files.items(): + extra_files[k] = v + + if isinstance(filename_prefix_or_stream, str): + filename_no_ext, ext = os.path.splitext(filename_prefix_or_stream) + if ext == "": + ext = ".pt" + + if append_timestamp: + filename_prefix_or_stream = now.strftime(f"{filename_no_ext}_%Y%m%d%H%M%S{ext}") + else: + filename_prefix_or_stream = filename_no_ext + ext + + torch.jit.save(jit_obj, filename_prefix_or_stream, extra_files) + + +def load_net_with_metadata( + filename_prefix_or_stream: Union[str, IO[Any]], + map_location: Optional[torch.device] = None, + more_extra_files: Sequence[str] = (), +) -> Tuple[torch.nn.Module, dict, dict]: + """ + Load the module object from the given Torchscript filename or stream, and convert the stored JSON metadata + back to a dict object. This will produce an empty dict if the metadata file is not present. + + Args: + filename_prefix_or_stream: filename or file-like stream object. + map_location: network map location as in `torch.jit.load`. + more_extra_files: other extra file data names to load from bundle, see `_extra_files` of `torch.jit.load`. + Returns: + Triple containing loaded object, metadata dict, and extra files dict containing other file data if present + """ + # Pytorch>1.6 can use dictionaries directly, otherwise need to use special map object + if pytorch_after(1, 7): + extra_files = {f: "" for f in more_extra_files} + extra_files[METADATA_FILENAME] = "" + else: + extra_files = torch._C.ExtraFilesMap() # type:ignore[attr-defined] + extra_files[METADATA_FILENAME] = "" + + for f in more_extra_files: + extra_files[f] = "" + + jit_obj = torch.jit.load(filename_prefix_or_stream, map_location, extra_files) + + extra_files = dict(extra_files.items()) # compatibility with ExtraFilesMap + + if METADATA_FILENAME in extra_files: + json_data = extra_files[METADATA_FILENAME] + del extra_files[METADATA_FILENAME] + else: + json_data = "{}" + + json_data_dict = json.loads(json_data) + + return jit_obj, json_data_dict, extra_files diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index d58fcca32d..eeea64f0cd 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -24,6 +24,7 @@ GridSamplePadMode, InterpolateMode, InverseKeys, + JITMetadataKeys, LossReduction, Method, MetricReduction, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index c059a4a5e2..486a7b81ce 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -266,3 +266,15 @@ class TransformBackends(Enum): TORCH = "torch" NUMPY = "numpy" + + +class JITMetadataKeys(Enum): + """ + Keys stored in the metadata file for saved Torchscript models. Some of these are generated by the routines + and others are optionally provided by users. + """ + + NAME = "name" + TIMESTAMP = "timestamp" + VERSION = "version" + DESCRIPTION = "description" diff --git a/tests/test_torchscript_utils.py b/tests/test_torchscript_utils.py new file mode 100644 index 0000000000..b9840f68f1 --- /dev/null +++ b/tests/test_torchscript_utils.py @@ -0,0 +1,112 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import torch + +from monai.config import get_config_values +from monai.data import load_net_with_metadata, save_net_with_metadata +from monai.utils import JITMetadataKeys +from monai.utils.module import pytorch_after + + +class TestModule(torch.nn.Module): + def forward(self, x): + return x + 10 + + +class TestTorchscript(unittest.TestCase): + def test_save_net_with_metadata(self): + """Save a network without metadata to a file.""" + m = torch.jit.script(TestModule()) + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test") + + self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) + + def test_save_net_with_metadata_ext(self): + """Save a network without metadata to a file.""" + m = torch.jit.script(TestModule()) + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test.zip") + + self.assertTrue(os.path.isfile(f"{tempdir}/test.zip")) + + def test_save_net_with_metadata_with_extra(self): + """Save a network with simple metadata to a file.""" + m = torch.jit.script(TestModule()) + + test_metadata = {"foo": [1, 2], "bar": "string"} + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata) + + self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) + + def test_load_net_with_metadata(self): + """Save then load a network with no metadata or other extra files.""" + m = torch.jit.script(TestModule()) + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test") + _, meta, extra_files = load_net_with_metadata(f"{tempdir}/test.pt") + + del meta[JITMetadataKeys.TIMESTAMP.value] # no way of knowing precisely what this value would be + + self.assertEqual(meta, get_config_values()) + self.assertEqual(extra_files, {}) + + def test_load_net_with_metadata_with_extra(self): + """Save then load a network with basic metadata.""" + m = torch.jit.script(TestModule()) + + test_metadata = {"foo": [1, 2], "bar": "string"} + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata) + _, meta, extra_files = load_net_with_metadata(f"{tempdir}/test.pt") + + del meta[JITMetadataKeys.TIMESTAMP.value] # no way of knowing precisely what this value would be + + test_compare = get_config_values() + test_compare.update(test_metadata) + + self.assertEqual(meta, test_compare) + self.assertEqual(extra_files, {}) + + def test_save_load_more_extra_files(self): + """Save then load extra file data from a torchscript file.""" + m = torch.jit.script(TestModule()) + + test_metadata = {"foo": [1, 2], "bar": "string"} + + more_extra_files = {"test.txt": b"This is test data"} + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata, more_extra_files=more_extra_files) + + self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) + + _, _, loaded_extra_files = load_net_with_metadata(f"{tempdir}/test.pt", more_extra_files=("test.txt",)) + + if pytorch_after(1, 7): + self.assertEqual(more_extra_files["test.txt"], loaded_extra_files["test.txt"]) + else: + self.assertEqual(more_extra_files["test.txt"].decode(), loaded_extra_files["test.txt"]) + + +if __name__ == "__main__": + unittest.main()