Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions deepspeed/utils/zero_to_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment.

debug = 0


def get_optim_files(checkpoint_dir):

Expand All @@ -41,6 +43,8 @@ def parse_optim_states(files):
if not "zero_stage" in state_dicts[0]['optimizer_state_dict']:
raise ValueError(f"non zero checkpoint")
zero_stage = state_dicts[0]['optimizer_state_dict']["zero_stage"]
world_size = state_dicts[0]['optimizer_state_dict']["partition_count"]
param_shapes = state_dicts[0]["param_shapes"]

# the groups are named differently in each stage
if zero_stage == 2:
Expand All @@ -50,12 +54,15 @@ def parse_optim_states(files):
else:
raise ValueError(f"unknown zero stage {zero_stage}")

param_shapes = state_dicts[0]["param_shapes"]
# if there is more than one param group, there will be multiple flattened tensors - one
# flattened tensor per group - for simplicity merge them into a single tensor
#
# XXX: could make the script more memory efficient for when there are multiple groups - it
# will require matching the sub-lists of param_shapes for each param group flattened tensor
fp32_flat_groups = [
state_dicts[i]['optimizer_state_dict'][fp32_groups_key][0]
for i in range(len(state_dicts))
torch.cat(state_dicts[i]['optimizer_state_dict'][fp32_groups_key],
0) for i in range(len(state_dicts))
Comment on lines +63 to +64
Copy link
Copy Markdown
Collaborator Author

@stas00 stas00 Apr 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only functional change in this PR. Instead of using just the first element, it now uses them all.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fine for now. I agree we have to revisit, especially for very large models that could cause CPU OOM.

]
world_size = state_dicts[0]['optimizer_state_dict']["partition_count"]

return zero_stage, world_size, param_shapes, fp32_flat_groups

Expand Down Expand Up @@ -93,6 +100,10 @@ def convert_zero_chkpt_to_fp32_consolid_state_dict(checkpoint_dir, output_file):
# - for zero3 we need to zip the partitions together at boundary of each param, re-consolidating
# each param, while dealing with padding if any

if debug:
for i in range(world_size):
print(f"fp32_flat_groups[i].shape={fp32_flat_groups[i].shape}")

if zero_stage == 2:
# XXX: memory usage doubles here (zero2)
full_single_fp32_vector = torch.cat(fp32_flat_groups, 0)
Expand All @@ -107,7 +118,10 @@ def convert_zero_chkpt_to_fp32_consolid_state_dict(checkpoint_dir, output_file):
total_numel += unpartitioned_numel

if zero_stage == 2:
# print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
if debug:
print(
f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} "
)
state_dict[name] = full_single_fp32_vector.narrow(
0,
offset,
Expand All @@ -116,7 +130,12 @@ def convert_zero_chkpt_to_fp32_consolid_state_dict(checkpoint_dir, output_file):

elif zero_stage == 3:
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
# print(f"{name} full shape: {shape} partition0 numel {partitioned_numel} partitioned_padding_numel {partitioned_padding_numel}")

if debug:
print(
f"{name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
)

# XXX: memory usage doubles here (zero3)
state_dict[name] = torch.cat(
tuple(fp32_flat_groups[i].narrow(0,
Expand Down