diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index 3401fd635e7c..2d98a39e3fc7 100644 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -17,6 +17,8 @@ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with # DeepSpeed data structures it has to be available in the current python environment. +debug = 0 + def get_optim_files(checkpoint_dir): @@ -41,6 +43,8 @@ def parse_optim_states(files): if not "zero_stage" in state_dicts[0]['optimizer_state_dict']: raise ValueError(f"non zero checkpoint") zero_stage = state_dicts[0]['optimizer_state_dict']["zero_stage"] + world_size = state_dicts[0]['optimizer_state_dict']["partition_count"] + param_shapes = state_dicts[0]["param_shapes"] # the groups are named differently in each stage if zero_stage == 2: @@ -50,12 +54,15 @@ def parse_optim_states(files): else: raise ValueError(f"unknown zero stage {zero_stage}") - param_shapes = state_dicts[0]["param_shapes"] + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor fp32_flat_groups = [ - state_dicts[i]['optimizer_state_dict'][fp32_groups_key][0] - for i in range(len(state_dicts)) + torch.cat(state_dicts[i]['optimizer_state_dict'][fp32_groups_key], + 0) for i in range(len(state_dicts)) ] - world_size = state_dicts[0]['optimizer_state_dict']["partition_count"] return zero_stage, world_size, param_shapes, fp32_flat_groups @@ -93,6 +100,10 @@ def convert_zero_chkpt_to_fp32_consolid_state_dict(checkpoint_dir, output_file): # - for zero3 we need to zip the partitions together at boundary of each param, re-consolidating # each param, while dealing with padding if any + if debug: + for i in range(world_size): + print(f"fp32_flat_groups[i].shape={fp32_flat_groups[i].shape}") + if zero_stage == 2: # XXX: memory usage doubles here (zero2) full_single_fp32_vector = torch.cat(fp32_flat_groups, 0) @@ -107,7 +118,10 @@ def convert_zero_chkpt_to_fp32_consolid_state_dict(checkpoint_dir, output_file): total_numel += unpartitioned_numel if zero_stage == 2: - # print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + if debug: + print( + f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " + ) state_dict[name] = full_single_fp32_vector.narrow( 0, offset, @@ -116,7 +130,12 @@ def convert_zero_chkpt_to_fp32_consolid_state_dict(checkpoint_dir, output_file): elif zero_stage == 3: partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) - # print(f"{name} full shape: {shape} partition0 numel {partitioned_numel} partitioned_padding_numel {partitioned_padding_numel}") + + if debug: + print( + f"{name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + # XXX: memory usage doubles here (zero3) state_dict[name] = torch.cat( tuple(fp32_flat_groups[i].narrow(0,