Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,22 +128,20 @@ def save_unsharded_optimizer(
# the `state_dict` in LowLevelZeroOptimizer has communication
# if only the master rank collect state_dict and save,
# the communication on each rank would not match
if use_async:
if use_async and self.coordinator.is_master():
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
state_dict = optimizer.state_dict(pinned_state_dicts)
state_dict = optimizer.state_dict(pinned_state_dicts, only_on_master=True)
if self.coordinator.is_master():
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

from colossalai.utils.safetensors import save_nested

f_writer = AsyncFileWriter(
fp=open(checkpoint, "wb", buffering=0), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
)
f_writer = AsyncFileWriter(checkpoint, n_entries=self.N_WRITE_ENTRIES, backend="pthread")
save_nested(f_writer, state_dict)
self.async_writers.append(f_writer)
else:
Expand Down Expand Up @@ -192,13 +190,15 @@ def save_sharded_optimizer(
# state_dict only provide only 'param_groups'
state_dict = optimizer.optim.state_dict()
# state shard would be handled by the low-level zero optimizer
if use_async:
if use_async and self.coordinator.is_master():
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts)
sharded_state = optimizer.state_dict_shard(
max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts, only_on_master=True
)

# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
Expand Down Expand Up @@ -227,7 +227,7 @@ def save_sharded_optimizer(
from colossalai.utils.safetensors import save_nested

f_writer = AsyncFileWriter(
fp=open(checkpoint_file_path, "wb", buffering=0),
checkpoint_file_path,
n_entries=self.N_WRITE_ENTRIES,
backend="pthread",
)
Expand Down
1 change: 0 additions & 1 deletion colossalai/checkpoint_io/checkpoint_io_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def __init__(self):
def _sync_io(self):
for writer in self.async_writers:
writer.synchronize()
writer.fp.close()
self.async_writers.clear()

def _sync_d2h(self):
Expand Down
2 changes: 1 addition & 1 deletion colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def save_unsharded_model(
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

writer = AsyncFileWriter(open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread")
writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)
Expand Down
4 changes: 1 addition & 3 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,9 +690,7 @@ def save_unsharded_model(

from colossalai.utils.safetensors import move_and_save

writer = AsyncFileWriter(
open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread"
)
writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)
Expand Down
2 changes: 1 addition & 1 deletion colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def async_save_state_dict_shards(
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)

writer = AsyncFileWriter(open(checkpoint_file_path, "wb", buffering=0), n_write_entries, backend="pthread")
writer = AsyncFileWriter(checkpoint_file_path, n_write_entries, backend="pthread")
writers.append(writer)

if pinned_state_dict is not None:
Expand Down
57 changes: 37 additions & 20 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,9 @@ def pack_group(group):

return {"state": packed_state, "param_groups": param_groups}

def state_dict(self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None) -> Dict:
def state_dict(
self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None, only_on_master: bool = False
) -> Dict:
"""Return a state_dict same with DDP

Returns:
Expand All @@ -785,23 +787,29 @@ def state_dict(self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tens
zero_state = dict()
device = get_accelerator().get_current_device()
for param, state in self.optim.state.items():
working_param = self.master_to_working_param[id(param)]
pg = self.param_to_pg[working_param]
if not only_on_master or get_nd_rank(pg) == 0:
zero_state[param] = copy.deepcopy(state)
else:
zero_state[param] = {}

if pinned_state_dicts is not None and param not in pinned_state_dicts:
pinned_state_dicts[param] = {}
zero_state[param] = copy.deepcopy(state)

for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step":
working_param = self.master_to_working_param[id(param)]
pg = self.param_to_pg[working_param]
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param)
if pinned_state_dicts is not None and k not in pinned_state_dicts[param]:
pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu")
if pinned_state_dicts is not None:
pinned_state_dicts[param][k].copy_(param_state)
zero_state[param][k] = pinned_state_dicts[param][k]
else:
zero_state[param][k] = param_state.cpu()
if not only_on_master or get_nd_rank(pg) == 0:
if pinned_state_dicts is not None and k not in pinned_state_dicts[param]:
pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu")
if pinned_state_dicts is not None:
pinned_state_dicts[param][k].copy_(param_state)
zero_state[param][k] = pinned_state_dicts[param][k]
else:
zero_state[param][k] = param_state.cpu()

states_dict = self._pack_state(zero_state)

Expand Down Expand Up @@ -837,7 +845,10 @@ def load_state_dict(self, state_dict: Dict):
self.optim.load_state_dict(zero_state_dict)

def state_dict_shard(
self, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None
self,
max_shard_size: int = 1024,
pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None,
only_on_master: bool = False,
) -> Iterator[Tuple[Dict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
Only include the 'state' in state_dict.
Expand All @@ -862,25 +873,31 @@ def state_dict_shard(
cnt += 1
for param_idx, states in local_states.items():
current_block_size = 0
current_block = copy.deepcopy(states)
if pinned_state_dicts is not None and param_idx not in pinned_state_dicts:
pinned_state_dicts[param_idx] = {}
master_param = idx2master[param_idx]
working_param = self.master_to_working_param[id(master_param)]
pg = self.param_to_pg[working_param]
if not only_on_master or get_nd_rank(pg) == 0:
current_block = copy.deepcopy(states)
else:
current_block = {}

for k, v in states.items():
if isinstance(v, torch.Tensor) and k != "step":
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]:
pinned_state_dicts[param_idx][k] = torch.empty_like(state_tensor, pin_memory=True, device="cpu")
if pinned_state_dicts is not None:
pinned_state_dicts[param_idx][k].copy_(state_tensor)
current_block[k] = pinned_state_dicts[param_idx][k]
else:
current_block[k] = state_tensor.cpu()
if not only_on_master or get_nd_rank(pg) == 0:
if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]:
pinned_state_dicts[param_idx][k] = torch.empty_like(
state_tensor, pin_memory=True, device="cpu"
)
if pinned_state_dicts is not None:
pinned_state_dicts[param_idx][k].copy_(state_tensor)
current_block[k] = pinned_state_dicts[param_idx][k]
else:
current_block[k] = state_tensor.cpu()
current_block_size += calculate_tensor_size(state_tensor)

if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
Expand Down
17 changes: 9 additions & 8 deletions tests/test_checkpoint_io/test_safetensors_async_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")


from colossalai.testing import check_state_dict_equal
from colossalai.utils import get_current_device

Expand Down Expand Up @@ -110,20 +111,20 @@ def test_save_load():
}

optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
f_writer = AsyncFileWriter(optimizer_saved_path, n_entries=191, backend="pthread")
save_nested(f_writer, optimizer_state_dict)
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
del f_writer
load_state_dict = load_flat(optimizer_saved_path)
check_state_dict_equal(load_state_dict, optimizer_state_dict)

optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
f_writer = AsyncFileWriter(fp=open(optimizer_shard_saved_path, "wb"), n_entries=191, backend="pthread")
f_writer = AsyncFileWriter(optimizer_shard_saved_path, n_entries=191, backend="pthread")
save_nested(f_writer, optimizer_state_dict["state"])
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
del f_writer
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])

Expand All @@ -133,21 +134,21 @@ def test_save_load():
"module.weight2": torch.rand((1024, 1024)),
}
model_saved_path = f"{tempdir}/save_model.safetensors"
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
save(f_writer, model_state_dict)
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
del f_writer
load_state_dict = load_file(model_saved_path)
check_state_dict_equal(model_state_dict, load_state_dict)

model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
del f_writer
load_state_dict = load_file(model_saved_path)
check_state_dict_equal(model_state_dict, load_state_dict)