Skip to content

[moe] support openmoe train#4637

Merged
oahzxl merged 11 commits intohpcaitech:feature/moefrom
oahzxl:moe_train
Sep 7, 2023
Merged

[moe] support openmoe train#4637
oahzxl merged 11 commits intohpcaitech:feature/moefrom
oahzxl:moe_train

Conversation

@oahzxl
Copy link
Copy Markdown
Contributor

@oahzxl oahzxl commented Sep 6, 2023

  • support openmoe train for zero2 and zero2+ep
  • rewrite checkpoint for ckpt load and save
  • use chunk and checkpoint to reduce lm_head's high activation memory due to large vocab size
  • update model details to adapt to train

TODO: align result and accuracy

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Sep 6, 2023

The code coverage for the changed files is 5%.

Click me to view the complete report
Name                                                      Stmts   Miss  Cover
-----------------------------------------------------------------------------
colossalai/amp/naive_amp/mixed_precision_optimizer.py        98     98     0%
colossalai/booster/booster.py                                66     66     0%
colossalai/booster/plugin/__init__.py                        11     11     0%
colossalai/booster/plugin/hybrid_parallel_plugin.py         152    152     0%
colossalai/booster/plugin/pp_plugin_base.py                   9      9     0%
colossalai/cluster/__init__.py                                5      0   100%
colossalai/cluster/process_group_mesh.py                     72     46    36%
colossalai/context/moe_context.py                            57     28    51%
colossalai/engine/gradient_handler/__init__.py                6      0   100%
colossalai/interface/optimizer.py                            45     20    56%
colossalai/kernel/cuda_native/__init__.py                     5      0   100%
colossalai/lazy/lazy_init.py                                315    246    22%
colossalai/nn/layer/moe/__init__.py                           6      0   100%
colossalai/nn/layer/moe/_operation.py                       143     98    31%
colossalai/nn/layer/moe/checkpoint.py                        56     37    34%
colossalai/nn/layer/moe/experts.py                           84     65    23%
colossalai/nn/layer/moe/layers.py                            82     61    26%
colossalai/nn/layer/moe/routers.py                          133    110    17%
colossalai/nn/layer/moe/utils.py                             58     38    34%
colossalai/pipeline/p2p.py                                  102    102     0%
colossalai/pipeline/schedule/__init__.py                      3      3     0%
colossalai/pipeline/schedule/_utils.py                       50     50     0%
colossalai/pipeline/schedule/base.py                         10     10     0%
colossalai/pipeline/schedule/one_f_one_b.py                 116    116     0%
colossalai/pipeline/stage_manager.py                         68     68     0%
colossalai/shardformer/_utils.py                             54     54     0%
colossalai/shardformer/layer/__init__.py                      8      8     0%
colossalai/shardformer/layer/embedding.py                   130    130     0%
colossalai/shardformer/layer/linear.py                      181    181     0%
colossalai/shardformer/layer/normalization.py                51     51     0%
colossalai/shardformer/layer/qkv_fused_linear.py            292    292     0%
colossalai/shardformer/layer/utils.py                        84     84     0%
colossalai/shardformer/modeling/bert.py                     431    431     0%
colossalai/shardformer/modeling/blip2.py                     53     53     0%
colossalai/shardformer/modeling/bloom.py                    387    387     0%
colossalai/shardformer/modeling/chatglm.py                  149    149     0%
colossalai/shardformer/modeling/gpt2.py                     293    293     0%
colossalai/shardformer/modeling/jit.py                       19     19     0%
colossalai/shardformer/modeling/llama.py                    204    204     0%
colossalai/shardformer/modeling/opt.py                      285    285     0%
colossalai/shardformer/modeling/sam.py                       94     94     0%
colossalai/shardformer/modeling/t5.py                       297    297     0%
colossalai/shardformer/modeling/vit.py                      149    149     0%
colossalai/shardformer/modeling/whisper.py                   95     95     0%
colossalai/shardformer/policies/auto_policy.py               27     27     0%
colossalai/shardformer/policies/base_policy.py               87     87     0%
colossalai/shardformer/policies/bert.py                     257    257     0%
colossalai/shardformer/policies/blip2.py                     54     54     0%
colossalai/shardformer/policies/bloom.py                    151    151     0%
colossalai/shardformer/policies/chatglm.py                  100    100     0%
colossalai/shardformer/policies/gpt2.py                     181    181     0%
colossalai/shardformer/policies/llama.py                    114    114     0%
colossalai/shardformer/policies/opt.py                      140    140     0%
colossalai/shardformer/policies/sam.py                       32     32     0%
colossalai/shardformer/policies/t5.py                       182    182     0%
colossalai/shardformer/policies/vit.py                      108    108     0%
colossalai/shardformer/policies/whisper.py                   61     61     0%
colossalai/shardformer/shard/shard_config.py                 28     28     0%
colossalai/shardformer/shard/sharder.py                      95     95     0%
colossalai/shardformer/shard/shardformer.py                  15     15     0%
colossalai/shardformer/shard/utils.py                        11     11     0%
colossalai/tensor/d_tensor/api.py                           149    113    24%
colossalai/tensor/moe_tensor/__init__.py                      0      0   100%
colossalai/tensor/moe_tensor/api.py                          20      8    60%
colossalai/tensor/moe_tensor/moe_info.py                     10      7    30%
colossalai/zero/low_level/low_level_optim.py                330    288    13%
tests/kit/model_zoo/transformers/__init__.py                 12     12     0%
tests/kit/model_zoo/transformers/bert.py                     50     50     0%
tests/kit/model_zoo/transformers/blip2.py                    21     21     0%
tests/kit/model_zoo/transformers/bloom.py                    36     36     0%
tests/kit/model_zoo/transformers/chatglm.py                  20     20     0%
tests/kit/model_zoo/transformers/gpt.py                      39     39     0%
tests/kit/model_zoo/transformers/opt.py                      32     32     0%
tests/kit/model_zoo/transformers/sam.py                      14     14     0%
tests/kit/model_zoo/transformers/t5.py                       25     25     0%
tests/kit/model_zoo/transformers/vit.py                      24     24     0%
tests/kit/model_zoo/transformers/whisper.py                  23     23     0%
tests/test_shardformer/test_model/_utils.py                 142    142     0%
tests/test_shardformer/test_model/test_shard_bert.py         62     62     0%
tests/test_shardformer/test_model/test_shard_blip2.py        40     40     0%
tests/test_shardformer/test_model/test_shard_bloom.py        59     59     0%
tests/test_shardformer/test_model/test_shard_chatglm.py      60     60     0%
tests/test_shardformer/test_model/test_shard_gpt2.py         65     65     0%
tests/test_shardformer/test_model/test_shard_llama.py        62     62     0%
tests/test_shardformer/test_model/test_shard_opt.py          62     62     0%
tests/test_shardformer/test_model/test_shard_sam.py          39     39     0%
tests/test_shardformer/test_model/test_shard_t5.py           59     59     0%
tests/test_shardformer/test_model/test_shard_vit.py          61     61     0%
tests/test_shardformer/test_model/test_shard_whisper.py      46     46     0%
tests/test_shardformer/test_shard_utils.py                   21     21     0%
tests/test_shardformer/test_with_torch_ddp.py                52     52     0%
-----------------------------------------------------------------------------
TOTAL                                                      8286   7875     5%

