From 49bfc827f89b141be6d283bc4635acbbe5610e0d Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 5 Jul 2023 13:41:42 +0800 Subject: [PATCH] [shardformer] opt support flash attention --- ..._benchmark.py => convergence_benchmark.py} | 0 ..._benchmark.sh => convergence_benchmark.sh} | 2 +- .../examples/performance_benchmark.py | 86 +++++++++++++++++++ 3 files changed, 87 insertions(+), 1 deletion(-) rename colossalai/shardformer/examples/{shardformer_benchmark.py => convergence_benchmark.py} (100%) rename colossalai/shardformer/examples/{shardformer_benchmark.sh => convergence_benchmark.sh} (76%) create mode 100644 colossalai/shardformer/examples/performance_benchmark.py diff --git a/colossalai/shardformer/examples/shardformer_benchmark.py b/colossalai/shardformer/examples/convergence_benchmark.py similarity index 100% rename from colossalai/shardformer/examples/shardformer_benchmark.py rename to colossalai/shardformer/examples/convergence_benchmark.py diff --git a/colossalai/shardformer/examples/shardformer_benchmark.sh b/colossalai/shardformer/examples/convergence_benchmark.sh similarity index 76% rename from colossalai/shardformer/examples/shardformer_benchmark.sh rename to colossalai/shardformer/examples/convergence_benchmark.sh index f42b19a32d35..1c281abcda6d 100644 --- a/colossalai/shardformer/examples/shardformer_benchmark.sh +++ b/colossalai/shardformer/examples/convergence_benchmark.sh @@ -1,4 +1,4 @@ -torchrun --standalone --nproc_per_node=4 shardformer_benchmark.py \ +torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \ --model "bert" \ --pretrain "bert-base-uncased" \ --max_epochs 1 \ diff --git a/colossalai/shardformer/examples/performance_benchmark.py b/colossalai/shardformer/examples/performance_benchmark.py new file mode 100644 index 000000000000..9c7b76bcf0a6 --- /dev/null +++ b/colossalai/shardformer/examples/performance_benchmark.py @@ -0,0 +1,86 @@ +""" +Shardformer Benchmark +""" +import torch +import torch.distributed as dist +import transformers +import triton + +import colossalai +from colossalai.shardformer import ShardConfig, ShardFormer + + +def data_gen(batch_size, seq_length): + input_ids = torch.randint(0, seq_length, (batch_size, seq_length), dtype=torch.long) + attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def data_gen_for_sequence_classification(batch_size, seq_length): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen(batch_size, seq_length) + data['labels'] = torch.ones((batch_size), dtype=torch.long) + return data + + +MODEL_CONFIG = transformers.LlamaConfig(num_hidden_layers=4, + hidden_size=128, + intermediate_size=256, + num_attention_heads=4, + max_position_embeddings=128, + num_labels=16) +BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64 +model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG) + +# vary seq length for fixed head and batch=4 +configs = [ + triton.testing.Benchmark(x_names=['N_CTX'], + x_vals=[2**i for i in range(8, 13)], + line_arg='provider', + line_vals=['org_model', 'shard_model'], + line_names=['org_model', 'shard_model'], + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'lama_for_sequence_classification-batch-{BATCH}', + args={ + 'BATCH': BATCH, + 'dtype': torch.float16, + 'model_func': model_func + }) +] + + +def train(model, data): + output = model(**data) + loss = output.logits.mean() + loss.backward() + + +@triton.testing.perf_report(configs) +def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, device="cuda"): + warmup = 10 + rep = 100 + # prepare data + data = data_gen_for_sequence_classification(BATCH, N_CTX) + data = {k: v.cuda() for k, v in data.items()} + model = model_func().to(device) + model.train() + if provider == "org_model": + fn = lambda: train(model, data) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + if provider == "shard_model": + shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True) + shard_former = ShardFormer(shard_config=shard_config) + sharded_model = shard_former.optimize(model).cuda() + fn = lambda: train(sharded_model, data) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + + +# start benchmark, command: +# torchrun --standalone --nproc_per_node=2 performance_benchmark.py +if __name__ == "__main__": + colossalai.launch_from_torch({}) + bench_shardformer.run(save_path='.', print_data=dist.get_rank() == 0)