diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index 3c347dad22..4dec09c889 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -10,6 +10,7 @@ # limitations under the License. from .utils import ( + convert_to_torchscript, copy_model_state, eval_mode, icnr_init, diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 529dfbf977..c60867d8e6 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -15,12 +15,14 @@ import warnings from collections import OrderedDict from contextlib import contextmanager -from typing import Any, Callable, Mapping, Optional, Sequence, Union +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union import torch import torch.nn as nn from monai.utils.deprecate_utils import deprecated_arg +from monai.utils.misc import ensure_tuple, set_determinism +from monai.utils.module import PT_BEFORE_1_7 __all__ = [ "one_hot", @@ -34,6 +36,7 @@ "eval_mode", "train_mode", "copy_model_state", + "convert_to_torchscript", ] @@ -424,3 +427,68 @@ def copy_model_state( if inplace and isinstance(dst, torch.nn.Module): dst.load_state_dict(dst_dict) return dst_dict, updated_keys, unchanged_keys + + +def convert_to_torchscript( + model: nn.Module, + filename_or_obj: Optional[Any] = None, + extra_files: Optional[Dict] = None, + verify: bool = False, + inputs: Optional[Sequence[Any]] = None, + device: Optional[torch.device] = None, + rtol: float = 1e-4, + atol: float = 0.0, + **kwargs, +): + """ + Utility to convert a model into TorchScript model and save to file, + with optional input / output data verification. + + Args: + model: source PyTorch model to save. + filename_or_obj: if not None, specify a file-like object (has to implement write and flush) + or a string containing a file path name to save the TorchScript model. + extra_files: map from filename to contents which will be stored as part of the save model file. + works for PyTorch 1.7 or later. + for more details: https://pytorch.org/docs/stable/generated/torch.jit.save.html. + verify: whether to verify the input and output of TorchScript model. + if `filename_or_obj` is not None, load the saved TorchScript model and verify. + inputs: input test data to verify model, should be a sequence of data, every item maps to a argument + of `model()` function. + device: target device to verify the model, if None, use CUDA if available. + rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model. + atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model. + + """ + model.eval() + with torch.no_grad(): + script_module = torch.jit.script(model) + if filename_or_obj is not None: + if PT_BEFORE_1_7: + torch.jit.save(m=script_module, f=filename_or_obj) + else: + torch.jit.save(m=script_module, f=filename_or_obj, _extra_files=extra_files) + + if verify: + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if inputs is None: + raise ValueError("missing input data for verification.") + + inputs = [i.to(device) if isinstance(i, torch.Tensor) else i for i in inputs] + ts_model = torch.jit.load(filename_or_obj) if filename_or_obj is not None else script_module + ts_model.eval().to(device) + model = model.to(device) + + with torch.no_grad(): + set_determinism(seed=0) + torch_out = ensure_tuple(model(*inputs)) + set_determinism(seed=0) + torchscript_out = ensure_tuple(ts_model(*inputs)) + set_determinism(seed=None) + # compare TorchScript and PyTorch results + for r1, r2 in zip(torch_out, torchscript_out): + if isinstance(r1, torch.Tensor) or isinstance(r2, torch.Tensor): + torch.testing.assert_allclose(r1, r2, rtol=rtol, atol=atol) + + return script_module diff --git a/tests/test_convert_to_torchscript.py b/tests/test_convert_to_torchscript.py new file mode 100644 index 0000000000..a772610a04 --- /dev/null +++ b/tests/test_convert_to_torchscript.py @@ -0,0 +1,42 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import torch + +from monai.networks import convert_to_torchscript +from monai.networks.nets import UNet + + +class TestConvertToTorchScript(unittest.TestCase): + def test_value(self): + model = UNet( + spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0 + ) + with tempfile.TemporaryDirectory() as tempdir: + torchscript_model = convert_to_torchscript( + model=model, + filename_or_obj=os.path.join(tempdir, "model.ts"), + extra_files={"foo.txt": b"bar"}, + verify=True, + inputs=[torch.randn((16, 1, 32, 32), requires_grad=False)], + device="cuda" if torch.cuda.is_available() else "cpu", + rtol=1e-3, + atol=1e-4, + ) + self.assertTrue(isinstance(torchscript_model, torch.nn.Module)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py index c8e0a35c92..c3b1bc259b 100644 --- a/tests/test_utils_pytorch_numpy_unification.py +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -15,7 +15,8 @@ import torch from monai.transforms.utils_pytorch_numpy_unification import percentile -from tests.utils import TEST_NDARRAYS, assert_allclose, set_determinism +from monai.utils import set_determinism +from tests.utils import TEST_NDARRAYS, assert_allclose class TestPytorchNumpyUnification(unittest.TestCase): diff --git a/tests/utils.py b/tests/utils.py index c73a87d141..97e57c6be6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -22,7 +22,6 @@ import unittest import warnings from functools import partial -from io import BytesIO from subprocess import PIPE, Popen from typing import Callable, Optional, Tuple from urllib.error import HTTPError, URLError @@ -35,7 +34,8 @@ from monai.config.deviceconfig import USE_COMPILED from monai.config.type_definitions import NdarrayOrTensor from monai.data import create_test_image_2d, create_test_image_3d -from monai.utils import ensure_tuple, optional_import, set_determinism +from monai.networks import convert_to_torchscript +from monai.utils import optional_import from monai.utils.misc import is_module_ver_at_least from monai.utils.module import version_leq from monai.utils.type_conversion import convert_data_type @@ -578,54 +578,26 @@ def setUp(self): self.segn = torch.tensor(self.segn) -def test_script_save(net, *inputs, eval_nets=True, device=None, rtol=1e-4): +def test_script_save(net, *inputs, device=None, rtol=1e-4, atol=0.0): """ Test the ability to save `net` as a Torchscript object, reload it, and apply inference. The value `inputs` is - forward-passed through the original and loaded copy of the network and their results returned. Both `net` and its - reloaded copy are set to evaluation mode if `eval_nets` is True. The forward pass for both is done without - gradient accumulation. + forward-passed through the original and loaded copy of the network and their results returned. + The forward pass for both is done without gradient accumulation. The test will be performed with CUDA if available, else CPU. """ - if True: - device = "cpu" - else: - # TODO: It would be nice to be able to use GPU if - # available, but this currently causes CI failures. - if not device: - device = "cuda" if torch.cuda.is_available() else "cpu" - - # Convert to device - inputs = [i.to(device) for i in inputs] - - scripted = torch.jit.script(net.cpu()) - buffer = scripted.save_to_buffer() - reloaded_net = torch.jit.load(BytesIO(buffer)).to(device) - net.to(device) - - if eval_nets: - net.eval() - reloaded_net.eval() - - with torch.no_grad(): - set_determinism(seed=0) - result1 = net(*inputs) - result2 = reloaded_net(*inputs) - set_determinism(seed=None) - - # convert results to tuples if needed to allow iterating over pairs of outputs - result1 = ensure_tuple(result1) - result2 = ensure_tuple(result2) - - for i, (r1, r2) in enumerate(zip(result1, result2)): - if None not in (r1, r2): # might be None - np.testing.assert_allclose( - r1.detach().cpu().numpy(), - r2.detach().cpu().numpy(), - rtol=rtol, - atol=0, - err_msg=f"failed on comparison number: {i}", - ) + # TODO: would be nice to use GPU if available, but it currently causes CI failures. + device = "cpu" + with tempfile.TemporaryDirectory() as tempdir: + convert_to_torchscript( + model=net, + filename_or_obj=os.path.join(tempdir, "model.ts"), + verify=True, + inputs=inputs, + device=device, + rtol=rtol, + atol=atol, + ) def query_memory(n=2):