diff --git a/monai/networks/utils.py b/monai/networks/utils.py index f301c2dd5c..bd65ffa33e 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -851,7 +851,7 @@ def _onnx_trt_compile( # wrap the serialized TensorRT engine back to a TorchScript module. trt_model = torch_tensorrt.ts.embed_engine_in_new_module( f.getvalue(), - device=torch.device(f"cuda:{device}"), + device=torch_tensorrt.Device(f"cuda:{device}"), input_binding_names=input_names, output_binding_names=output_names, )