Skip to content

Commit ad83119

Browse files
committed
update
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent b35675a commit ad83119

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

monai/networks/utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,6 @@ class AutoDeviceAdapt:
854854
calling ``mod(inputs, ...)`` with heterogeneous ``inputs`` on CUDA.
855855
Since the error handling is expensive, if the OOM happens frequently, ``mod`` should be adjusted instead of
856856
relying on this class.
857-
858857
"""
859858

860859
def __init__(self, mod):
@@ -890,13 +889,16 @@ def __call__(self, inputs, *args, **kwargs):
890889
f" please review {type(self.mod)} settings if this message appears frequently."
891890
)
892891
if not self.has_cpu_thresh: # mod doesn't support dynamic input size based device allocation
892+
if isinstance(inputs, torch.Tensor):
893+
inputs = inputs.to("cpu")
893894
if self.has_device: # run this call with mod.device="cpu"
894895
ori_device = self.mod.device
895896
self.mod.device = "cpu"
896-
out = self.mod(inputs.to("cpu"), *args, **kwargs)
897+
out = self.mod(inputs, *args, **kwargs)
897898
self.mod.device = ori_device
898899
return out
899-
if isinstance(inputs, torch.Tensor):
900-
return self.mod(inputs.to("cpu"), *args, **kwargs) # move inputs to cpu
901-
self.mod.cpu_thresh = inputs.shape[2:].numel() - 1 # try change the dynamic threshold before the call
902-
return self.mod(inputs, *args, **kwargs)
900+
return self.mod(inputs, *args, **kwargs)
901+
if isinstance(inputs, torch.Tensor):
902+
self.mod.cpu_thresh = inputs.shape[2:].numel() - 1 # try change the dynamic threshold before the call
903+
return self.mod(inputs, *args, **kwargs)
904+
raise e

0 commit comments

Comments
 (0)