From 4a6fac3242c4f6558a99ea73aed87b6f279f3447 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 10 Apr 2023 18:44:01 +0800 Subject: [PATCH] [chat] add zero2 cpu strategy for sft training --- applications/Chat/examples/train_sft.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py index c0ac7b177694..22f70e485843 100644 --- a/applications/Chat/examples/train_sft.py +++ b/applications/Chat/examples/train_sft.py @@ -35,6 +35,8 @@ def train(args): strategy = ColossalAIStrategy(stage=3, placement_policy='cuda') elif args.strategy == 'colossalai_zero2': strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + elif args.strategy == 'colossalai_zero2_cpu': + strategy = ColossalAIStrategy(stage=2, placement_policy='cpu') else: raise ValueError(f'Unsupported strategy "{args.strategy}"') @@ -168,7 +170,7 @@ def train(args): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--strategy', - choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'], default='naive') parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') parser.add_argument('--pretrain', type=str, default=None)