@@ -854,7 +854,6 @@ class AutoDeviceAdapt:
854854 calling ``mod(inputs, ...)`` with heterogeneous ``inputs`` on CUDA.
855855 Since the error handling is expensive, if the OOM happens frequently, ``mod`` should be adjusted instead of
856856 relying on this class.
857-
858857 """
859858
860859 def __init__ (self , mod ):
@@ -890,13 +889,16 @@ def __call__(self, inputs, *args, **kwargs):
890889 f" please review { type (self .mod )} settings if this message appears frequently."
891890 )
892891 if not self .has_cpu_thresh : # mod doesn't support dynamic input size based device allocation
892+ if isinstance (inputs , torch .Tensor ):
893+ inputs = inputs .to ("cpu" )
893894 if self .has_device : # run this call with mod.device="cpu"
894895 ori_device = self .mod .device
895896 self .mod .device = "cpu"
896- out = self .mod (inputs . to ( "cpu" ) , * args , ** kwargs )
897+ out = self .mod (inputs , * args , ** kwargs )
897898 self .mod .device = ori_device
898899 return out
899- if isinstance (inputs , torch .Tensor ):
900- return self .mod (inputs .to ("cpu" ), * args , ** kwargs ) # move inputs to cpu
901- self .mod .cpu_thresh = inputs .shape [2 :].numel () - 1 # try change the dynamic threshold before the call
902- return self .mod (inputs , * args , ** kwargs )
900+ return self .mod (inputs , * args , ** kwargs )
901+ if isinstance (inputs , torch .Tensor ):
902+ self .mod .cpu_thresh = inputs .shape [2 :].numel () - 1 # try change the dynamic threshold before the call
903+ return self .mod (inputs , * args , ** kwargs )
904+ raise e
0 commit comments