feat: general fsdp2 on non-MoE models + HF TP plan#352
Conversation
d5ad00d to
8e2e6f4
Compare
|
File another issue #413 to trace FSDP2 for MoE models.
|
fc6cc49 to
05d8cfe
Compare
08cce8c to
0dc55cc
Compare
@terrykong Thanks very much for pointing out this! I tested with almost the same script as yours before this commit fdb565c. It is fixed now, and other models won't be affect since they don't need special handle on |
|
Thanks @jgerh , have updated from your suggestions. |
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
0dc55cc to
72a8f35
Compare
|
Thanks for the quick fix @YUki-666 . Gemma2 seems to be okay now from a quick run: |
Signed-off-by: Yuki Huang <yukih@nvidia.com>



What does this PR do ?
custom-parallel-plan>opt-parallel-plan(which we implemented for certain models in FSDP2) >hf-tp-plan(HF's _tp_plan).Convergence test on
LlamaForCausalLM,Qwen2ForCausalLM,Qwen3ForCausalLM,Gemma2ForCausalLM,Gemma3ForCausalLM,Phi3ForCausalLMrun well.Convergence Test Detail
Llama-3.1-8B-Instruct (LlamaForCausalLM)

FSDP2-tp8-opt_plan vs FSDP2-tp8-hf_tp_plan
Qwen2ForCausalLM / Qwen3ForCausalLM
(Qwen2ForCausalLM)
FSDP2-tp4-opt_plan vs FSDP2-tp4-hf_tp_plan
(Qwen3ForCausalLM)
FSDP1 vs FSDP2-tp1
Gemma2ForCausalLM / Gemma3ForCausalLM
(Gemma2ForCausalLM)
FSDP1 vs FSDP2-tp1 vs FSDP2-tp4-hf_tp_plan
(Gemma3ForCausalLM)
FSDP1 vs FSDP2-tp1
Issues
Closes #156
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information