Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions colossalai/utils/safetensors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
import json
from dataclasses import asdict, dataclass
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple

import torch
from safetensors.torch import _TYPES
Expand All @@ -27,10 +27,10 @@ class PreparedData:
offset: int


def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor]]:
def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[str]]:
sorted_data = sorted(data.items(), key=lambda x: (x[1].dtype, x[0]))

tensors = []
tensor_keys = []
metadata = {}
offset = 0

Expand All @@ -41,7 +41,7 @@ def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Ten
)
offset += n
metadata[name] = asdict(tensor_info)
tensors.append(tensor)
tensor_keys.append(name)

metadata_buf = json.dumps(metadata).encode("utf-8")

Expand All @@ -50,15 +50,23 @@ def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Ten

n = len(metadata_buf)

return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors
return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensor_keys


def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
prepared_data, tensors = prepare(state_dict)
def save_state_dict(
f_writer: AsyncFileWriter,
state_dict: Dict[str, torch.Tensor],
state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
prepared_data, tensor_keys = prepare(state_dict)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset

f_writer.write(n.to_bytes(8, byteorder="little"))
f_writer.write(header_bytes)

for tensor in tensors:
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
f_writer.register_h2d(len(tensor_keys))
for name in tensor_keys:
if state_dict_pinned:
f_writer.write_tensor(state_dict[name], state_dict_pinned[name])
else:
f_writer.write_tensor(state_dict[name])