Skip to content

Conversation

@yangjianfengo1
Copy link
Contributor

@yangjianfengo1 yangjianfengo1 commented Aug 4, 2025

在flash attention 3的基础上支持了mask形式,其中mask是一个shape为[q_seq_len]的int 数组,对于第i个token,那么qkgemm矩阵的第i行,第mask[i]列之后的数字都会被mask 掉,即qkgemm[i,mask[i]:] 会被mask掉,例如要使用casual mask的话,mask的数组就为[1,2,3,4,5,.......]

若mask传入None,默认采用casual mask

@paddle-bot
Copy link

paddle-bot bot commented Aug 4, 2025

Thanks for your contribution!

naive_attn_out = naive_attn(q_input, k_input, v_input, mask)
paddle_attn_out = paddle_flash_attn_mask(q_input, k_input, v_input, mask)

print((paddle_attn_out.reshape([-1]) - paddle.to_tensor(naive_attn_out).reshape([-1])).max())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里用assert判断吧,CI监控起来单测

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@yuanlehome yuanlehome merged commit 40f7f3e into PaddlePaddle:develop Aug 5, 2025
12 of 14 checks passed
@yangjianfengo1 yangjianfengo1 deleted the fa3 branch August 6, 2025 04:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants