From d39efef6449511c7ef6c9aaedcbbf963b0e5f5d6 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Fri, 15 Oct 2021 15:02:58 +0100 Subject: [PATCH 01/15] Adding Torchscript utility functions Signed-off-by: Eric Kerfoot --- monai/config/__init__.py | 2 + monai/data/__init__.py | 1 + monai/data/torchscript_utils.py | 118 ++++++++++++++++++++++++++++++++ monai/utils/__init__.py | 1 + monai/utils/enums.py | 12 ++++ tests/test_torchscript_utils.py | 90 ++++++++++++++++++++++++ 6 files changed, 224 insertions(+) create mode 100644 monai/data/torchscript_utils.py create mode 100644 tests/test_torchscript_utils.py diff --git a/monai/config/__init__.py b/monai/config/__init__.py index c929cb2362..dbda5a411b 100644 --- a/monai/config/__init__.py +++ b/monai/config/__init__.py @@ -12,6 +12,8 @@ from .deviceconfig import ( USE_COMPILED, IgniteInfo, + get_config_values, + get_optional_config_values, get_gpu_info, get_system_info, print_config, diff --git a/monai/data/__init__.py b/monai/data/__init__.py index fca170335b..ba2a981eec 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -37,6 +37,7 @@ from .synthetic import create_test_image_2d, create_test_image_3d from .test_time_augmentation import TestTimeAugmentation from .thread_buffer import ThreadBuffer, ThreadDataLoader +from .torchscript_utils import save_net_with_metadata, load_net_with_metadata from .utils import ( compute_importance_map, compute_shape_offset, diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py new file mode 100644 index 0000000000..88d752e34e --- /dev/null +++ b/monai/data/torchscript_utils.py @@ -0,0 +1,118 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import json +import os +import torch +from typing import Any, IO, Mapping, Optional, Sequence, Tuple, Union + +from monai.config import get_config_values +from monai.utils import JITMetadataKeys + + +METADATA_FILENAME = "metadata.json" + + +def save_net_with_metadata( + jit_obj: torch.nn.Module, + filename_prefix_or_stream: Union[str, IO[Any]], + include_config_vals: bool = True, + append_timestamp: bool = False, + meta_values: Mapping[str, Any] = {}, + more_extra_files: Mapping[str, bytes] = {}, +) -> None: + """ + Save the JIT object (script or trace produced object) `jit_obj` to the given file or stream with metadata + included as a JSON file. The Torchscript format is a zip file which can contain extra file data which is used + here as a mechanism for storing metadata about the network being saved. The data in `meta_values` should be + compatible with conversion to JSON using the standard library function `dumps`. The intent is this metadata will + include information about the network applicable to some use case, such as describing the input and output format, + a network name and version, a plain language description of what the network does, and other relevant scientific + information. Clients can use this information to determine automatically how to use the network, and users can + read what the network does and keep track of versions. + + Examples:: + + net = torch.jit.script(monai.networks.nets.UNet(2, 1, 1, [8, 16], [2])) + + meta = { + "name": "Test UNet", + "used_for": "demonstration purposes", + "input_dims": 2, + "output_dims": 2 + } + + # save the Torchscript bundle with the above dictionary stored as an extra file + save_net_with_metadata(m, "test", meta_values=meta) + + # load the network back, `loaded_meta` has same data as `meta` plus version information + loaded_net, loaded_meta, _ = load_net_with_metadata("test.pt") + + + Args: + jit_obj: object to save, should be generated by `script` or `trace`. + filename_prefix_or_stream: filename or file-like stream object, if filename has no extension it becomes .pt. + include_config_vals: if True, MONAI, Pytorch, and Numpy versions are included in metadata. + append_timestamp: if True, a timestamp for "now" is appended to the file's name before the extension. + meta_values: metadata values to store with the object, not limited just to keys in `JITMetadataKeys`. + more_extra_files: other extra file data items to include in bundle. + """ + + now = datetime.datetime.now() + metadict = {} + + if include_config_vals: + metadict.update(get_config_values()) + metadict[JITMetadataKeys.TIMESTAMP.value] = now.astimezone().isoformat() + + metadict.update(meta_values) + json_data = json.dumps(metadict) + + extra_files = {METADATA_FILENAME: json_data.encode()} + extra_files.update(more_extra_files) + + if isinstance(filename_prefix_or_stream, str): + filename_no_ext, ext = os.path.splitext(filename_prefix_or_stream) + if ext == "": + ext = ".pt" + + if append_timestamp: + filename_prefix_or_stream = now.strftime(f"{filename_no_ext}_%Y%m%d%H%M%S{ext}") + else: + filename_prefix_or_stream = filename_no_ext + ext + + torch.jit.save(jit_obj, filename_prefix_or_stream, extra_files) + + +def load_net_with_metadata( + filename_prefix_or_stream: Union[str, IO[Any]], + map_location: Optional[torch.device] = None, + more_extra_files: Sequence[str] = (), +) -> Tuple[torch.nn.Module, dict, dict]: + """ + Load the module object from the given Torchscript filename or stream, and convert the stored JSON metadata + back to a dict object. This will produce an empty dict if the metadata file is not present. + + Args: + filename_prefix_or_stream: filename or file-like stream object. + map_location: network map location as in `torch.jit.load`. + more_extra_files: other extra file data names to load from bundle. + Returns: + Triple containing loaded object, metadata dict, and extra files dict containing other file data if present + """ + extra_files = {f: "" for f in more_extra_files} + extra_files[METADATA_FILENAME] = "" + + jit_obj = torch.jit.load(filename_prefix_or_stream, map_location, extra_files) + json_data = json.loads(extra_files.pop(METADATA_FILENAME, "{}")) + + return jit_obj, json_data, extra_files diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index dc3922933d..5ce69c24db 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -24,6 +24,7 @@ GridSamplePadMode, InterpolateMode, InverseKeys, + JITMetadataKeys, LossReduction, Method, MetricReduction, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 847df9e2d3..06620fb6d2 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -243,3 +243,15 @@ class TransformBackends(Enum): TORCH = "torch" NUMPY = "numpy" + + +class JITMetadataKeys(Enum): + """ + Keys stored in the metadata file for saved Torchscript models. Some of these are generated by the routines + and others are optionally provided by users. + """ + NAME = "name" + TIMESTAMP = "timestamp" + VERSION = "version" + DESCRIPTION = "description" + \ No newline at end of file diff --git a/tests/test_torchscript_utils.py b/tests/test_torchscript_utils.py new file mode 100644 index 0000000000..e32d4de54f --- /dev/null +++ b/tests/test_torchscript_utils.py @@ -0,0 +1,90 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import torch + +from monai.config import get_config_values +from monai.data import save_net_with_metadata, load_net_with_metadata +from monai.utils import JITMetadataKeys + +class TestModule(torch.nn.Module): + def forward(self, x): + return x + 10 + + +class TestTorchscript(unittest.TestCase): + def test_save_net_with_metadata(self): + """Save a network without metadata to a file.""" + m = torch.jit.script(TestModule()) + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test") + + self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) + + def test_save_net_with_metadata_ext(self): + """Save a network without metadata to a file.""" + m = torch.jit.script(TestModule()) + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test.zip") + + self.assertTrue(os.path.isfile(f"{tempdir}/test.zip")) + + def test_save_net_with_metadata_with_extra(self): + """Save a network with simple metadata to a file.""" + m = torch.jit.script(TestModule()) + + test_metadata = {"foo": [1, 2], "bar": "string"} + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata) + + self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) + + def test_load_net_with_metadata(self): + """Save then load a network with no metadata or other extra files.""" + m = torch.jit.script(TestModule()) + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test") + _, meta, extra_files = load_net_with_metadata(f"{tempdir}/test.pt") + + del meta[JITMetadataKeys.TIMESTAMP.value] # no way of knowing precisely what this value would be + + self.assertEqual(meta, get_config_values()) + self.assertEqual(extra_files, {}) + + def test_load_net_with_metadata_with_extra(self): + """Save then load a network with basic metadata. """ + m = torch.jit.script(TestModule()) + + test_metadata = {"foo": [1, 2], "bar": "string"} + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata) + _, meta, extra_files = load_net_with_metadata(f"{tempdir}/test.pt") + + del meta[JITMetadataKeys.TIMESTAMP.value] # no way of knowing precisely what this value would be + + test_compare = get_config_values() + test_compare.update(test_metadata) + + self.assertEqual(meta, test_compare) + self.assertEqual(extra_files, {}) + + +if __name__ == "__main__": + unittest.main() From 2b59bfbe5a047d2eec24ed8eb25354121c9c2258 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Oct 2021 14:06:50 +0000 Subject: [PATCH 02/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/torchscript_utils.py | 28 ++++++++++++++-------------- monai/utils/enums.py | 4 ++-- tests/test_torchscript_utils.py | 12 ++++++------ 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index 88d752e34e..d46a097aad 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -37,27 +37,27 @@ def save_net_with_metadata( compatible with conversion to JSON using the standard library function `dumps`. The intent is this metadata will include information about the network applicable to some use case, such as describing the input and output format, a network name and version, a plain language description of what the network does, and other relevant scientific - information. Clients can use this information to determine automatically how to use the network, and users can - read what the network does and keep track of versions. - + information. Clients can use this information to determine automatically how to use the network, and users can + read what the network does and keep track of versions. + Examples:: - + net = torch.jit.script(monai.networks.nets.UNet(2, 1, 1, [8, 16], [2])) - - meta = { - "name": "Test UNet", - "used_for": "demonstration purposes", - "input_dims": 2, + + meta = { + "name": "Test UNet", + "used_for": "demonstration purposes", + "input_dims": 2, "output_dims": 2 } - + # save the Torchscript bundle with the above dictionary stored as an extra file save_net_with_metadata(m, "test", meta_values=meta) - + # load the network back, `loaded_meta` has same data as `meta` plus version information loaded_net, loaded_meta, _ = load_net_with_metadata("test.pt") - - + + Args: jit_obj: object to save, should be generated by `script` or `trace`. filename_prefix_or_stream: filename or file-like stream object, if filename has no extension it becomes .pt. @@ -101,7 +101,7 @@ def load_net_with_metadata( """ Load the module object from the given Torchscript filename or stream, and convert the stored JSON metadata back to a dict object. This will produce an empty dict if the metadata file is not present. - + Args: filename_prefix_or_stream: filename or file-like stream object. map_location: network map location as in `torch.jit.load`. diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 06620fb6d2..4ddf5fa1a1 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -244,7 +244,7 @@ class TransformBackends(Enum): TORCH = "torch" NUMPY = "numpy" - + class JITMetadataKeys(Enum): """ Keys stored in the metadata file for saved Torchscript models. Some of these are generated by the routines @@ -254,4 +254,4 @@ class JITMetadataKeys(Enum): TIMESTAMP = "timestamp" VERSION = "version" DESCRIPTION = "description" - \ No newline at end of file + diff --git a/tests/test_torchscript_utils.py b/tests/test_torchscript_utils.py index e32d4de54f..29c3a22db7 100644 --- a/tests/test_torchscript_utils.py +++ b/tests/test_torchscript_utils.py @@ -31,16 +31,16 @@ def test_save_net_with_metadata(self): with tempfile.TemporaryDirectory() as tempdir: save_net_with_metadata(m, f"{tempdir}/test") - + self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) - + def test_save_net_with_metadata_ext(self): """Save a network without metadata to a file.""" m = torch.jit.script(TestModule()) with tempfile.TemporaryDirectory() as tempdir: save_net_with_metadata(m, f"{tempdir}/test.zip") - + self.assertTrue(os.path.isfile(f"{tempdir}/test.zip")) def test_save_net_with_metadata_with_extra(self): @@ -51,7 +51,7 @@ def test_save_net_with_metadata_with_extra(self): with tempfile.TemporaryDirectory() as tempdir: save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata) - + self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) def test_load_net_with_metadata(self): @@ -63,7 +63,7 @@ def test_load_net_with_metadata(self): _, meta, extra_files = load_net_with_metadata(f"{tempdir}/test.pt") del meta[JITMetadataKeys.TIMESTAMP.value] # no way of knowing precisely what this value would be - + self.assertEqual(meta, get_config_values()) self.assertEqual(extra_files, {}) @@ -84,7 +84,7 @@ def test_load_net_with_metadata_with_extra(self): self.assertEqual(meta, test_compare) self.assertEqual(extra_files, {}) - + if __name__ == "__main__": unittest.main() From aa5b29d8539e96d836c3687a0973998076d369e7 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Fri, 15 Oct 2021 14:12:27 +0000 Subject: [PATCH 03/15] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/config/__init__.py | 2 +- monai/data/__init__.py | 2 +- monai/data/torchscript_utils.py | 4 ++-- monai/utils/enums.py | 2 +- tests/test_torchscript_utils.py | 5 +++-- 5 files changed, 8 insertions(+), 7 deletions(-) diff --git a/monai/config/__init__.py b/monai/config/__init__.py index dbda5a411b..d8aaa5707c 100644 --- a/monai/config/__init__.py +++ b/monai/config/__init__.py @@ -13,8 +13,8 @@ USE_COMPILED, IgniteInfo, get_config_values, - get_optional_config_values, get_gpu_info, + get_optional_config_values, get_system_info, print_config, print_debug_info, diff --git a/monai/data/__init__.py b/monai/data/__init__.py index ba2a981eec..be84441bb2 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -37,7 +37,7 @@ from .synthetic import create_test_image_2d, create_test_image_3d from .test_time_augmentation import TestTimeAugmentation from .thread_buffer import ThreadBuffer, ThreadDataLoader -from .torchscript_utils import save_net_with_metadata, load_net_with_metadata +from .torchscript_utils import load_net_with_metadata, save_net_with_metadata from .utils import ( compute_importance_map, compute_shape_offset, diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index d46a097aad..87f7b6f1a6 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -12,13 +12,13 @@ import datetime import json import os +from typing import IO, Any, Mapping, Optional, Sequence, Tuple, Union + import torch -from typing import Any, IO, Mapping, Optional, Sequence, Tuple, Union from monai.config import get_config_values from monai.utils import JITMetadataKeys - METADATA_FILENAME = "metadata.json" diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 4ddf5fa1a1..ffa7b44340 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -250,8 +250,8 @@ class JITMetadataKeys(Enum): Keys stored in the metadata file for saved Torchscript models. Some of these are generated by the routines and others are optionally provided by users. """ + NAME = "name" TIMESTAMP = "timestamp" VERSION = "version" DESCRIPTION = "description" - diff --git a/tests/test_torchscript_utils.py b/tests/test_torchscript_utils.py index 29c3a22db7..4f29e3d4c0 100644 --- a/tests/test_torchscript_utils.py +++ b/tests/test_torchscript_utils.py @@ -16,9 +16,10 @@ import torch from monai.config import get_config_values -from monai.data import save_net_with_metadata, load_net_with_metadata +from monai.data import load_net_with_metadata, save_net_with_metadata from monai.utils import JITMetadataKeys + class TestModule(torch.nn.Module): def forward(self, x): return x + 10 @@ -68,7 +69,7 @@ def test_load_net_with_metadata(self): self.assertEqual(extra_files, {}) def test_load_net_with_metadata_with_extra(self): - """Save then load a network with basic metadata. """ + """Save then load a network with basic metadata.""" m = torch.jit.script(TestModule()) test_metadata = {"foo": [1, 2], "bar": "string"} From 8d1a32d58783ca477be937647eabae24929876c1 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Fri, 15 Oct 2021 15:32:08 +0100 Subject: [PATCH 04/15] Adding Torchscript utility functions Signed-off-by: Eric Kerfoot --- monai/data/torchscript_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index 87f7b6f1a6..d9b76ed154 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -27,8 +27,8 @@ def save_net_with_metadata( filename_prefix_or_stream: Union[str, IO[Any]], include_config_vals: bool = True, append_timestamp: bool = False, - meta_values: Mapping[str, Any] = {}, - more_extra_files: Mapping[str, bytes] = {}, + meta_values: Optional[Mapping[str, Any]] = None, + more_extra_files: Optional[Mapping[str, bytes]] = None, ) -> None: """ Save the JIT object (script or trace produced object) `jit_obj` to the given file or stream with metadata @@ -74,11 +74,15 @@ def save_net_with_metadata( metadict.update(get_config_values()) metadict[JITMetadataKeys.TIMESTAMP.value] = now.astimezone().isoformat() - metadict.update(meta_values) + if meta_values is not None: + metadict.update(meta_values) + json_data = json.dumps(metadict) extra_files = {METADATA_FILENAME: json_data.encode()} - extra_files.update(more_extra_files) + + if more_extra_files is not None: + extra_files.update(more_extra_files) if isinstance(filename_prefix_or_stream, str): filename_no_ext, ext = os.path.splitext(filename_prefix_or_stream) From 3dfd1f82e5b3ba18d825894bae7c9fe98276679f Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Sun, 19 Dec 2021 18:07:01 +0000 Subject: [PATCH 05/15] Added test for extra files Signed-off-by: Eric Kerfoot --- monai/data/torchscript_utils.py | 20 +++++++++++++++----- tests/test_torchscript_utils.py | 17 +++++++++++++++++ 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index d9b76ed154..8542af4b07 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -18,6 +18,7 @@ from monai.config import get_config_values from monai.utils import JITMetadataKeys +from monai.utils.module import pytorch_after METADATA_FILENAME = "metadata.json" @@ -64,7 +65,7 @@ def save_net_with_metadata( include_config_vals: if True, MONAI, Pytorch, and Numpy versions are included in metadata. append_timestamp: if True, a timestamp for "now" is appended to the file's name before the extension. meta_values: metadata values to store with the object, not limited just to keys in `JITMetadataKeys`. - more_extra_files: other extra file data items to include in bundle. + more_extra_files: other extra file data items to include in bundle, see `_extra_files` of `torch.jit.save`. """ now = datetime.datetime.now() @@ -79,10 +80,19 @@ def save_net_with_metadata( json_data = json.dumps(metadict) - extra_files = {METADATA_FILENAME: json_data.encode()} + # Pytorch>1.6 can use dictionaries directly, otherwise need to use special map object + if pytorch_after(1, 6): + extra_files = {METADATA_FILENAME: json_data.encode()} - if more_extra_files is not None: - extra_files.update(more_extra_files) + if more_extra_files is not None: + extra_files.update(more_extra_files) + else: + extra_files = torch._C.ExtraFilesMap() + extra_files[METADATA_FILENAME] = json_data.encode() + + if more_extra_files is not None: + for k, v in more_extra_files.items(): + extra_files[k] = v if isinstance(filename_prefix_or_stream, str): filename_no_ext, ext = os.path.splitext(filename_prefix_or_stream) @@ -109,7 +119,7 @@ def load_net_with_metadata( Args: filename_prefix_or_stream: filename or file-like stream object. map_location: network map location as in `torch.jit.load`. - more_extra_files: other extra file data names to load from bundle. + more_extra_files: other extra file data names to load from bundle, see `_extra_files` of `torch.jit.load`. Returns: Triple containing loaded object, metadata dict, and extra files dict containing other file data if present """ diff --git a/tests/test_torchscript_utils.py b/tests/test_torchscript_utils.py index 4f29e3d4c0..6792c8ef2e 100644 --- a/tests/test_torchscript_utils.py +++ b/tests/test_torchscript_utils.py @@ -86,6 +86,23 @@ def test_load_net_with_metadata_with_extra(self): self.assertEqual(meta, test_compare) self.assertEqual(extra_files, {}) + def test_save_load_more_extra_files(self): + """Save then load extra file data from a torchscript file.""" + m = torch.jit.script(TestModule()) + + test_metadata = {"foo": [1, 2], "bar": "string"} + + more_extra_files = {"test.txt": b"This is test data"} + + with tempfile.TemporaryDirectory() as tempdir: + save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata, more_extra_files=more_extra_files) + + self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) + + _, _, loaded_extra_files = load_net_with_metadata(f"{tempdir}/test.pt", more_extra_files=("test.txt",)) + + self.assertEqual(more_extra_files["test.txt"], loaded_extra_files["test.txt"]) + if __name__ == "__main__": unittest.main() From be984dc666fb1a4d38b961fc4ea32110dc7a8bf7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 19 Dec 2021 18:07:29 +0000 Subject: [PATCH 06/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/torchscript_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index 8542af4b07..d115f5bf3c 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -65,7 +65,7 @@ def save_net_with_metadata( include_config_vals: if True, MONAI, Pytorch, and Numpy versions are included in metadata. append_timestamp: if True, a timestamp for "now" is appended to the file's name before the extension. meta_values: metadata values to store with the object, not limited just to keys in `JITMetadataKeys`. - more_extra_files: other extra file data items to include in bundle, see `_extra_files` of `torch.jit.save`. + more_extra_files: other extra file data items to include in bundle, see `_extra_files` of `torch.jit.save`. """ now = datetime.datetime.now() @@ -89,7 +89,7 @@ def save_net_with_metadata( else: extra_files = torch._C.ExtraFilesMap() extra_files[METADATA_FILENAME] = json_data.encode() - + if more_extra_files is not None: for k, v in more_extra_files.items(): extra_files[k] = v From 1c0f3c4da89fe6ba1157abf1bd8d238c2a77564d Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Sun, 19 Dec 2021 18:17:57 +0000 Subject: [PATCH 07/15] Update Signed-off-by: Eric Kerfoot --- monai/data/torchscript_utils.py | 2 +- monai/networks/utils.py | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index d115f5bf3c..d679bcb21a 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -81,7 +81,7 @@ def save_net_with_metadata( json_data = json.dumps(metadict) # Pytorch>1.6 can use dictionaries directly, otherwise need to use special map object - if pytorch_after(1, 6): + if pytorch_after(1, 7): extra_files = {METADATA_FILENAME: json_data.encode()} if more_extra_files is not None: diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 0cff97cf27..dab7066ca1 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -230,12 +230,8 @@ def icnr_init(conv, upsample_factor, init=nn.init.kaiming_normal_): conv.weight.data.copy_(kernel) -@deprecated_arg( - name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." -) -def pixelshuffle( - x: torch.Tensor, spatial_dims: int, scale_factor: int, dimensions: Optional[int] = None -) -> torch.Tensor: +@deprecated_arg(name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead.") +def pixelshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int, dimensions: Optional[int] = None) -> torch.Tensor: """ Apply pixel shuffle to the tensor `x` with spatial dimensions `spatial_dims` and scaling factor `scale_factor`. From 1c2f6207b03896eeae785adcffbe00063644383d Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Sun, 19 Dec 2021 18:31:47 +0000 Subject: [PATCH 08/15] Update Signed-off-by: Eric Kerfoot --- monai/data/torchscript_utils.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index d679bcb21a..7be1e2d714 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -65,7 +65,7 @@ def save_net_with_metadata( include_config_vals: if True, MONAI, Pytorch, and Numpy versions are included in metadata. append_timestamp: if True, a timestamp for "now" is appended to the file's name before the extension. meta_values: metadata values to store with the object, not limited just to keys in `JITMetadataKeys`. - more_extra_files: other extra file data items to include in bundle, see `_extra_files` of `torch.jit.save`. + more_extra_files: other extra file data items to include in bundle, see `_extra_files` of `torch.jit.save`. """ now = datetime.datetime.now() @@ -123,8 +123,16 @@ def load_net_with_metadata( Returns: Triple containing loaded object, metadata dict, and extra files dict containing other file data if present """ - extra_files = {f: "" for f in more_extra_files} - extra_files[METADATA_FILENAME] = "" + # Pytorch>1.6 can use dictionaries directly, otherwise need to use special map object + if pytorch_after(1, 7): + extra_files = {f: "" for f in more_extra_files} + extra_files[METADATA_FILENAME] = "" + else: + extra_files = torch._C.ExtraFilesMap() + extra_files[METADATA_FILENAME] = "" + + for f in more_extra_files: + extra_files[f] = "" jit_obj = torch.jit.load(filename_prefix_or_stream, map_location, extra_files) json_data = json.loads(extra_files.pop(METADATA_FILENAME, "{}")) From dd8c033e018420b75463880ded62b44173478bdd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 19 Dec 2021 18:32:17 +0000 Subject: [PATCH 09/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/torchscript_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index 7be1e2d714..9de577b737 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -65,7 +65,7 @@ def save_net_with_metadata( include_config_vals: if True, MONAI, Pytorch, and Numpy versions are included in metadata. append_timestamp: if True, a timestamp for "now" is appended to the file's name before the extension. meta_values: metadata values to store with the object, not limited just to keys in `JITMetadataKeys`. - more_extra_files: other extra file data items to include in bundle, see `_extra_files` of `torch.jit.save`. + more_extra_files: other extra file data items to include in bundle, see `_extra_files` of `torch.jit.save`. """ now = datetime.datetime.now() From 0e78ad68d6b9347f5d92a3ed9b62259698954b4f Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Sun, 19 Dec 2021 21:52:35 +0000 Subject: [PATCH 10/15] Update Signed-off-by: Eric Kerfoot --- monai/data/torchscript_utils.py | 11 +++++++++-- monai/networks/utils.py | 8 ++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index 9de577b737..7b545775db 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -135,6 +135,13 @@ def load_net_with_metadata( extra_files[f] = "" jit_obj = torch.jit.load(filename_prefix_or_stream, map_location, extra_files) - json_data = json.loads(extra_files.pop(METADATA_FILENAME, "{}")) - return jit_obj, json_data, extra_files + if METADATA_FILENAME in extra_files: + json_data = extra_files[METADATA_FILENAME] + del extra_files[METADATA_FILENAME] + else: + json_data = "{}" + + json_data_dict = json.loads(json_data) + + return jit_obj, json_data_dict, extra_files diff --git a/monai/networks/utils.py b/monai/networks/utils.py index dab7066ca1..0cff97cf27 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -230,8 +230,12 @@ def icnr_init(conv, upsample_factor, init=nn.init.kaiming_normal_): conv.weight.data.copy_(kernel) -@deprecated_arg(name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead.") -def pixelshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int, dimensions: Optional[int] = None) -> torch.Tensor: +@deprecated_arg( + name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." +) +def pixelshuffle( + x: torch.Tensor, spatial_dims: int, scale_factor: int, dimensions: Optional[int] = None +) -> torch.Tensor: """ Apply pixel shuffle to the tensor `x` with spatial dimensions `spatial_dims` and scaling factor `scale_factor`. From ae80786c7d67561d0af381547826962ecd4a4a87 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Sun, 19 Dec 2021 22:13:49 +0000 Subject: [PATCH 11/15] Update Signed-off-by: Eric Kerfoot --- monai/data/torchscript_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index 7b545775db..ffeba75703 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -87,7 +87,7 @@ def save_net_with_metadata( if more_extra_files is not None: extra_files.update(more_extra_files) else: - extra_files = torch._C.ExtraFilesMap() + extra_files = torch._C.ExtraFilesMap() # ignore: attr-defined extra_files[METADATA_FILENAME] = json_data.encode() if more_extra_files is not None: @@ -128,7 +128,7 @@ def load_net_with_metadata( extra_files = {f: "" for f in more_extra_files} extra_files[METADATA_FILENAME] = "" else: - extra_files = torch._C.ExtraFilesMap() + extra_files = torch._C.ExtraFilesMap() # ignore: attr-defined extra_files[METADATA_FILENAME] = "" for f in more_extra_files: From b994604e7b5188e5a346672859c193054cd1b33e Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Sun, 19 Dec 2021 23:03:26 +0000 Subject: [PATCH 12/15] Updates Signed-off-by: Eric Kerfoot --- monai/data/torchscript_utils.py | 6 ++++-- tests/test_torchscript_utils.py | 6 +++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index ffeba75703..41cff3f583 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -87,7 +87,7 @@ def save_net_with_metadata( if more_extra_files is not None: extra_files.update(more_extra_files) else: - extra_files = torch._C.ExtraFilesMap() # ignore: attr-defined + extra_files = torch._C.ExtraFilesMap() # type:ignore[attr-defined] extra_files[METADATA_FILENAME] = json_data.encode() if more_extra_files is not None: @@ -128,13 +128,15 @@ def load_net_with_metadata( extra_files = {f: "" for f in more_extra_files} extra_files[METADATA_FILENAME] = "" else: - extra_files = torch._C.ExtraFilesMap() # ignore: attr-defined + extra_files = torch._C.ExtraFilesMap() # type:ignore[attr-defined] extra_files[METADATA_FILENAME] = "" for f in more_extra_files: extra_files[f] = "" jit_obj = torch.jit.load(filename_prefix_or_stream, map_location, extra_files) + + extra_files = dict(extra_files.items()) # compatibility with ExtraFilesMap if METADATA_FILENAME in extra_files: json_data = extra_files[METADATA_FILENAME] diff --git a/tests/test_torchscript_utils.py b/tests/test_torchscript_utils.py index 6792c8ef2e..51bae4cc7f 100644 --- a/tests/test_torchscript_utils.py +++ b/tests/test_torchscript_utils.py @@ -18,6 +18,7 @@ from monai.config import get_config_values from monai.data import load_net_with_metadata, save_net_with_metadata from monai.utils import JITMetadataKeys +from monai.utils.module import pytorch_after class TestModule(torch.nn.Module): @@ -101,7 +102,10 @@ def test_save_load_more_extra_files(self): _, _, loaded_extra_files = load_net_with_metadata(f"{tempdir}/test.pt", more_extra_files=("test.txt",)) - self.assertEqual(more_extra_files["test.txt"], loaded_extra_files["test.txt"]) + if pytorch_after(1, 6): + self.assertEqual(more_extra_files["test.txt"], loaded_extra_files["test.txt"]) + else: + self.assertEqual(more_extra_files["test.txt"].decode(), loaded_extra_files["test.txt"]) if __name__ == "__main__": From 123c585801e003be63ff1b98ddc0d4b65553386c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 19 Dec 2021 23:03:56 +0000 Subject: [PATCH 13/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/torchscript_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index 41cff3f583..4dc8753b4f 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -135,7 +135,7 @@ def load_net_with_metadata( extra_files[f] = "" jit_obj = torch.jit.load(filename_prefix_or_stream, map_location, extra_files) - + extra_files = dict(extra_files.items()) # compatibility with ExtraFilesMap if METADATA_FILENAME in extra_files: From 7818227aac501c2b981b2f1eefbfc20552f54780 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Sun, 19 Dec 2021 23:51:16 +0000 Subject: [PATCH 14/15] Updates Signed-off-by: Eric Kerfoot --- tests/test_torchscript_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_torchscript_utils.py b/tests/test_torchscript_utils.py index 51bae4cc7f..b9840f68f1 100644 --- a/tests/test_torchscript_utils.py +++ b/tests/test_torchscript_utils.py @@ -102,7 +102,7 @@ def test_save_load_more_extra_files(self): _, _, loaded_extra_files = load_net_with_metadata(f"{tempdir}/test.pt", more_extra_files=("test.txt",)) - if pytorch_after(1, 6): + if pytorch_after(1, 7): self.assertEqual(more_extra_files["test.txt"], loaded_extra_files["test.txt"]) else: self.assertEqual(more_extra_files["test.txt"].decode(), loaded_extra_files["test.txt"]) From b9e30abf37e1e72102047489b6743c63c25620c9 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Mon, 20 Dec 2021 22:45:03 +0000 Subject: [PATCH 15/15] Updates Signed-off-by: Eric Kerfoot --- monai/data/torchscript_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index 4dc8753b4f..17539ce1dc 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -61,7 +61,7 @@ def save_net_with_metadata( Args: jit_obj: object to save, should be generated by `script` or `trace`. - filename_prefix_or_stream: filename or file-like stream object, if filename has no extension it becomes .pt. + filename_prefix_or_stream: filename or file-like stream object, if filename has no extension it becomes `.pt`. include_config_vals: if True, MONAI, Pytorch, and Numpy versions are included in metadata. append_timestamp: if True, a timestamp for "now" is appended to the file's name before the extension. meta_values: metadata values to store with the object, not limited just to keys in `JITMetadataKeys`.