From f6e0cbba534e11dea8ace831b1f87614e82274e6 Mon Sep 17 00:00:00 2001 From: zhurunhua <1281592874@qq.com> Date: Thu, 25 Jul 2024 21:22:25 +0800 Subject: [PATCH 1/2] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends --- applications/Colossal-LLaMA/train.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/applications/Colossal-LLaMA/train.py b/applications/Colossal-LLaMA/train.py index 43a360a9a49c..e2904ca49403 100644 --- a/applications/Colossal-LLaMA/train.py +++ b/applications/Colossal-LLaMA/train.py @@ -128,6 +128,11 @@ def main() -> None: parser.add_argument("--zero", type=int, default=1) parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos") parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length") + parser.add_argument( + "--skip_save_each_epoch", + action="store_true", + default=False, + help="skip saving the model checkpoint after each epoch is completed.") args = parser.parse_args() with open(args.config_file, "w") as f: @@ -370,11 +375,15 @@ def main() -> None: ) total_loss.fill_(0.0) pbar.update() + # Save modeling. - if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or ( - step + 1 - ) == len(dataloader): + save_model_condition = (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) + + if not args.skip_save_each_epoch: + save_model_condition = save_model_condition or (step + 1) == len(dataloader) + + if save_model_condition: coordinator.print_on_master("\nStart saving model checkpoint with running states") if args.use_neft: From d59faa12f482fd55c7b3c9de60a6adb96c92df0a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jul 2024 15:02:42 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- applications/Colossal-LLaMA/train.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/applications/Colossal-LLaMA/train.py b/applications/Colossal-LLaMA/train.py index e2904ca49403..e74aad33c3e3 100644 --- a/applications/Colossal-LLaMA/train.py +++ b/applications/Colossal-LLaMA/train.py @@ -132,7 +132,8 @@ def main() -> None: "--skip_save_each_epoch", action="store_true", default=False, - help="skip saving the model checkpoint after each epoch is completed.") + help="skip saving the model checkpoint after each epoch is completed.", + ) args = parser.parse_args() with open(args.config_file, "w") as f: @@ -375,10 +376,12 @@ def main() -> None: ) total_loss.fill_(0.0) pbar.update() - + # Save modeling. - save_model_condition = (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) + save_model_condition = ( + args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0 + ) if not args.skip_save_each_epoch: save_model_condition = save_model_condition or (step + 1) == len(dataloader)