From 177e84208203a7d70b2270064b15dda710f09472 Mon Sep 17 00:00:00 2001 From: botbw Date: Wed, 13 Nov 2024 02:45:40 +0000 Subject: [PATCH 1/2] [utils] sync save_state_dict from tensornvme --- colossalai/utils/safetensors.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index 9aa3558d9926..ec932da2434a 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -1,7 +1,7 @@ # a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214 import json from dataclasses import asdict, dataclass -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Optional import torch from safetensors.torch import _TYPES @@ -27,10 +27,10 @@ class PreparedData: offset: int -def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor]]: +def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[str]]: sorted_data = sorted(data.items(), key=lambda x: (x[1].dtype, x[0])) - tensors = [] + tensor_keys = [] metadata = {} offset = 0 @@ -41,7 +41,7 @@ def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Ten ) offset += n metadata[name] = asdict(tensor_info) - tensors.append(tensor) + tensor_keys.append(name) metadata_buf = json.dumps(metadata).encode("utf-8") @@ -50,15 +50,18 @@ def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Ten n = len(metadata_buf) - return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors + return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensor_keys - -def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None: - prepared_data, tensors = prepare(state_dict) +def save_state_dict(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None) -> None: + prepared_data, tensor_keys = prepare(state_dict) n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset f_writer.write(n.to_bytes(8, byteorder="little")) f_writer.write(header_bytes) - for tensor in tensors: - f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset) + f_writer.register_h2d(len(tensor_keys)) + for name in tensor_keys: + if state_dict_pinned: + f_writer.write_tensor(state_dict[name], state_dict_pinned[name]) + else: + f_writer.write_tensor(state_dict[name]) From 555f5f833afc1d3ec340e8afa34b2101b792b2d9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Nov 2024 02:48:40 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/utils/safetensors.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index ec932da2434a..e3c70a5b772a 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -1,7 +1,7 @@ # a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214 import json from dataclasses import asdict, dataclass -from typing import Dict, List, Tuple, Optional +from typing import Dict, List, Optional, Tuple import torch from safetensors.torch import _TYPES @@ -52,7 +52,12 @@ def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[str]]: return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensor_keys -def save_state_dict(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None) -> None: + +def save_state_dict( + f_writer: AsyncFileWriter, + state_dict: Dict[str, torch.Tensor], + state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None, +) -> None: prepared_data, tensor_keys = prepare(state_dict) n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset