Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ def _onnx_trt_compile(
output_names = [] if not output_names else output_names

# set up the TensorRT builder
torch_tensorrt.set_device(device)
torch.cuda.set_device(device)
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
Expand Down Expand Up @@ -931,7 +931,7 @@ def convert_to_trt(
warnings.warn(f"The dynamic batch range sequence should have 3 elements, but got {dynamic_batchsize} elements.")

device = device if device else 0
target_device = torch.device(f"cuda:{device}") if device else torch.device("cuda:0")
target_device = torch.device(f"cuda:{device}")
convert_precision = torch.float32 if precision == "fp32" else torch.half
inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)]

Expand Down Expand Up @@ -986,7 +986,7 @@ def scale_batch_size(input_shape: Sequence[int], scale_num: int):
ir_model,
inputs=input_placeholder,
enabled_precisions=convert_precision,
device=target_device,
device=torch_tensorrt.Device(f"cuda:{device}"),
ir="torchscript",
**kwargs,
)
Expand Down
1 change: 0 additions & 1 deletion monai/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ def load_submodules(
loader = mod_spec.loader
loader.exec_module(mod)
submodules.append(mod)

except OptionalImportError:
pass # could not import the optional deps., they are ignored
except ImportError as e:
Expand Down