From c1a65cd134e19dc142e5ef4b9afc2421de9ee7ed Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Sun, 24 Sep 2023 21:13:03 +0800 Subject: [PATCH 1/2] [fix] fix weekly runing example --- examples/tutorial/new_api/cifar_resnet/train.py | 2 +- examples/tutorial/new_api/cifar_vit/train.py | 2 +- examples/tutorial/new_api/glue_bert/finetune.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/tutorial/new_api/cifar_resnet/train.py b/examples/tutorial/new_api/cifar_resnet/train.py index 6ae2d8b0412f..4407a51c3153 100644 --- a/examples/tutorial/new_api/cifar_resnet/train.py +++ b/examples/tutorial/new_api/cifar_resnet/train.py @@ -145,7 +145,7 @@ def main(): if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() elif args.plugin == "gemini": - plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5) + plugin = GeminiPlugin(placement_policy="static", strict_ddp_mode=True, initial_scale=2**5) elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) diff --git a/examples/tutorial/new_api/cifar_vit/train.py b/examples/tutorial/new_api/cifar_vit/train.py index 226a4b320961..700e4d2e0cd9 100644 --- a/examples/tutorial/new_api/cifar_vit/train.py +++ b/examples/tutorial/new_api/cifar_vit/train.py @@ -165,7 +165,7 @@ def main(): if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() elif args.plugin == "gemini": - plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5) + plugin = GeminiPlugin(placement_policy="static", strict_ddp_mode=True, initial_scale=2**5) elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) diff --git a/examples/tutorial/new_api/glue_bert/finetune.py b/examples/tutorial/new_api/glue_bert/finetune.py index 7d69dbc066b3..d46bcc36952b 100644 --- a/examples/tutorial/new_api/glue_bert/finetune.py +++ b/examples/tutorial/new_api/glue_bert/finetune.py @@ -141,7 +141,7 @@ def main(): if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() elif args.plugin == "gemini": - plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5) + plugin = GeminiPlugin(placement_policy="static", strict_ddp_mode=True, initial_scale=2**5) elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) From dada93f8958b5a54050cafdca3b08b55e57f5b37 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 25 Sep 2023 14:17:07 +0800 Subject: [PATCH 2/2] [fix] fix weekly runing example --- examples/tutorial/new_api/glue_bert/finetune.py | 2 +- examples/tutorial/new_api/glue_bert/test_ci.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/tutorial/new_api/glue_bert/finetune.py b/examples/tutorial/new_api/glue_bert/finetune.py index d46bcc36952b..990822c9feba 100644 --- a/examples/tutorial/new_api/glue_bert/finetune.py +++ b/examples/tutorial/new_api/glue_bert/finetune.py @@ -21,7 +21,7 @@ # ============================== # Prepare Hyperparameters # ============================== -NUM_EPOCHS = 3 +NUM_EPOCHS = 1 BATCH_SIZE = 32 LEARNING_RATE = 2.4e-5 WEIGHT_DECAY = 0.01 diff --git a/examples/tutorial/new_api/glue_bert/test_ci.sh b/examples/tutorial/new_api/glue_bert/test_ci.sh index c2c097f8d026..56dd431f1e60 100755 --- a/examples/tutorial/new_api/glue_bert/test_ci.sh +++ b/examples/tutorial/new_api/glue_bert/test_ci.sh @@ -4,5 +4,5 @@ set -xe pip install -r requirements.txt for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do - torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin + torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.80 --plugin $plugin done