Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torchrun --standalone --nproc_per_node=4 shardformer_benchmark.py \
torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \
--model "bert" \
--pretrain "bert-base-uncased" \
--max_epochs 1 \
Expand Down
86 changes: 86 additions & 0 deletions colossalai/shardformer/examples/performance_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
Shardformer Benchmark
"""
import torch
import torch.distributed as dist
import transformers
import triton

import colossalai
from colossalai.shardformer import ShardConfig, ShardFormer


def data_gen(batch_size, seq_length):
input_ids = torch.randint(0, seq_length, (batch_size, seq_length), dtype=torch.long)
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
return dict(input_ids=input_ids, attention_mask=attention_mask)


def data_gen_for_sequence_classification(batch_size, seq_length):
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen(batch_size, seq_length)
data['labels'] = torch.ones((batch_size), dtype=torch.long)
return data


MODEL_CONFIG = transformers.LlamaConfig(num_hidden_layers=4,
hidden_size=128,
intermediate_size=256,
num_attention_heads=4,
max_position_embeddings=128,
num_labels=16)
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64
model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG)

# vary seq length for fixed head and batch=4
configs = [
triton.testing.Benchmark(x_names=['N_CTX'],
x_vals=[2**i for i in range(8, 13)],
line_arg='provider',
line_vals=['org_model', 'shard_model'],
line_names=['org_model', 'shard_model'],
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'lama_for_sequence_classification-batch-{BATCH}',
args={
'BATCH': BATCH,
'dtype': torch.float16,
'model_func': model_func
})
]


def train(model, data):
output = model(**data)
loss = output.logits.mean()
loss.backward()


@triton.testing.perf_report(configs)
def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, device="cuda"):
warmup = 10
rep = 100
# prepare data
data = data_gen_for_sequence_classification(BATCH, N_CTX)
data = {k: v.cuda() for k, v in data.items()}
model = model_func().to(device)
model.train()
if provider == "org_model":
fn = lambda: train(model, data)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == "shard_model":
shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model = shard_former.optimize(model).cuda()
fn = lambda: train(sharded_model, data)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms


# start benchmark, command:
# torchrun --standalone --nproc_per_node=2 performance_benchmark.py
if __name__ == "__main__":
colossalai.launch_from_torch({})
bench_shardformer.run(save_path='.', print_data=dist.get_rank() == 0)