Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def save_unsharded_optimizer(
from colossalai.utils.safetensors import save_nested

f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]})
save_nested(f_writer, state_dict)
self.async_writers.append(f_writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
Expand Down
8 changes: 8 additions & 0 deletions colossalai/nn/optimizer/cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ def __init__(
# if you find yourself stuck here, make sure that you install colossalai with BUILD_EXT=1 specification
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)

def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
for group in self.param_groups:
for p in group["params"]:
state = self.state[p]
if "step" in state and isinstance(state["step"], torch.Tensor):
state["step"] = int(state["step"].item())

def torch_adam_update(
self,
data,
Expand Down
6 changes: 2 additions & 4 deletions colossalai/testing/comparison.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, OrderedDict, Tuple
from typing import Any, List, OrderedDict

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -78,9 +78,7 @@ def check_state_dict_equal(
v1 = v1.to(v2.dtype)
assert_close_loose(v1, v2)
else:
if isinstance(v1, Tuple) and not isinstance(v2, Tuple):
v2 = tuple(v2)
assert v1 == v2, f"{v1} not equals to {v2}. {type(v1)}, {type(v2)}"
assert v1 == v2, f"{v1} not equals to {v2}"


def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
Expand Down
141 changes: 86 additions & 55 deletions colossalai/utils/safetensors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
import json
import warnings
from dataclasses import asdict, dataclass
from typing import Dict, List, Optional, Tuple

Expand All @@ -12,6 +11,26 @@
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
_TYPES_INV = {v: k for k, v in _TYPES.items()}
import io

from torch.distributed.distributed_c10d import _pickler, _unpickler


def _object_to_tensor(obj, device):
f = io.BytesIO()
_pickler(f).dump(obj)
byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined]
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
# Otherwise, it will casue 100X slowdown.
# See: https://github.com/pytorch/pytorch/issues/65696
byte_tensor = torch.ByteTensor(byte_storage).to(device)
return byte_tensor


def _tensor_to_object(tensor, tensor_size):
tensor = tensor.cpu()
buf = tensor.numpy().tobytes()[:tensor_size]
return _unpickler(io.BytesIO(buf)).load()


@dataclass
Expand All @@ -28,49 +47,68 @@ class PreparedData:
offset: int


def flatten_dict(nested_dict, parent_key="", separator="^"):
"""
Flatten a nested dictionary, generating a flattened dictionary where the keys are joined by the specified separator.

nested_dict: The input nested dictionary.
parent_key: The parent key currently being processed.
separator: The separator used to join keys, default is '_', but can be customized to another symbol. :return: A flattened dictionary."
"""
items = []
for k, v in nested_dict.items():
new_key = f"{parent_key}{separator}{k}" if parent_key else str(k)
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, separator).items())
else:
v = torch.tensor(v, dtype=torch.float16) if not isinstance(v, torch.Tensor) else v
items.append((new_key, v))

return dict(items)


def unflatten_dict(flattened_dict, separator="^"):
"""
Restore a flattened dictionary back to a multi-level nested dictionary.

flattened_dict: The flattened dictionary.
separator: The separator used during flattening, default is '_', but can be customized to another symbol. :return: The restored nested dictionary.
"""
nested_dict = {}
for key, value in flattened_dict.items():
keys = key.split(separator)
try:
keys[0] = int(keys[0])
except ValueError:
warnings.warn(f"{key[0]} can't convert to integer")
d = nested_dict
for part in keys[:-1]:
if part not in d:
d[part] = {}
d = d[part]
assert isinstance(value, torch.Tensor)
d[keys[-1]] = value

return nested_dict
def _cast_to_tensor(obj):
if isinstance(obj, torch.Tensor):
return obj
return _object_to_tensor(obj, "cpu")


def _cast_to_object(tensor: torch.Tensor):
return _tensor_to_object(tensor, tensor.numel() * tensor.element_size())


def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[dict, Optional[dict]]:
flat_dict = {}
non_tensor_keys = []
if "state" in state_dict:
# 3-level dict
states = state_dict["state"]
else:
# 2-level dict, usually for optimizer state dict shard
states = state_dict

