-
Notifications
You must be signed in to change notification settings - Fork 690
[New Feature] Support W4Afp8 MoE GroupGemm #3171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for your contribution! |
| }} | ||
| """ | ||
|
|
||
| gemm_case = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ffn2的shape也加下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
decoder的padding这个参数应该是2048
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
| for case in gemm_case: | ||
| for n in range(16, 257, 16): | ||
| template_head_file.write( | ||
| gemm_template_case.format(M=case[0], K=case[1], N=n, BATCH=case[2], TYPE="BF16", PADDING=case[3], TAILN=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fp16类型也加一下吧,和moe_ffn的调用接口兼容
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
| @@ -0,0 +1,140 @@ | |||
| #include "cute/algorithm/copy.hpp" | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
新增文件添加paddle copyrights 信息
qingqing01
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看起来没使用 pybind 注册到 cpp_extensions.cc, 后续使用注意性能
| ) | ||
|
|
||
| gap = (out_cuda - out_naive).abs() | ||
| assert float(gap.mean()) < 0.07 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测后续要规范化下写法
本身是一件事情,就是模型接入w4afp8,但是考虑到代码量太多,review困难和合入困难,所以会拆分两个pr提交,性能会在下个pr中优化 |
w4afp8
kernel 实现
我们实现了w4afp8的group gemm,即矩阵A是fp8(目前只支持了E4M3),矩阵B是int4,并做了如下优化
对于int4的反量化,我们直接将int4后四位放在了fp8的低四位上,他们支持如下关系
由于h卡上,gemm的B矩阵只能从共享内存读取,A矩阵即可以从共享内存读取,也可以在寄存器里面,由于int4类型的矩阵是B矩阵,如果要从共享内存读取的话,会有这样一个操作,即从全局内存加载到共享内存,接着从共享内存加载到寄存器反量化,然后保存到共享内存,对此,我们做了转置的矩阵乘,即$(B^T*A^T)^T$,最后通过stmatrix 保存到全局内存。
假设矩阵A的shape为[M,K],矩阵B为[N,K],由于矩阵A的M一般是动态变化的,N,K是固定的为权重shape,转置后相当于N是动态变化,MK是固定的,在h卡上,GemmN的选择可以从8,16,24...256,当N小于256的时候,我们直接调用GemmN=(N+15)/16*16的gemm,当N大于256的时候,我们会调用两个gemm,分别是GemmN=256和GemmN=(N % 256+15)/16*16,尽可能的减少计算方面的浪费。
性能指标
我们和deepgemm做了如下测试,性能数据如下,其中N,K固定为7168和8192,M动态变化。
使用方式
具体可参考test/operators/test_w4afp8_gemm.py
其中我们支持两种方式