Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.nn as nn
from packaging.version import Version
from torch.optim import Optimizer
from torch.utils._pytree import tree_map

from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
Expand Down Expand Up @@ -293,7 +294,6 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
# Helper functions for saving state dict
# ======================================


def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
"""
Save state dict to checkpoint.
Expand All @@ -303,16 +303,19 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
checkpoint_file_path (str): path to the checkpoint file.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
"""
# Move all tensors in the state_dict to CPU before saving to avoid serialization issues
state_dict_cpu = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict)

if use_safetensors:
assert is_safetensors_available(), "safetensors is not available."
assert checkpoint_file_path.endswith(
".safetensors"
), "safetensors only supports .safetensors suffix for checkpoint file."
from safetensors.torch import save_file as safe_save_file

safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
safe_save_file(state_dict_cpu, checkpoint_file_path, metadata={"format": "pt"})
else:
torch.save(state_dict, checkpoint_file_path)
torch.save(state_dict_cpu, checkpoint_file_path)


def save_param_groups(state_dict: dict, group_file_path: str) -> None:
Expand Down