Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -429,12 +429,13 @@ As shown in the figures above, when the sequence length is around 1000 or greate
### Convergence


To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.
To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](../../examples/language/bert/finetune.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results.

| accuracy | f1 | loss | GPU number | model shard |

| accuracy | f1 | loss | GPU number | model sharded |
| :------: | :-----: | :-----: | :--------: | :---------: |
| 0.82594 | 0.87441 | 0.09913 | 4 | True |
| 0.81884 | 0.87299 | 0.10120 | 2 | True |
| 0.81855 | 0.87124 | 0.10357 | 1 | False |
| 0.84589 | 0.88613 | 0.43414 | 4 | True |
| 0.83594 | 0.88064 | 0.43298 | 1 | False |


Overall, the results demonstrate that using shardformers during model training does not affect the convergence.
14 changes: 8 additions & 6 deletions examples/language/bert/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ This directory includes two parts: Using the Booster API finetune Huggingface Be
bash test_ci.sh
```

### Results on 2-GPU
### Bert-Finetune Results

| Plugin | Accuracy | F1-score | GPU number |
| -------------- | -------- | -------- | -------- |
| torch_ddp | 84.4% | 88.6% | 2 |
| torch_ddp_fp16 | 84.7% | 88.8% | 2 |
| gemini | 84.0% | 88.4% | 2 |
| hybrid_parallel | 84.5% | 88.6% | 4 |

| Plugin | Accuracy | F1-score |
| -------------- | -------- | -------- |
| torch_ddp | 84.4% | 88.6% |
| torch_ddp_fp16 | 84.7% | 88.8% |
| gemini | 84.0% | 88.4% |

## Benchmark
```
Expand Down