From 88ddb4924a0b3bcf13e8d6fa175b34dfed15ad91 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 28 Apr 2021 20:07:21 -0700 Subject: [PATCH 1/2] support param groups --- deepspeed/utils/zero_to_fp32.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index 3401fd635e7c..59d2420261f7 100644 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -17,6 +17,8 @@ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with # DeepSpeed data structures it has to be available in the current python environment. +debug = 0 + def get_optim_files(checkpoint_dir): @@ -41,6 +43,8 @@ def parse_optim_states(files): if not "zero_stage" in state_dicts[0]['optimizer_state_dict']: raise ValueError(f"non zero checkpoint") zero_stage = state_dicts[0]['optimizer_state_dict']["zero_stage"] + world_size = state_dicts[0]['optimizer_state_dict']["partition_count"] + param_shapes = state_dicts[0]["param_shapes"] # the groups are named differently in each stage if zero_stage == 2: @@ -50,12 +54,15 @@ def parse_optim_states(files): else: raise ValueError(f"unknown zero stage {zero_stage}") - param_shapes = state_dicts[0]["param_shapes"] + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor fp32_flat_groups = [ - state_dicts[i]['optimizer_state_dict'][fp32_groups_key][0] + torch.cat(state_dicts[i]['optimizer_state_dict'][fp32_groups_key], 0) for i in range(len(state_dicts)) ] - world_size = state_dicts[0]['optimizer_state_dict']["partition_count"] return zero_stage, world_size, param_shapes, fp32_flat_groups @@ -93,6 +100,10 @@ def convert_zero_chkpt_to_fp32_consolid_state_dict(checkpoint_dir, output_file): # - for zero3 we need to zip the partitions together at boundary of each param, re-consolidating # each param, while dealing with padding if any + if debug: + for i in range(world_size): + print(f"fp32_flat_groups[i].shape={fp32_flat_groups[i].shape}") + if zero_stage == 2: # XXX: memory usage doubles here (zero2) full_single_fp32_vector = torch.cat(fp32_flat_groups, 0) @@ -107,7 +118,8 @@ def convert_zero_chkpt_to_fp32_consolid_state_dict(checkpoint_dir, output_file): total_numel += unpartitioned_numel if zero_stage == 2: - # print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") state_dict[name] = full_single_fp32_vector.narrow( 0, offset, @@ -116,7 +128,10 @@ def convert_zero_chkpt_to_fp32_consolid_state_dict(checkpoint_dir, output_file): elif zero_stage == 3: partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) - # print(f"{name} full shape: {shape} partition0 numel {partitioned_numel} partitioned_padding_numel {partitioned_padding_numel}") + + if debug: + print(f"{name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}") + # XXX: memory usage doubles here (zero3) state_dict[name] = torch.cat( tuple(fp32_flat_groups[i].narrow(0, From e0f8b6ef2bc779285595d47ea6d36af0655d8d18 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 28 Apr 2021 20:19:07 -0700 Subject: [PATCH 2/2] terrible autoformatter --- deepspeed/utils/zero_to_fp32.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index 59d2420261f7..2d98a39e3fc7 100644 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -60,8 +60,8 @@ def parse_optim_states(files): # XXX: could make the script more memory efficient for when there are multiple groups - it # will require matching the sub-lists of param_shapes for each param group flattened tensor fp32_flat_groups = [ - torch.cat(state_dicts[i]['optimizer_state_dict'][fp32_groups_key], 0) - for i in range(len(state_dicts)) + torch.cat(state_dicts[i]['optimizer_state_dict'][fp32_groups_key], + 0) for i in range(len(state_dicts)) ] return zero_stage, world_size, param_shapes, fp32_flat_groups @@ -119,7 +119,9 @@ def convert_zero_chkpt_to_fp32_consolid_state_dict(checkpoint_dir, output_file): if zero_stage == 2: if debug: - print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + print( + f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " + ) state_dict[name] = full_single_fp32_vector.narrow( 0, offset, @@ -130,7 +132,9 @@ def convert_zero_chkpt_to_fp32_consolid_state_dict(checkpoint_dir, output_file): partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) if debug: - print(f"{name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}") + print( + f"{name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) # XXX: memory usage doubles here (zero3) state_dict[name] = torch.cat(