Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 71 additions & 61 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,30 +786,36 @@ def state_dict(
"""
zero_state = dict()
device = get_accelerator().get_current_device()
for param, state in self.optim.state.items():
working_param = self.master_to_working_param[id(param)]
pg = self.param_to_pg[working_param]
if not only_on_master or get_nd_rank(pg) == 0:
zero_state[param] = copy.deepcopy(state)
else:
zero_state[param] = {}

if pinned_state_dicts is not None and param not in pinned_state_dicts:
pinned_state_dicts[param] = {}

for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step":
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param)
if not only_on_master or get_nd_rank(pg) == 0:
if pinned_state_dicts is not None and k not in pinned_state_dicts[param]:
pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu")
if pinned_state_dicts is not None:
pinned_state_dicts[param][k].copy_(param_state)
zero_state[param][k] = pinned_state_dicts[param][k]
else:
zero_state[param][k] = param_state.cpu()
for param_group in self.optim.param_groups:
for param in param_group["params"]:
if param not in self.optim.state:
continue
state = self.optim.state[param]
working_param = self.master_to_working_param[id(param)]
pg = self.param_to_pg[working_param]
if not only_on_master or get_nd_rank(pg) == 0:
zero_state[param] = copy.deepcopy(state)
else:
zero_state[param] = {}

if pinned_state_dicts is not None and param not in pinned_state_dicts:
pinned_state_dicts[param] = {}

for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step":
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param)
if not only_on_master or get_nd_rank(pg) == 0:
if pinned_state_dicts is not None and k not in pinned_state_dicts[param]:
pinned_state_dicts[param][k] = torch.empty_like(
param_state, pin_memory=True, device="cpu"
)
if pinned_state_dicts is not None:
pinned_state_dicts[param][k].copy_(param_state)
zero_state[param][k] = pinned_state_dicts[param][k]
else:
zero_state[param][k] = param_state.cpu()

states_dict = self._pack_state(zero_state)

Expand Down Expand Up @@ -865,48 +871,52 @@ def state_dict_shard(
device = get_accelerator().get_current_device()
local_states = self.optim.state_dict()["state"]

idx2master = {}
master2idx = {}
cnt = 0
for param_group in self.optim.param_groups:
for param in param_group["params"]:
idx2master[cnt] = param
master2idx[param] = cnt
cnt += 1
for param_idx, states in local_states.items():
current_block_size = 0
if pinned_state_dicts is not None and param_idx not in pinned_state_dicts:
pinned_state_dicts[param_idx] = {}
master_param = idx2master[param_idx]
working_param = self.master_to_working_param[id(master_param)]
pg = self.param_to_pg[working_param]
if not only_on_master or get_nd_rank(pg) == 0:
current_block = copy.deepcopy(states)
else:
current_block = {}

for k, v in states.items():
if isinstance(v, torch.Tensor) and k != "step":
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
if not only_on_master or get_nd_rank(pg) == 0:
if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]:
pinned_state_dicts[param_idx][k] = torch.empty_like(
state_tensor, pin_memory=True, device="cpu"
)
if pinned_state_dicts is not None:
pinned_state_dicts[param_idx][k].copy_(state_tensor)
current_block[k] = pinned_state_dicts[param_idx][k]
else:
current_block[k] = state_tensor.cpu()
current_block_size += calculate_tensor_size(state_tensor)

if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
yield ret_block, ret_block_size
ret_block = dict()
ret_block_size = 0

ret_block[param_idx] = current_block
ret_block_size += current_block_size
for param_group in self.optim.param_groups:
for master_param in param_group["params"]:
param_idx = master2idx[master_param]
states = local_states[param_idx]

current_block_size = 0
if pinned_state_dicts is not None and param_idx not in pinned_state_dicts:
pinned_state_dicts[param_idx] = {}
working_param = self.master_to_working_param[id(master_param)]
pg = self.param_to_pg[working_param]
if not only_on_master or get_nd_rank(pg) == 0:
current_block = copy.deepcopy(states)
else:
current_block = {}

for k, v in states.items():
if isinstance(v, torch.Tensor) and k != "step":
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
if not only_on_master or get_nd_rank(pg) == 0:
if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]:
pinned_state_dicts[param_idx][k] = torch.empty_like(
state_tensor, pin_memory=True, device="cpu"
)
if pinned_state_dicts is not None:
pinned_state_dicts[param_idx][k].copy_(state_tensor)
current_block[k] = pinned_state_dicts[param_idx][k]
else:
current_block[k] = state_tensor.cpu()
current_block_size += calculate_tensor_size(state_tensor)

if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
yield ret_block, ret_block_size
ret_block = dict()
ret_block_size = 0

ret_block[param_idx] = current_block
ret_block_size += current_block_size

yield ret_block, ret_block_size

Expand Down