Skip to content

Unify strides of o and do in attention backward#1

Merged
Bznkxs merged 1 commit into
mainfrom
Bznkxs-bwd-stride-unification
Jan 7, 2025
Merged

Unify strides of o and do in attention backward#1
Bznkxs merged 1 commit into
mainfrom
Bznkxs-bwd-stride-unification

Conversation

@Bznkxs
Copy link
Copy Markdown
Owner

@Bznkxs Bznkxs commented Jan 7, 2025

When performing backward propagation, o and do will sometimes have different strides and fail the stride check:

https://github.com/linxihui/dkernel/blob/main/dkernel/ops/sparse_attn_bwd.py#L675

This is true when training a model with GQA, where the key and value need to be repeated before passing to the kernel. When passing in q, k and v with size [B, S, H, D] and stride [S*H*D, H*D, D, 1] where B == 1, the output o has the same size and stride, but do has stride [H*D, H*D, D, 1]. I have not tested training a model without GQA.

According to this attention implementation of Megatron-LM, stride for a dimension that is 1 has no meaning, so the two strides mean the same thing. We adapt their solution to pass the stride check here.

@Bznkxs Bznkxs merged commit ce34e4d into main Jan 7, 2025
@Bznkxs Bznkxs deleted the Bznkxs-bwd-stride-unification branch January 7, 2025 20:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant