Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,10 @@ def save_unsharded_optimizer(
state_dict = optimizer.state_dict(pinned_state_dicts, only_on_master=True)
if self.coordinator.is_master():
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

from colossalai.utils.safetensors import save_nested

f_writer = AsyncFileWriter(checkpoint, n_entries=self.N_WRITE_ENTRIES, backend="pthread")
save_nested(f_writer, state_dict)
f_writer = save_nested(checkpoint, state_dict)
self.async_writers.append(f_writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
Expand Down Expand Up @@ -222,16 +220,10 @@ def save_sharded_optimizer(
checkpoint_file_path = os.path.join(checkpoint, shard_file)
if self.coordinator.is_master():
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

from colossalai.utils.safetensors import save_nested

f_writer = AsyncFileWriter(
checkpoint_file_path,
n_entries=self.N_WRITE_ENTRIES,
backend="pthread",
)
save_nested(f_writer, shard)
f_writer = save_nested(checkpoint_file_path, shard)
self.async_writers.append(f_writer)
else:
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
Expand Down
2 changes: 0 additions & 2 deletions colossalai/checkpoint_io/checkpoint_io_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ class CheckpointIO(ABC):
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
"""

N_WRITE_ENTRIES: int = 32

# ======================================
# Public methods
# ======================================
Expand Down
5 changes: 1 addition & 4 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,11 @@ def save_unsharded_model(
pass

if use_async:
from tensornvme.async_file_io import AsyncFileWriter

writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])

else:
# save the checkpoint
Expand Down Expand Up @@ -196,7 +194,6 @@ def save_sharded_model(
base_filename=weights_name,
is_master=True,
pinned_state_dict=pinned_state_dict,
n_write_entries=self.N_WRITE_ENTRIES,
)
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
self.async_writers.extend(writers)
Expand Down
4 changes: 1 addition & 3 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,15 +686,13 @@ def save_unsharded_model(
for _state_dict in state_dict_list:
complete_state_dict.update(_state_dict)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

from colossalai.utils.safetensors import move_and_save

writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
else:
save_state_dict(complete_state_dict, checkpoint, use_safetensors)

Expand Down
8 changes: 2 additions & 6 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ def async_save_state_dict_shards(
base_filename: str,
is_master: bool,
pinned_state_dict: Optional[Dict[str, torch.Tensor]],
n_write_entries: int,
use_pp_format: bool = False,
) -> Tuple[int, Dict[str, torch.Tensor], list]:
"""
Expand All @@ -290,7 +289,6 @@ def async_save_state_dict_shards(
Returns:
int: the total size of shards
"""
from tensornvme.async_file_io import AsyncFileWriter

total_size = 0
shard_filenames = []
Expand All @@ -311,17 +309,15 @@ def async_save_state_dict_shards(
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)

writer = AsyncFileWriter(checkpoint_file_path, n_write_entries, backend="pthread")
writers.append(writer)

if pinned_state_dict is not None:
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in shard.keys()}
else:
sub_pinned_state_dict = create_pinned_state_dict(shard)
returned_state_dict.update(sub_pinned_state_dict)

# Only save on master rank.
move_and_save(writer, shard, sub_pinned_state_dict)
writer = move_and_save(checkpoint_file_path, shard, sub_pinned_state_dict)
writers.append(writer)
shard_filenames.append(shard_file)
del shard

Expand Down
18 changes: 10 additions & 8 deletions colossalai/utils/safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

from torch.distributed.distributed_c10d import _pickler, _unpickler

ASYNC_WRITE_ENTRIES = 32


def _object_to_tensor(obj, device):
f = io.BytesIO()
Expand Down Expand Up @@ -149,32 +151,31 @@ def prepare(
return PreparedData(n=n, header_bytes=header_buf, offset=offset), tensors, tensor_keys


def save(
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
) -> None:
def save(path: str, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None) -> None:
prepared_data, tensors, _ = prepare(state_dict, metadata)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset

f_writer = AsyncFileWriter(path, n_entries=ASYNC_WRITE_ENTRIES, backend="pthread", n_tasks=2 + len(tensors))
f_writer.write(n.to_bytes(8, byteorder="little"))
f_writer.write(header_bytes)

for tensor in tensors:
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
return f_writer


def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
def save_nested(path: str, state_dict: Dict[str, torch.Tensor]) -> None:
flatten_data, metadata = _flatten_optim_state_dict(state_dict)
save(f_writer, flatten_data, metadata)
return save(path, flatten_data, metadata)


def move_and_save(
f_writer: AsyncFileWriter,
path: str,
state_dict: Dict[str, torch.Tensor],
state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
prepared_data, _, tensor_keys = prepare(state_dict)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset

f_writer = AsyncFileWriter(path, n_entries=ASYNC_WRITE_ENTRIES, backend="pthread", n_tasks=2 + len(tensor_keys))
f_writer.write(n.to_bytes(8, byteorder="little"))
f_writer.write(header_bytes)

Expand All @@ -184,6 +185,7 @@ def move_and_save(
f_writer.write_tensor(state_dict[name], state_dict_pinned[name])
else:
f_writer.write_tensor(state_dict[name])
return f_writer


def load_flat(checkpoint_path):
Expand Down
10 changes: 7 additions & 3 deletions colossalai/zero/low_level/bookkeeping/tensor_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,11 @@ def all_gather(self, group=None, fp8_communication: bool = False):
unflat_buffers = list(map(list, zip(*unflat_buffers)))
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
write_back_tensor = self._write_back_pairs[tensor]
write_back_tensor.data.copy_(
_flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()].reshape_as(write_back_tensor)
)
rec_tensor = _flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()]
if write_back_tensor.is_contiguous():
rec_tensor = rec_tensor.view_as(write_back_tensor)
else:
rec_tensor = rec_tensor.reshape_as(write_back_tensor)
write_back_tensor.data.copy_(rec_tensor)

self.empty()
24 changes: 7 additions & 17 deletions tests/test_checkpoint_io/test_safetensors_async_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,12 @@
import torch
from safetensors.torch import load_file

from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested

try:
from tensornvme.async_file_io import AsyncFileWriter
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")


from colossalai.testing import check_state_dict_equal
from colossalai.testing import check_state_dict_equal, clear_cache_before_run
from colossalai.utils import get_current_device
from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested


@clear_cache_before_run()
def test_save_load():
with tempfile.TemporaryDirectory() as tempdir:
optimizer_state_dict = {
Expand Down Expand Up @@ -111,17 +105,15 @@ def test_save_load():
}

optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
f_writer = AsyncFileWriter(optimizer_saved_path, n_entries=191, backend="pthread")
save_nested(f_writer, optimizer_state_dict)
f_writer = save_nested(optimizer_saved_path, optimizer_state_dict)
f_writer.sync_before_step()
f_writer.synchronize()
del f_writer
load_state_dict = load_flat(optimizer_saved_path)
check_state_dict_equal(load_state_dict, optimizer_state_dict)

optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
f_writer = AsyncFileWriter(optimizer_shard_saved_path, n_entries=191, backend="pthread")
save_nested(f_writer, optimizer_state_dict["state"])
f_writer = save_nested(optimizer_shard_saved_path, optimizer_state_dict["state"])
f_writer.sync_before_step()
f_writer.synchronize()
del f_writer
Expand All @@ -134,8 +126,7 @@ def test_save_load():
"module.weight2": torch.rand((1024, 1024)),
}
model_saved_path = f"{tempdir}/save_model.safetensors"
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
save(f_writer, model_state_dict)
f_writer = save(model_saved_path, model_state_dict)
f_writer.sync_before_step()
f_writer.synchronize()
del f_writer
Expand All @@ -145,8 +136,7 @@ def test_save_load():
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
f_writer = move_and_save(model_saved_path, model_state_dict_cuda, model_state_pinned)
f_writer.sync_before_step()
f_writer.synchronize()
del f_writer
Expand Down
4 changes: 3 additions & 1 deletion tests/test_optimizer/test_dist_lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from colossalai.nn.optimizer import DistributedLamb, Lamb
from colossalai.tensor.d_tensor import get_shard_dim_1d, is_distributed_tensor
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from colossalai.zero import LowLevelZeroOptimizer
from tests.kit.model_zoo import model_zoo
Expand Down Expand Up @@ -108,6 +108,7 @@ def set_dist_grad(
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
@parameterize("bias_correction", [False, True])
@parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)])
@clear_cache_before_run()
def run_dist_lamb_basic(
bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]
) -> None:
Expand Down Expand Up @@ -177,6 +178,7 @@ def run_dist_lamb_basic(
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
@parameterize("bias_correction", [False, True])
@parameterize("tp_zero_size", [(2, 2), (4, 1), (1, 4)])
@clear_cache_before_run()
def run_dist_lamb_fwd_bwd(
bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]
) -> None:
Expand Down