diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 6197be9d1c8d..20870a3c23a1 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -314,7 +314,7 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors use_safetensors (bool): whether to use safetensors to save the checkpoint. """ # Move all tensors in the state_dict to CPU before saving to avoid serialization issues - state_dict_cpu = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict) + state_dict_cpu = tree_map(lambda x: x.data.cpu() if torch.is_tensor(x) else x, state_dict) if use_safetensors: assert is_safetensors_available(), "safetensors is not available."