Comment thread colossalai/nn/layer/moe/checkpoint.py Outdated
Comment thread colossalai/context/moe_context.py
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Sep 7, 2023

The code coverage for the changed files is 82%.

Click me to view the complete report
Name                                                                    Stmts   Miss  Cover
-------------------------------------------------------------------------------------------
colossalai/amp/naive_amp/mixed_precision_optimizer.py                      98     20    80%
colossalai/booster/booster.py                                              66      9    86%
colossalai/booster/plugin/__init__.py                                      11      0   100%
colossalai/booster/plugin/hybrid_parallel_plugin.py                       152     17    89%
colossalai/booster/plugin/pp_plugin_base.py                                 9      1    89%
colossalai/cluster/__init__.py                                              5      0   100%
colossalai/cluster/process_group_mesh.py                                   72      1    99%
colossalai/context/moe_context.py                                          57     28    51%
colossalai/engine/gradient_handler/__init__.py                              6      0   100%
colossalai/interface/optimizer.py                                          45      5    89%
colossalai/kernel/cuda_native/__init__.py                                   5      0   100%
colossalai/lazy/lazy_init.py                                              315     44    86%
colossalai/nn/layer/moe/__init__.py                                         6      0   100%
colossalai/nn/layer/moe/_operation.py                                     143     98    31%
colossalai/nn/layer/moe/checkpoint.py                                      56     37    34%
colossalai/nn/layer/moe/experts.py                                         84     65    23%
colossalai/nn/layer/moe/layers.py                                          82     61    26%
colossalai/nn/layer/moe/routers.py                                        133    110    17%
colossalai/nn/layer/moe/utils.py                                           58     38    34%
colossalai/pipeline/p2p.py                                                102      7    93%
colossalai/pipeline/schedule/__init__.py                                    3      0   100%
colossalai/pipeline/schedule/_utils.py                                     50      5    90%
colossalai/pipeline/schedule/base.py                                       10      1    90%
colossalai/pipeline/schedule/one_f_one_b.py                               116      4    97%
colossalai/pipeline/stage_manager.py                                       68      4    94%
colossalai/shardformer/_utils.py                                           54     15    72%
colossalai/shardformer/layer/__init__.py                                    8      0   100%
colossalai/shardformer/layer/embedding.py                                 130     24    82%
colossalai/shardformer/layer/linear.py                                    181     53    71%
colossalai/shardformer/layer/normalization.py                              51     10    80%
colossalai/shardformer/layer/qkv_fused_linear.py                          292     70    76%
colossalai/shardformer/layer/utils.py                                      84     17    80%
colossalai/shardformer/modeling/bert.py                                   431    128    70%
colossalai/shardformer/modeling/blip2.py                                   53      1    98%
colossalai/shardformer/modeling/bloom.py                                  387    122    68%
colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py       30      0   100%
colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py           571    239    58%
colossalai/shardformer/modeling/chatglm.py                                149     34    77%
colossalai/shardformer/modeling/gpt2.py                                   293     83    72%
colossalai/shardformer/modeling/jit.py                                     19      3    84%
colossalai/shardformer/modeling/llama.py                                  204     65    68%
colossalai/shardformer/modeling/opt.py                                    285     65    77%
colossalai/shardformer/modeling/sam.py                                     94      6    94%
colossalai/shardformer/modeling/t5.py                                     297     74    75%
colossalai/shardformer/modeling/vit.py                                    149     23    85%
colossalai/shardformer/modeling/whisper.py                                 95     13    86%
colossalai/shardformer/policies/auto_policy.py                             27      2    93%
colossalai/shardformer/policies/base_policy.py                             87     11    87%
colossalai/shardformer/policies/bert.py                                   257      0   100%
colossalai/shardformer/policies/blip2.py                                   54      2    96%
colossalai/shardformer/policies/bloom.py                                  151      2    99%
colossalai/shardformer/policies/chatglm.py                                100      6    94%
colossalai/shardformer/policies/gpt2.py                                   181      1    99%
colossalai/shardformer/policies/llama.py                                  114      3    97%
colossalai/shardformer/policies/opt.py                                    140      2    99%
colossalai/shardformer/policies/sam.py                                     32      0   100%
colossalai/shardformer/policies/t5.py                                     182      5    97%
colossalai/shardformer/policies/vit.py                                    108      1    99%
colossalai/shardformer/policies/whisper.py                                 61      2    97%
colossalai/shardformer/shard/shard_config.py                               28      0   100%
colossalai/shardformer/shard/sharder.py                                    95      3    97%
colossalai/shardformer/shard/shardformer.py                                15      0   100%
colossalai/shardformer/shard/utils.py                                      11      0   100%
colossalai/tensor/d_tensor/api.py                                         149     24    84%
colossalai/tensor/moe_tensor/__init__.py                                    0      0   100%
colossalai/tensor/moe_tensor/api.py                                        20      7    65%
colossalai/tensor/moe_tensor/moe_info.py                                   10      7    30%
colossalai/zero/low_level/low_level_optim.py                              330     30    91%
tests/kit/model_zoo/transformers/__init__.py                               12      0   100%
tests/kit/model_zoo/transformers/bert.py                                   50      0   100%
tests/kit/model_zoo/transformers/blip2.py                                  21      0   100%
tests/kit/model_zoo/transformers/bloom.py                                  36      0   100%
tests/kit/model_zoo/transformers/chatglm.py                                20      0   100%
tests/kit/model_zoo/transformers/gpt.py                                    39      0   100%
tests/kit/model_zoo/transformers/opt.py                                    32      0   100%
tests/kit/model_zoo/transformers/sam.py                                    14      0   100%
tests/kit/model_zoo/transformers/t5.py                                     25      0   100%
tests/kit/model_zoo/transformers/vit.py                                    24      0   100%
tests/kit/model_zoo/transformers/whisper.py                                23      0   100%
tests/test_booster/test_plugin/test_3d_plugin.py                           64      7    89%
tests/test_booster/test_plugin/test_gemini_plugin.py                       74     10    86%
tests/test_cluster/test_process_group_mesh.py                              86      1    99%
tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py                 21      2    90%
tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py                    17      1    94%
tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py                     17      1    94%
tests/test_lazy/test_models.py                                             14      1    93%
tests/test_pipeline/test_p2p_communication.py                              44      1    98%
tests/test_pipeline/test_schedule/test_oneF_oneB.py                        80      2    98%
tests/test_pipeline/test_schedule/test_pipeline_schedule_utils.py          40      0   100%
tests/test_pipeline/test_stage_manager.py                                  52      1    98%
tests/test_shardformer/test_layer/test_embedding.py                        37      1    97%
tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py         89      1    99%
tests/test_shardformer/test_layer/test_layernorm.py                        35      1    97%
tests/test_shardformer/test_layer/test_linear_1d.py                       110      1    99%
tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py              51      1    98%
tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py      39      1    97%
tests/test_shardformer/test_model/_utils.py                               142     21    85%
tests/test_shardformer/test_model/test_shard_bert.py                       62      1    98%
tests/test_shardformer/test_model/test_shard_blip2.py                      40      1    98%
tests/test_shardformer/test_model/test_shard_bloom.py                      59      1    98%
tests/test_shardformer/test_model/test_shard_chatglm.py                    60      1    98%
tests/test_shardformer/test_model/test_shard_gpt2.py                       65      1    98%
tests/test_shardformer/test_model/test_shard_llama.py                      62      1    98%
tests/test_shardformer/test_model/test_shard_opt.py                        62      1    98%
tests/test_shardformer/test_model/test_shard_sam.py                        39      1    97%
tests/test_shardformer/test_model/test_shard_t5.py                         59      1    98%
tests/test_shardformer/test_model/test_shard_vit.py                        61      1    98%
tests/test_shardformer/test_model/test_shard_whisper.py                    46      1    98%
tests/test_shardformer/test_shard_utils.py                                 21      0   100%
tests/test_shardformer/test_with_torch_ddp.py                              52      1    98%
tests/test_utils/test_flash_attention.py                                   92      8    91%
-------------------------------------------------------------------------------------------
TOTAL                                                                    9849   1782    82%

@oahzxl oahzxl merged commit 764d397 into hpcaitech:feature/moe Sep 7, 2023
@oahzxl oahzxl deleted the moe_train branch September 12, 2023 02:12
oahzxl added a commit to oahzxl/ColossalAI that referenced this pull request Sep 15, 2023
* init

* update moe ckpt

* update config

* support openmoe infernece

* update config

* remove pdb

* support train

* add ckpt download

* update ckpt loading

* use general ckpt
oahzxl added a commit to oahzxl/ColossalAI that referenced this pull request Sep 15, 2023
* init

* update moe ckpt

* update config

* support openmoe infernece

* update config

* remove pdb

* support train

* add ckpt download

* update ckpt loading

* use general ckpt
oahzxl added a commit to oahzxl/ColossalAI that referenced this pull request Oct 26, 2023
* init

* update moe ckpt

* update config

* support openmoe infernece

* update config

* remove pdb

* support train

* add ckpt download

* update ckpt loading

* use general ckpt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants