From 4febe32c68a29e0e8a904cb889bd33f24bdec1e7 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 16 Mar 2021 16:24:51 -0700 Subject: [PATCH 1/2] consistent checkpoint filenaming --- deepspeed/runtime/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index e11e2c1d7afc..226659fbeee0 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1348,7 +1348,7 @@ def _get_rank_zero_ckpt_name(self, checkpoints_path, tag, mp_rank, dp_rank): zero_ckpt_name = os.path.join( checkpoints_path, str(tag), - filename + '_mp_rank_{:02d}'.format(mp_rank) + 'optim_states.pt') + filename + '_mp_rank_{:02d}'.format(mp_rank) + '_optim_states.pt') return zero_ckpt_name def _get_zero_ckpt_name(self, checkpoints_path, tag): From 21f00c5beb68de551e38e9d385545f1e0dad7420 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 16 Mar 2021 17:23:44 -0700 Subject: [PATCH 2/2] backward compatible rename --- deepspeed/runtime/engine.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 226659fbeee0..6b099d3c5398 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1525,13 +1525,20 @@ def _get_all_zero_checkpoints(self, load_dir, tag): mp_rank=mp_rank, dp_world_size=self.loaded_checkpoint_dp_world_size) invalid_zero_ckpt_paths = [] - for ckpt_name in zero_ckpt_names: + for i, ckpt_name in enumerate(zero_ckpt_names): if not os.path.exists(ckpt_name): + # transparently handle the old file pattern for optim_states + if 'optim_states.pt' in ckpt_name: + ckpt_name_try = ckpt_name.replace("_optim_states.pt", + "optim_states.pt") + if os.path.exists(ckpt_name_try): + zero_ckpt_names[i] = ckpt_name_try + continue invalid_zero_ckpt_paths.append(ckpt_name) if len(invalid_zero_ckpt_paths) > 0: logger.warn( - f"Client provided zero checkpoint load paths: {invalid_zero_ckpt_paths} does not exist" + f"The following zero checkpoints paths are missing: {invalid_zero_ckpt_paths}" ) return None