[Bug]Add sequence_parallel in layernorm init to enable 3D parallelism with DeepSpeed for non CUDA device.#468
Conversation
44619fa to
f099692
Compare
|
Hi @tjruwase, I think we have talked about this question before, |
|
Hi @tjruwase , is it possible to have this PR reviewed? This PR is to fix a Megatron-DeepSpeed incompatibility to torch.nn.layernorm. Without it Megatron-DeepSpeed does not work normally for non-CUDA devices. |
…run successfully with DeepSpeed Signed-off-by: yisheng <yi.sheng@intel.com>
|
Hi @tjruwase this PR had been updated, should be ready for merge. Thanks! |
@sfc-gh-truwase in case you mainly use the other github account |
…run successfully with DeepSpeed (deepspeedai#468) Signed-off-by: yisheng <yi.sheng@intel.com> Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>
…run successfully with DeepSpeed (deepspeedai#468) Signed-off-by: yisheng <yi.sheng@intel.com> Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>
…nabled (#479) * pass batch_dim_idx to deepspeed sequence parallel distributed attention for supporting batch size larger than 1 Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * add fused_rms_norm support on XPU device (#431) Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * [LLaMa] Adding support converting checkpoint from mds to hf (#432) * add support converting checkpoint from hf to mds * Fix PP issue * update Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * add device check when import ipex (#436) Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * fix TFLOPs calculation (#371) * fix TFLOPs calculation when GQA used, we observe right TFLOPs after this fix. when GQA is not used, huge difference in TFLOPs is solved with selective recompute . some other minor difference will also be observed as logits macs also added. * add copyrights Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * fix nan issue when running megatron-deepspeed (#434) Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * enable empty cache on XPU device (#438) Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * [wandb] disable wandb more gracefully (#422) Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * [Bug] Fix crash when logging optimizer state to tb (#417) Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * add FPDT support; add Ulysses rotary position embedding support Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * add FPDT support; add Ulysses rotary position embedding support Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * add FPDT support; add Ulysses rotary position embedding support Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * add FPDT support; add Ulysses rotary position embedding support Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * remove unnecessary files Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * set the warmup length to be FPDT chunk size if enabled Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * Enable Sequence Parallelism (#429) Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * grad_wei can't be NoneType when running with DeepSpeed, for zero3 will divided the gradient (#428) Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * fix init issue for rms_norm in squence_parallel (#448) Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * enable profiler for specific ranks (#451) Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * fix init issue for silently ignoring the deepspeed config (#452) Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * fix moe tflops (#445) Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * [tool]GQA convert support (#454) * [tools]GQA convert support * fix readme Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * Fix import error in `deepspeed_to_megatron.py` (#455) Previously, `deepspeed_to_megatron.py` would raise an import error due to the relative import. This commit fixes this issue by changing from the relative import to the absolute import like in `deepspeed_to_transformers.py`. Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * Update references to new GitHub org (deepspeedai) (#462) Signed-off-by: Logan Adams <loadams@microsoft.com> Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * add sequence_parallel in layernorm init to enable 3D parallelism can run successfully with DeepSpeed (#468) Signed-off-by: yisheng <yi.sheng@intel.com> Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> * fix bug when FPDT is disabled but with original Ulysses Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> Signed-off-by: jinghan yao yjhmitweb@gmail.com Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> --------- Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu> Signed-off-by: Logan Adams <loadams@microsoft.com> Signed-off-by: yisheng <yi.sheng@intel.com> Signed-off-by: jinghan yao yjhmitweb@gmail.com Co-authored-by: Jinghan Yao <yjhmitweb@ascend-rw02.ten.osc.edu> Co-authored-by: YiSheng5 <syhm@mail.ustc.edu.cn> Co-authored-by: billishyahao <yahao.he@gmail.com> Co-authored-by: Polisetty V R K Jyothendra Varma <jvarma@habana.ai> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Jinghan Yao <yjhmitweb@ascend-rw01.ten.osc.edu> Co-authored-by: ranzhejiang <zhejiang.ran@intel.com> Co-authored-by: Xinyu Lian <lian7@illinois.edu> Co-authored-by: inkcherry <mingzhi.liu@intel.com> Co-authored-by: hotsuyuki <hotsuyuki.kawanishi@gmail.com> Co-authored-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>
When you running on non-CUDA device, for 3D parallelism with DeepSpeed you will got this error, can see below:
[rank19]: File "/home/yisheng/anaconda3/envs/llm_pt_25/lib/python3.10/site-packages/deepspeed/runtime/pipe/module.py", line 214, in init
[rank19]: self._build()
[rank19]: File "/home/yisheng/anaconda3/envs/llm_pt_25/lib/python3.10/site-packages/deepspeed/runtime/pipe/module.py", line 270, in _build
[rank19]: module = layer.build()
[rank19]: File "/home/yisheng/anaconda3/envs/llm_pt_25/lib/python3.10/site-packages/deepspeed/runtime/pipe/module.py", line 74, in build
[rank19]: return self.typename(*self.module_args, **self.module_kwargs)
[rank19]: TypeError: LayerNorm.init() got an unexpected keyword argument 'sequence_parallel'
cause for Megatron-DeepSpeed, sequence_parallel is added in Megatron-DeepSpeed for layernorm, for current implementation, non-CUDA device is using from torch.nn import LayerNorm for layernorm, there is no attr named sequence_parallel, will cause init error for non-CUDA device.