Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions monai/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

from .utils import (
convert_to_torchscript,
copy_model_state,
eval_mode,
icnr_init,
Expand Down
70 changes: 69 additions & 1 deletion monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
import warnings
from collections import OrderedDict
from contextlib import contextmanager
from typing import Any, Callable, Mapping, Optional, Sequence, Union
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union

import torch
import torch.nn as nn

from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.misc import ensure_tuple, set_determinism
from monai.utils.module import PT_BEFORE_1_7

__all__ = [
"one_hot",
Expand All @@ -34,6 +36,7 @@
"eval_mode",
"train_mode",
"copy_model_state",
"convert_to_torchscript",
]


Expand Down Expand Up @@ -424,3 +427,68 @@ def copy_model_state(
if inplace and isinstance(dst, torch.nn.Module):
dst.load_state_dict(dst_dict)
return dst_dict, updated_keys, unchanged_keys


def convert_to_torchscript(
model: nn.Module,
filename_or_obj: Optional[Any] = None,
extra_files: Optional[Dict] = None,
verify: bool = False,
inputs: Optional[Sequence[Any]] = None,
device: Optional[torch.device] = None,
rtol: float = 1e-4,
atol: float = 0.0,
**kwargs,
):
"""
Utility to convert a model into TorchScript model and save to file,
with optional input / output data verification.

Args:
model: source PyTorch model to save.
filename_or_obj: if not None, specify a file-like object (has to implement write and flush)
or a string containing a file path name to save the TorchScript model.
extra_files: map from filename to contents which will be stored as part of the save model file.
works for PyTorch 1.7 or later.
for more details: https://pytorch.org/docs/stable/generated/torch.jit.save.html.
verify: whether to verify the input and output of TorchScript model.
if `filename_or_obj` is not None, load the saved TorchScript model and verify.
inputs: input test data to verify model, should be a sequence of data, every item maps to a argument
of `model()` function.
device: target device to verify the model, if None, use CUDA if available.
rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model.
atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model.

"""
model.eval()
with torch.no_grad():
script_module = torch.jit.script(model)
if filename_or_obj is not None:
if PT_BEFORE_1_7:
torch.jit.save(m=script_module, f=filename_or_obj)
else:
torch.jit.save(m=script_module, f=filename_or_obj, _extra_files=extra_files)

if verify:
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if inputs is None:
raise ValueError("missing input data for verification.")

inputs = [i.to(device) if isinstance(i, torch.Tensor) else i for i in inputs]
ts_model = torch.jit.load(filename_or_obj) if filename_or_obj is not None else script_module
ts_model.eval().to(device)
model = model.to(device)

with torch.no_grad():
set_determinism(seed=0)
torch_out = ensure_tuple(model(*inputs))
set_determinism(seed=0)
torchscript_out = ensure_tuple(ts_model(*inputs))
set_determinism(seed=None)
# compare TorchScript and PyTorch results
for r1, r2 in zip(torch_out, torchscript_out):
if isinstance(r1, torch.Tensor) or isinstance(r2, torch.Tensor):
torch.testing.assert_allclose(r1, r2, rtol=rtol, atol=atol)

return script_module
42 changes: 42 additions & 0 deletions tests/test_convert_to_torchscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest

import torch

from monai.networks import convert_to_torchscript
from monai.networks.nets import UNet


class TestConvertToTorchScript(unittest.TestCase):
def test_value(self):
model = UNet(
spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0
)
with tempfile.TemporaryDirectory() as tempdir:
torchscript_model = convert_to_torchscript(
model=model,
filename_or_obj=os.path.join(tempdir, "model.ts"),
extra_files={"foo.txt": b"bar"},
verify=True,
inputs=[torch.randn((16, 1, 32, 32), requires_grad=False)],
device="cuda" if torch.cuda.is_available() else "cpu",
rtol=1e-3,
atol=1e-4,
)
self.assertTrue(isinstance(torchscript_model, torch.nn.Module))


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion tests/test_utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import torch

from monai.transforms.utils_pytorch_numpy_unification import percentile
from tests.utils import TEST_NDARRAYS, assert_allclose, set_determinism
from monai.utils import set_determinism
from tests.utils import TEST_NDARRAYS, assert_allclose


class TestPytorchNumpyUnification(unittest.TestCase):
Expand Down
62 changes: 17 additions & 45 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import unittest
import warnings
from functools import partial
from io import BytesIO
from subprocess import PIPE, Popen
from typing import Callable, Optional, Tuple
from urllib.error import HTTPError, URLError
Expand All @@ -35,7 +34,8 @@
from monai.config.deviceconfig import USE_COMPILED
from monai.config.type_definitions import NdarrayOrTensor
from monai.data import create_test_image_2d, create_test_image_3d
from monai.utils import ensure_tuple, optional_import, set_determinism
from monai.networks import convert_to_torchscript
from monai.utils import optional_import
from monai.utils.misc import is_module_ver_at_least
from monai.utils.module import version_leq
from monai.utils.type_conversion import convert_data_type
Expand Down Expand Up @@ -578,54 +578,26 @@ def setUp(self):
self.segn = torch.tensor(self.segn)


def test_script_save(net, *inputs, eval_nets=True, device=None, rtol=1e-4):
def test_script_save(net, *inputs, device=None, rtol=1e-4, atol=0.0):
"""
Test the ability to save `net` as a Torchscript object, reload it, and apply inference. The value `inputs` is
forward-passed through the original and loaded copy of the network and their results returned. Both `net` and its
reloaded copy are set to evaluation mode if `eval_nets` is True. The forward pass for both is done without
gradient accumulation.
forward-passed through the original and loaded copy of the network and their results returned.
The forward pass for both is done without gradient accumulation.

The test will be performed with CUDA if available, else CPU.
"""
if True:
device = "cpu"
else:
# TODO: It would be nice to be able to use GPU if
# available, but this currently causes CI failures.
if not device:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Convert to device
inputs = [i.to(device) for i in inputs]

scripted = torch.jit.script(net.cpu())
buffer = scripted.save_to_buffer()
reloaded_net = torch.jit.load(BytesIO(buffer)).to(device)
net.to(device)

if eval_nets:
net.eval()
reloaded_net.eval()

with torch.no_grad():
set_determinism(seed=0)
result1 = net(*inputs)
result2 = reloaded_net(*inputs)
set_determinism(seed=None)

# convert results to tuples if needed to allow iterating over pairs of outputs
result1 = ensure_tuple(result1)
result2 = ensure_tuple(result2)

for i, (r1, r2) in enumerate(zip(result1, result2)):
if None not in (r1, r2): # might be None
np.testing.assert_allclose(
r1.detach().cpu().numpy(),
r2.detach().cpu().numpy(),
rtol=rtol,
atol=0,
err_msg=f"failed on comparison number: {i}",
)
# TODO: would be nice to use GPU if available, but it currently causes CI failures.
device = "cpu"
with tempfile.TemporaryDirectory() as tempdir:
convert_to_torchscript(
model=net,
filename_or_obj=os.path.join(tempdir, "model.ts"),
verify=True,
inputs=inputs,
device=device,
rtol=rtol,
atol=atol,
)


def query_memory(n=2):
Expand Down