diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index 2a24c2920466..5087d71a3d62 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -48,7 +48,8 @@ def move(tensor, device): # to save host resources when DP > 1。 if tensor.is_meta: - return torch.empty_like(tensor, device=device) + # Keep tensor in meta device if tensor is meta. + return tensor else: # Using new tensors help in freeing memory (after split for example) was done before by calling clone(). # Using copy=True instead of clone() will help in case of cpu --> cpu.