From 75518f7042d925cba02f939a66ab49a62193b998 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 23 Oct 2021 00:19:01 +0800 Subject: [PATCH 1/7] [DLMED] add utitlity Signed-off-by: Nic Ma --- monai/networks/__init__.py | 1 + monai/networks/utils.py | 44 +++++++++++++++++++++++++++++++ tests/test_save_to_torchscript.py | 40 ++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+) create mode 100644 tests/test_save_to_torchscript.py diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index 3c347dad22..65e7b26331 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -18,6 +18,7 @@ one_hot, pixelshuffle, predict_segmentation, + save_to_torchscript, slice_channels, to_norm_affine, train_mode, diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 529dfbf977..94828a15e6 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -34,6 +34,7 @@ "eval_mode", "train_mode", "copy_model_state", + "save_to_torchscript", ] @@ -424,3 +425,46 @@ def copy_model_state( if inplace and isinstance(dst, torch.nn.Module): dst.load_state_dict(dst_dict) return dst_dict, updated_keys, unchanged_keys + + +def save_to_torchscript( + model: nn.Module, + output_path: str, + verify: bool = False, + input_shape: Optional[Sequence[int]] = None, + device: Optional[torch.device] = None, + rtol: float = 1e-4, + atol: float = 0.0, +): + """ + Utility to save a model into TorchScript model with optional input / output data verification. + + Args: + model: source PyTorch model to save. + output_path: the path to save converted TorchScript model. + verify: whether to verify the input and output of TorchScript model. + input_shape: shape of the input data to verify model. + device: target device to verify the model. + rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model. + atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model. + + """ + model.eval() + with torch.no_grad(): + script_module = torch.jit.script(model) + script_module.save(output_path) + + if verify: + if input_shape is None: + raise ValueError("missing input_shape argument for verification.") + dummy_input = torch.randn(tuple(input_shape), requires_grad=False).to(device) + torchscript_model = torch.jit.load(output_path).eval().to(device) + model = model.to(device) + + with torch.no_grad(): + torch_out = model(dummy_input) + torchscript_out = torchscript_model(dummy_input) + # compare TorchScript and PyTorch results + torch.testing.assert_allclose(torch_out, torchscript_out, rtol=rtol, atol=atol) + + return script_module diff --git a/tests/test_save_to_torchscript.py b/tests/test_save_to_torchscript.py new file mode 100644 index 0000000000..fe7be5cc8c --- /dev/null +++ b/tests/test_save_to_torchscript.py @@ -0,0 +1,40 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import torch + +from monai.networks import save_to_torchscript +from monai.networks.nets import UNet + + +class TestArrayDataset(unittest.TestCase): + def test_shape(self): + model = UNet( + spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0 + ) + with tempfile.TemporaryDirectory() as tempdir: + torchscript_model = save_to_torchscript( + model=model, + output_path=os.path.join(tempdir, "model.ts"), + verify=True, + input_shape=(16, 1, 32, 32), + device="cuda" if torch.cuda.is_available() else "cpu", + rtol=1e-3, + ) + self.assertTrue(isinstance(torchscript_model, torch.jit._script.RecursiveScriptModule)) + + +if __name__ == "__main__": + unittest.main() From 2d69164e4febf04f54b072c4a9ec16956ebb57d1 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 23 Oct 2021 01:13:15 +0800 Subject: [PATCH 2/7] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/networks/__init__.py | 2 +- monai/networks/utils.py | 23 +++++++++++-------- ...ript.py => test_convert_to_torchscript.py} | 6 ++--- tests/utils.py | 3 ++- 4 files changed, 20 insertions(+), 14 deletions(-) rename tests/{test_save_to_torchscript.py => test_convert_to_torchscript.py} (90%) diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index 65e7b26331..4dec09c889 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -10,6 +10,7 @@ # limitations under the License. from .utils import ( + convert_to_torchscript, copy_model_state, eval_mode, icnr_init, @@ -18,7 +19,6 @@ one_hot, pixelshuffle, predict_segmentation, - save_to_torchscript, slice_channels, to_norm_affine, train_mode, diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 94828a15e6..77ce2a2277 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -34,7 +34,7 @@ "eval_mode", "train_mode", "copy_model_state", - "save_to_torchscript", + "convert_to_torchscript", ] @@ -427,43 +427,48 @@ def copy_model_state( return dst_dict, updated_keys, unchanged_keys -def save_to_torchscript( +def convert_to_torchscript( model: nn.Module, - output_path: str, + output_path: Optional[str] = None, verify: bool = False, input_shape: Optional[Sequence[int]] = None, device: Optional[torch.device] = None, rtol: float = 1e-4, atol: float = 0.0, + **kwargs, ): """ - Utility to save a model into TorchScript model with optional input / output data verification. + Utility to convert a model into TorchScript model and save to file, + with optional input / output data verification. Args: model: source PyTorch model to save. - output_path: the path to save converted TorchScript model. + output_path: if not None, specify the path to save converted TorchScript model. verify: whether to verify the input and output of TorchScript model. + if `output_path` is not None, load the saved TorchScript model and verify. input_shape: shape of the input data to verify model. device: target device to verify the model. rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model. atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model. + kwargs: besides input data, other arguments for the call function of model. """ model.eval() with torch.no_grad(): script_module = torch.jit.script(model) - script_module.save(output_path) + if output_path is not None: + script_module.save(output_path) if verify: if input_shape is None: raise ValueError("missing input_shape argument for verification.") dummy_input = torch.randn(tuple(input_shape), requires_grad=False).to(device) - torchscript_model = torch.jit.load(output_path).eval().to(device) + ts_model = (torch.jit.load(output_path) if output_path is not None else script_module).eval().to(device) model = model.to(device) with torch.no_grad(): - torch_out = model(dummy_input) - torchscript_out = torchscript_model(dummy_input) + torch_out = model(dummy_input, **kwargs) + torchscript_out = ts_model(dummy_input, **kwargs) # compare TorchScript and PyTorch results torch.testing.assert_allclose(torch_out, torchscript_out, rtol=rtol, atol=atol) diff --git a/tests/test_save_to_torchscript.py b/tests/test_convert_to_torchscript.py similarity index 90% rename from tests/test_save_to_torchscript.py rename to tests/test_convert_to_torchscript.py index fe7be5cc8c..31648511ab 100644 --- a/tests/test_save_to_torchscript.py +++ b/tests/test_convert_to_torchscript.py @@ -15,7 +15,7 @@ import torch -from monai.networks import save_to_torchscript +from monai.networks import convert_to_torchscript from monai.networks.nets import UNet @@ -25,7 +25,7 @@ def test_shape(self): spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0 ) with tempfile.TemporaryDirectory() as tempdir: - torchscript_model = save_to_torchscript( + torchscript_model = convert_to_torchscript( model=model, output_path=os.path.join(tempdir, "model.ts"), verify=True, @@ -33,7 +33,7 @@ def test_shape(self): device="cuda" if torch.cuda.is_available() else "cpu", rtol=1e-3, ) - self.assertTrue(isinstance(torchscript_model, torch.jit._script.RecursiveScriptModule)) + self.assertTrue(isinstance(torchscript_model, torch.nn.Module)) if __name__ == "__main__": diff --git a/tests/utils.py b/tests/utils.py index c73a87d141..a12d9f8ea6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -35,6 +35,7 @@ from monai.config.deviceconfig import USE_COMPILED from monai.config.type_definitions import NdarrayOrTensor from monai.data import create_test_image_2d, create_test_image_3d +from monai.networks import convert_to_torchscript from monai.utils import ensure_tuple, optional_import, set_determinism from monai.utils.misc import is_module_ver_at_least from monai.utils.module import version_leq @@ -598,7 +599,7 @@ def test_script_save(net, *inputs, eval_nets=True, device=None, rtol=1e-4): # Convert to device inputs = [i.to(device) for i in inputs] - scripted = torch.jit.script(net.cpu()) + scripted = convert_to_torchscript(net.cpu(), output_path=None) buffer = scripted.save_to_buffer() reloaded_net = torch.jit.load(BytesIO(buffer)).to(device) net.to(device) From c6e2d89b9b5040813b805a70fad3215aac499e92 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 23 Oct 2021 08:29:12 +0800 Subject: [PATCH 3/7] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/networks/utils.py | 17 +++++++++++------ tests/test_convert_to_torchscript.py | 8 +++++--- tests/utils.py | 7 ++++--- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 77ce2a2277..f054b27992 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -15,7 +15,7 @@ import warnings from collections import OrderedDict from contextlib import contextmanager -from typing import Any, Callable, Mapping, Optional, Sequence, Union +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union import torch import torch.nn as nn @@ -429,7 +429,8 @@ def copy_model_state( def convert_to_torchscript( model: nn.Module, - output_path: Optional[str] = None, + filename_or_obj: Optional[Any] = None, + extra_files: Optional[Dict] = None, verify: bool = False, input_shape: Optional[Sequence[int]] = None, device: Optional[torch.device] = None, @@ -443,7 +444,10 @@ def convert_to_torchscript( Args: model: source PyTorch model to save. - output_path: if not None, specify the path to save converted TorchScript model. + filename_or_obj: if not None, specify a file-like object (has to implement write and flush) + or a string containing a file path name to save the TorchScript model. + extra_files: map from filename to contents which will be stored as part of the save model file. + for more details: https://pytorch.org/docs/stable/generated/torch.jit.save.html. verify: whether to verify the input and output of TorchScript model. if `output_path` is not None, load the saved TorchScript model and verify. input_shape: shape of the input data to verify model. @@ -456,14 +460,15 @@ def convert_to_torchscript( model.eval() with torch.no_grad(): script_module = torch.jit.script(model) - if output_path is not None: - script_module.save(output_path) + if filename_or_obj is not None: + torch.jit.save(m=script_module, f=filename_or_obj, _extra_files=extra_files) if verify: if input_shape is None: raise ValueError("missing input_shape argument for verification.") dummy_input = torch.randn(tuple(input_shape), requires_grad=False).to(device) - ts_model = (torch.jit.load(output_path) if output_path is not None else script_module).eval().to(device) + ts_model = torch.jit.load(filename_or_obj) if filename_or_obj is not None else script_module + ts_model.eval().to(device) model = model.to(device) with torch.no_grad(): diff --git a/tests/test_convert_to_torchscript.py b/tests/test_convert_to_torchscript.py index 31648511ab..b313038f61 100644 --- a/tests/test_convert_to_torchscript.py +++ b/tests/test_convert_to_torchscript.py @@ -19,19 +19,21 @@ from monai.networks.nets import UNet -class TestArrayDataset(unittest.TestCase): - def test_shape(self): +class TestConvertToTorchScript(unittest.TestCase): + def test_value(self): model = UNet( spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0 ) with tempfile.TemporaryDirectory() as tempdir: torchscript_model = convert_to_torchscript( model=model, - output_path=os.path.join(tempdir, "model.ts"), + filename_or_obj=os.path.join(tempdir, "model.ts"), + extra_files={"foo.txt": b"bar"}, verify=True, input_shape=(16, 1, 32, 32), device="cuda" if torch.cuda.is_available() else "cpu", rtol=1e-3, + atol=1e-4, ) self.assertTrue(isinstance(torchscript_model, torch.nn.Module)) diff --git a/tests/utils.py b/tests/utils.py index a12d9f8ea6..38c7dc3f33 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -599,9 +599,10 @@ def test_script_save(net, *inputs, eval_nets=True, device=None, rtol=1e-4): # Convert to device inputs = [i.to(device) for i in inputs] - scripted = convert_to_torchscript(net.cpu(), output_path=None) - buffer = scripted.save_to_buffer() - reloaded_net = torch.jit.load(BytesIO(buffer)).to(device) + buffer = BytesIO() + convert_to_torchscript(net.cpu(), filename_or_obj=buffer, verify=False) + buffer.seek(0) + reloaded_net = torch.jit.load(buffer).to(device) net.to(device) if eval_nets: From c1f00d6663bbb080f123bd78bde5815232bbb8c1 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 23 Oct 2021 08:47:04 +0800 Subject: [PATCH 4/7] [DLMED] fix CI for old PyTorch Signed-off-by: Nic Ma --- monai/networks/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index f054b27992..0e9ad0708c 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -21,6 +21,7 @@ import torch.nn as nn from monai.utils.deprecate_utils import deprecated_arg +from monai.utils.module import PT_BEFORE_1_7 __all__ = [ "one_hot", @@ -447,6 +448,7 @@ def convert_to_torchscript( filename_or_obj: if not None, specify a file-like object (has to implement write and flush) or a string containing a file path name to save the TorchScript model. extra_files: map from filename to contents which will be stored as part of the save model file. + works for PyTorch 1.7 or later. for more details: https://pytorch.org/docs/stable/generated/torch.jit.save.html. verify: whether to verify the input and output of TorchScript model. if `output_path` is not None, load the saved TorchScript model and verify. @@ -461,7 +463,10 @@ def convert_to_torchscript( with torch.no_grad(): script_module = torch.jit.script(model) if filename_or_obj is not None: - torch.jit.save(m=script_module, f=filename_or_obj, _extra_files=extra_files) + if PT_BEFORE_1_7: + torch.jit.save(m=script_module, f=filename_or_obj) + else: + torch.jit.save(m=script_module, f=filename_or_obj, _extra_files=extra_files) if verify: if input_shape is None: From 54300691f06723e19606ea8ecf782ad33a3fd198 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sun, 24 Oct 2021 09:01:49 +0800 Subject: [PATCH 5/7] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/networks/utils.py | 29 +++++++++----- tests/test_convert_to_torchscript.py | 2 +- tests/utils.py | 59 ++++++---------------------- 3 files changed, 33 insertions(+), 57 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 0e9ad0708c..0e8d2b8f79 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -21,6 +21,7 @@ import torch.nn as nn from monai.utils.deprecate_utils import deprecated_arg +from monai.utils.misc import ensure_tuple, set_determinism from monai.utils.module import PT_BEFORE_1_7 __all__ = [ @@ -433,7 +434,7 @@ def convert_to_torchscript( filename_or_obj: Optional[Any] = None, extra_files: Optional[Dict] = None, verify: bool = False, - input_shape: Optional[Sequence[int]] = None, + inputs: Optional[Sequence[Any]] = None, device: Optional[torch.device] = None, rtol: float = 1e-4, atol: float = 0.0, @@ -452,11 +453,11 @@ def convert_to_torchscript( for more details: https://pytorch.org/docs/stable/generated/torch.jit.save.html. verify: whether to verify the input and output of TorchScript model. if `output_path` is not None, load the saved TorchScript model and verify. - input_shape: shape of the input data to verify model. - device: target device to verify the model. + inputs: input test data to verify model, should be a sequence of data, every item maps to a argument + of `model()` function. + device: target device to verify the model, if None, use CUDA if available. rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model. atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model. - kwargs: besides input data, other arguments for the call function of model. """ model.eval() @@ -469,17 +470,25 @@ def convert_to_torchscript( torch.jit.save(m=script_module, f=filename_or_obj, _extra_files=extra_files) if verify: - if input_shape is None: - raise ValueError("missing input_shape argument for verification.") - dummy_input = torch.randn(tuple(input_shape), requires_grad=False).to(device) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + if inputs is None: + raise ValueError("missing input data for verification.") + + inputs = [i.to(device) if isinstance(i, torch.Tensor) else i for i in inputs] ts_model = torch.jit.load(filename_or_obj) if filename_or_obj is not None else script_module ts_model.eval().to(device) model = model.to(device) with torch.no_grad(): - torch_out = model(dummy_input, **kwargs) - torchscript_out = ts_model(dummy_input, **kwargs) + set_determinism(seed=0) + torch_out = ensure_tuple(model(*inputs)) + set_determinism(seed=0) + torchscript_out = ensure_tuple(ts_model(*inputs)) + set_determinism(seed=None) # compare TorchScript and PyTorch results - torch.testing.assert_allclose(torch_out, torchscript_out, rtol=rtol, atol=atol) + for r1, r2 in zip(torch_out, torchscript_out): + if isinstance(r1, torch.Tensor) or isinstance(r2, torch.Tensor): + torch.testing.assert_allclose(r1, r2, rtol=rtol, atol=atol) return script_module diff --git a/tests/test_convert_to_torchscript.py b/tests/test_convert_to_torchscript.py index b313038f61..a772610a04 100644 --- a/tests/test_convert_to_torchscript.py +++ b/tests/test_convert_to_torchscript.py @@ -30,7 +30,7 @@ def test_value(self): filename_or_obj=os.path.join(tempdir, "model.ts"), extra_files={"foo.txt": b"bar"}, verify=True, - input_shape=(16, 1, 32, 32), + inputs=[torch.randn((16, 1, 32, 32), requires_grad=False)], device="cuda" if torch.cuda.is_available() else "cpu", rtol=1e-3, atol=1e-4, diff --git a/tests/utils.py b/tests/utils.py index 38c7dc3f33..19bf9977e6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -22,7 +22,6 @@ import unittest import warnings from functools import partial -from io import BytesIO from subprocess import PIPE, Popen from typing import Callable, Optional, Tuple from urllib.error import HTTPError, URLError @@ -36,7 +35,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.data import create_test_image_2d, create_test_image_3d from monai.networks import convert_to_torchscript -from monai.utils import ensure_tuple, optional_import, set_determinism +from monai.utils import optional_import from monai.utils.misc import is_module_ver_at_least from monai.utils.module import version_leq from monai.utils.type_conversion import convert_data_type @@ -579,55 +578,23 @@ def setUp(self): self.segn = torch.tensor(self.segn) -def test_script_save(net, *inputs, eval_nets=True, device=None, rtol=1e-4): +def test_script_save(net, *inputs, device=None, rtol=1e-4): """ Test the ability to save `net` as a Torchscript object, reload it, and apply inference. The value `inputs` is - forward-passed through the original and loaded copy of the network and their results returned. Both `net` and its - reloaded copy are set to evaluation mode if `eval_nets` is True. The forward pass for both is done without - gradient accumulation. + forward-passed through the original and loaded copy of the network and their results returned. + The forward pass for both is done without gradient accumulation. The test will be performed with CUDA if available, else CPU. """ - if True: - device = "cpu" - else: - # TODO: It would be nice to be able to use GPU if - # available, but this currently causes CI failures. - if not device: - device = "cuda" if torch.cuda.is_available() else "cpu" - - # Convert to device - inputs = [i.to(device) for i in inputs] - - buffer = BytesIO() - convert_to_torchscript(net.cpu(), filename_or_obj=buffer, verify=False) - buffer.seek(0) - reloaded_net = torch.jit.load(buffer).to(device) - net.to(device) - - if eval_nets: - net.eval() - reloaded_net.eval() - - with torch.no_grad(): - set_determinism(seed=0) - result1 = net(*inputs) - result2 = reloaded_net(*inputs) - set_determinism(seed=None) - - # convert results to tuples if needed to allow iterating over pairs of outputs - result1 = ensure_tuple(result1) - result2 = ensure_tuple(result2) - - for i, (r1, r2) in enumerate(zip(result1, result2)): - if None not in (r1, r2): # might be None - np.testing.assert_allclose( - r1.detach().cpu().numpy(), - r2.detach().cpu().numpy(), - rtol=rtol, - atol=0, - err_msg=f"failed on comparison number: {i}", - ) + with tempfile.TemporaryDirectory() as tempdir: + convert_to_torchscript( + model=net, + filename_or_obj=os.path.join(tempdir, "model.ts"), + verify=True, + inputs=inputs, + device=device, + rtol=rtol, + ) def query_memory(n=2): From 7fd52581512b19360cb5c1059c461409b49ec964 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sun, 24 Oct 2021 09:10:18 +0800 Subject: [PATCH 6/7] [DLMED] fix flake8 Signed-off-by: Nic Ma --- monai/networks/utils.py | 4 ++-- tests/test_utils_pytorch_numpy_unification.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 0e8d2b8f79..c60867d8e6 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -452,7 +452,7 @@ def convert_to_torchscript( works for PyTorch 1.7 or later. for more details: https://pytorch.org/docs/stable/generated/torch.jit.save.html. verify: whether to verify the input and output of TorchScript model. - if `output_path` is not None, load the saved TorchScript model and verify. + if `filename_or_obj` is not None, load the saved TorchScript model and verify. inputs: input test data to verify model, should be a sequence of data, every item maps to a argument of `model()` function. device: target device to verify the model, if None, use CUDA if available. @@ -471,7 +471,7 @@ def convert_to_torchscript( if verify: if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if inputs is None: raise ValueError("missing input data for verification.") diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py index c8e0a35c92..c3b1bc259b 100644 --- a/tests/test_utils_pytorch_numpy_unification.py +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -15,7 +15,8 @@ import torch from monai.transforms.utils_pytorch_numpy_unification import percentile -from tests.utils import TEST_NDARRAYS, assert_allclose, set_determinism +from monai.utils import set_determinism +from tests.utils import TEST_NDARRAYS, assert_allclose class TestPytorchNumpyUnification(unittest.TestCase): From 727b3431798ff985cd03924c4da7583ec1a804cf Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sun, 24 Oct 2021 09:56:05 +0800 Subject: [PATCH 7/7] [DLMED] fix CI issue Signed-off-by: Nic Ma --- tests/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index 19bf9977e6..97e57c6be6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -578,7 +578,7 @@ def setUp(self): self.segn = torch.tensor(self.segn) -def test_script_save(net, *inputs, device=None, rtol=1e-4): +def test_script_save(net, *inputs, device=None, rtol=1e-4, atol=0.0): """ Test the ability to save `net` as a Torchscript object, reload it, and apply inference. The value `inputs` is forward-passed through the original and loaded copy of the network and their results returned. @@ -586,6 +586,8 @@ def test_script_save(net, *inputs, device=None, rtol=1e-4): The test will be performed with CUDA if available, else CPU. """ + # TODO: would be nice to use GPU if available, but it currently causes CI failures. + device = "cpu" with tempfile.TemporaryDirectory() as tempdir: convert_to_torchscript( model=net, @@ -594,6 +596,7 @@ def test_script_save(net, *inputs, device=None, rtol=1e-4): inputs=inputs, device=device, rtol=rtol, + atol=atol, )