for idx, d in states.items():
for k, v in d.items():
nested_key = f"state{seperator}{idx}{seperator}{k}"
if not isinstance(v, torch.Tensor):
non_tensor_keys.append(nested_key)
flat_dict[nested_key] = _cast_to_tensor(v)
if "param_groups" in state_dict:
flat_dict["param_groups"] = _cast_to_tensor(state_dict["param_groups"])
non_tensor_keys.append("param_groups")
if len(non_tensor_keys) > 0:
metadata = {"non_tensor_keys": non_tensor_keys}
else:
metadata = None
return flat_dict, metadata


def _unflatten_optim_state_dict(flat_dict: dict, metadata: Optional[dict] = None, seperator: str = "."):
state_dict = {}
if metadata is not None:
non_tensor_keys = json.loads(metadata["non_tensor_keys"])
else:
non_tensor_keys = []
flat_dict = {k: _cast_to_object(v) if k in non_tensor_keys else v for k, v in flat_dict.items()}
if "param_groups" in flat_dict:
# 3-level dict
state_dict["param_groups"] = flat_dict.pop("param_groups")
state_dict["state"] = {}
states = state_dict["state"]
else:
# 2-level dict, usually for optimizer state dict shard
states = state_dict

for k, v in flat_dict.items():
parts = k.split(seperator)
assert len(parts) == 3 and parts[0] == "state"
idx = int(parts[1])
key = parts[2]
if idx not in states:
states[idx] = {}
states[idx][key] = v

return state_dict


def prepare(
Expand Down Expand Up @@ -124,10 +162,8 @@ def save(
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)


def save_nested(
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
) -> None:
flatten_data = flatten_dict(state_dict)
def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
flatten_data, metadata = _flatten_optim_state_dict(state_dict)
save(f_writer, flatten_data, metadata)


Expand All @@ -154,10 +190,5 @@ def load_flat(checkpoint_path):
with safe_open(checkpoint_path, framework="pt") as f:
metadata = f.metadata()
state_dict_load = load_file(checkpoint_path)
state_dict = unflatten_dict(state_dict_load)
if metadata is None:
return state_dict
metadata = dict(map(lambda item: (item[0], json.loads(item[1])), metadata.items()))
combined_state_dict = {"state": state_dict}
combined_state_dict.update(metadata)
return combined_state_dict
state_dict = _unflatten_optim_state_dict(state_dict_load, metadata)
return state_dict
64 changes: 45 additions & 19 deletions tests/test_checkpoint_io/test_safetensors_async_io.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,39 @@
import tempfile
from copy import deepcopy

import torch
from safetensors.torch import load_file

from colossalai.utils.safetensors import load_flat, save_nested
from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested

try:
from tensornvme.async_file_io import AsyncFileWriter
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")

from colossalai.testing import check_state_dict_equal
from colossalai.utils import get_current_device


def test_save_load():
with tempfile.TemporaryDirectory() as tempdir:
optimizer_state_dict = {
0: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
1: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
2: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
}
# group_dict = {"param_groups": [0, 1, 2]}
group_dict = {
"state": {
0: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
1: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
2: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
},
"param_groups": [
{
"lr": 0.001,
Expand Down Expand Up @@ -94,22 +106,26 @@ def test_save_load():
61,
],
}
]
],
}
metadata = deepcopy(group_dict)

optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")

save_nested(f_writer, optimizer_state_dict, metadata)
save_nested(f_writer, optimizer_state_dict)
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()

load_state_dict = load_flat(optimizer_saved_path)
state_dict = load_state_dict["state"]
group = {"param_groups": load_state_dict["param_groups"]}
check_state_dict_equal(optimizer_state_dict, state_dict)
check_state_dict_equal(group_dict, group)
check_state_dict_equal(load_state_dict, optimizer_state_dict)

optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
f_writer = AsyncFileWriter(fp=open(optimizer_shard_saved_path, "wb"), n_entries=191, backend="pthread")
save_nested(f_writer, optimizer_state_dict["state"])
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])

model_state_dict = {
"module.weight0": torch.rand((1024, 1024)),
Expand All @@ -118,10 +134,20 @@ def test_save_load():
}
model_saved_path = f"{tempdir}/save_model.safetensors"
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
save_nested(f_writer, model_state_dict)
save(f_writer, model_state_dict)
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
load_state_dict = load_file(model_saved_path)
check_state_dict_equal(model_state_dict, load_state_dict)

load_state_dict = load_flat(model_saved_path)
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
load_state_dict = load_file(model_saved_path)
check_state_dict_equal(model_state_dict, load_state_dict)