From c42394497b35cd273882e1c802d83db7d29ef950 Mon Sep 17 00:00:00 2001 From: imgaojun Date: Tue, 7 Nov 2023 16:39:17 +0800 Subject: [PATCH 1/2] Fix serialization error with Tensor Parallel state saving --- colossalai/checkpoint_io/utils.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 06dab1fdb72a..d833a9258489 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -293,6 +293,22 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> # Helper functions for saving state dict # ====================================== +def move_to_cpu(obj): + """ + Recursively move tensors to CPU to avoid serialization issues with CUDA tensors. + """ + if torch.is_tensor(obj): + return obj.cpu() + elif isinstance(obj, dict): + return {k: move_to_cpu(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [move_to_cpu(o) for o in obj] + elif isinstance(obj, tuple): + return tuple(move_to_cpu(o) for o in obj) + elif isinstance(obj, OrderedDict): + return OrderedDict((k, move_to_cpu(v)) for k, v in obj.items()) + else: + return obj def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None: """ @@ -303,6 +319,9 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors checkpoint_file_path (str): path to the checkpoint file. use_safetensors (bool): whether to use safetensors to save the checkpoint. """ + # Move all tensors in the state_dict to CPU before saving to avoid serialization issues + state_dict_cpu = move_to_cpu(state_dict) + if use_safetensors: assert is_safetensors_available(), "safetensors is not available." assert checkpoint_file_path.endswith( @@ -310,9 +329,9 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors ), "safetensors only supports .safetensors suffix for checkpoint file." from safetensors.torch import save_file as safe_save_file - safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"}) + safe_save_file(state_dict_cpu, checkpoint_file_path, metadata={"format": "pt"}) else: - torch.save(state_dict, checkpoint_file_path) + torch.save(state_dict_cpu, checkpoint_file_path) def save_param_groups(state_dict: dict, group_file_path: str) -> None: From 0fccdc8acb4b1aee9b26b3f52153edc855350d7b Mon Sep 17 00:00:00 2001 From: imgaojun Date: Wed, 8 Nov 2023 16:02:35 +0800 Subject: [PATCH 2/2] Refactor state_dict CPU transfer using tree_map --- colossalai/checkpoint_io/utils.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index d833a9258489..e1800f29b0af 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -11,6 +11,7 @@ import torch.nn as nn from packaging.version import Version from torch.optim import Optimizer +from torch.utils._pytree import tree_map from colossalai.tensor.d_tensor import ( is_customized_distributed_tensor, @@ -293,23 +294,6 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> # Helper functions for saving state dict # ====================================== -def move_to_cpu(obj): - """ - Recursively move tensors to CPU to avoid serialization issues with CUDA tensors. - """ - if torch.is_tensor(obj): - return obj.cpu() - elif isinstance(obj, dict): - return {k: move_to_cpu(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [move_to_cpu(o) for o in obj] - elif isinstance(obj, tuple): - return tuple(move_to_cpu(o) for o in obj) - elif isinstance(obj, OrderedDict): - return OrderedDict((k, move_to_cpu(v)) for k, v in obj.items()) - else: - return obj - def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None: """ Save state dict to checkpoint. @@ -320,7 +304,7 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors use_safetensors (bool): whether to use safetensors to save the checkpoint. """ # Move all tensors in the state_dict to CPU before saving to avoid serialization issues - state_dict_cpu = move_to_cpu(state_dict) + state_dict_cpu = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict) if use_safetensors: assert is_safetensors_available(), "safetensors is not available."