Skip to content

Conversation

@yangjianfengo1
Copy link
Contributor

@yangjianfengo1 yangjianfengo1 commented Aug 4, 2025

w4afp8

kernel 实现

我们实现了w4afp8的group gemm,即矩阵A是fp8(目前只支持了E4M3),矩阵B是int4,并做了如下优化

  • 对于int4的反量化,我们直接将int4后四位放在了fp8的低四位上,他们支持如下关系

    $$ Y=2^{-9} \times X \qquad Y\in Fp8E4M3 \qquad X \in uint4 $$

  • 由于h卡上,gemm的B矩阵只能从共享内存读取,A矩阵即可以从共享内存读取,也可以在寄存器里面,由于int4类型的矩阵是B矩阵,如果要从共享内存读取的话,会有这样一个操作,即从全局内存加载到共享内存,接着从共享内存加载到寄存器反量化,然后保存到共享内存,对此,我们做了转置的矩阵乘,即$(B^T*A^T)^T$,最后通过stmatrix 保存到全局内存。

  • 假设矩阵A的shape为[M,K],矩阵B为[N,K],由于矩阵A的M一般是动态变化的,N,K是固定的为权重shape,转置后相当于N是动态变化,MK是固定的,在h卡上,GemmN的选择可以从8,16,24...256,当N小于256的时候,我们直接调用GemmN=(N+15)/16*16的gemm,当N大于256的时候,我们会调用两个gemm,分别是GemmN=256和GemmN=(N % 256+15)/16*16,尽可能的减少计算方面的浪费。

性能指标

我们和deepgemm做了如下测试,性能数据如下,其中N,K固定为7168和8192,M动态变化。

M w4afp8(us) deepgemm(us)
64 114 184
128 131 211
256 209 329
512 362 448
1024 730 799
2048 1440 1424
4096 2810 2901

使用方式

具体可参考test/operators/test_w4afp8_gemm.py
其中我们支持两种方式

  • token不padding:TokenPadding设定为0,tokens为一个[group+1]的数组,是token的前缀和
  • token padding:TokenPadding设定为实际padding大小,tokens为一个[group]的数组,是实际的token

@paddle-bot
Copy link

paddle-bot bot commented Aug 4, 2025

Thanks for your contribution!

}}
"""

gemm_case = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ffn2的shape也加下

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decoder的padding这个参数应该是2048

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

for case in gemm_case:
for n in range(16, 257, 16):
template_head_file.write(
gemm_template_case.format(M=case[0], K=case[1], N=n, BATCH=case[2], TYPE="BF16", PADDING=case[3], TAILN=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp16类型也加一下吧,和moe_ffn的调用接口兼容

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@@ -0,0 +1,140 @@
#include "cute/algorithm/copy.hpp"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增文件添加paddle copyrights 信息

@qingqing01 qingqing01 changed the title [New Feature] w4afp8 gemm支持 [New Feature] Support W4Afp8 MoE GroupGemm Aug 6, 2025
Copy link
Collaborator

@qingqing01 qingqing01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看起来没使用 pybind 注册到 cpp_extensions.cc, 后续使用注意性能

)

gap = (out_cuda - out_naive).abs()
assert float(gap.mean()) < 0.07
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测后续要规范化下写法

@yangjianfengo1
Copy link
Contributor Author

看起来没使用 pybind 注册到 cpp_extensions.cc, 后续使用注意性能

本身是一件事情,就是模型接入w4afp8,但是考虑到代码量太多,review困难和合入困难,所以会拆分两个pr提交,性能会在下个pr中优化

@gongshaotian gongshaotian merged commit 8939751 into PaddlePaddle:develop Aug 6, 2025
11 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants