Skip to content

Remove tma padding for fwd inputs#85

Merged
beginlner merged 1 commit intodeepseek-ai:mainfrom
dianzhangchen:feat/fmha_rm_padding
Aug 25, 2025
Merged

Remove tma padding for fwd inputs#85
beginlner merged 1 commit intodeepseek-ai:mainfrom
dianzhangchen:feat/fmha_rm_padding

Conversation

@dianzhangchen
Copy link
Copy Markdown
Contributor

@beginlner beginlner merged commit 2d291b0 into deepseek-ai:main Aug 25, 2025
@SeanLi-OI
Copy link
Copy Markdown
Contributor

SeanLi-OI commented Aug 27, 2025

Hi, @dianzhangchen
I run test_fmha_sm100.py with these params

b=2 varlen=False h=48 h_k=48 d=192 dv=128 causal=True
seqlens_q=torch.tensor([7680, 512])
seqlens_k=torch.tensor([7680, 512])

and met a significant diff between flashmla and torch_attn.
image

Could you check if you can reproduce it? Maybe met some corner case again.

@dianzhangchen
Copy link
Copy Markdown
Contributor Author

Hi, @dianzhangchen I run test_fmha_sm100.py with these params

b=2 varlen=False h=48 h_k=48 d=192 dv=128 causal=True
seqlens_q=torch.tensor([7680, 512])
seqlens_k=torch.tensor([7680, 512])

and met a significant diff between flashmla and torch_attn. image

Could you check if you can reproduce it? or met some corner case again.

I noticed the input sequence length is varlen, but the varlen variable is set to False. You might try setting varlen to True

@SeanLi-OI
Copy link
Copy Markdown
Contributor

SeanLi-OI commented Aug 27, 2025

@dianzhangchen Sorry for my mistake, the precision has no issue when varlen is True.
But I'm running into another issue with these params

b=44 varlen=True h=48 h_k=48 d=192 dv=128 causal=True
seqlens_q=torch. Tensor([8192,  379,  379,  379,  379,  379,  379,  379,  379,  379,  379,  379,
         234,  379,  379,  379,  379,  379,  379,  379,  379,  379,  378,  917,
         917,  917,  919,  918,  918,  852,  917,  917,  523,  697,  698,  698,
         697,  697,  698,  697,  697,  697,  697,  696])
seqlens_k=torch. Tensor([8192,  379,  379,  379,  379,  379,  379,  379,  379,  379,  379,  379,
         234,  379,  379,  379,  379,  379,  379,  379,  379,  379,  378,  917,
         917,  917,  919,  918,  918,  852,  917,  917,  523,  697,  698,  698,
         697,  697,  698,  697,  697,  697,  697,  696])
image

I print the size and offset in csrc/sm100/common/utils.hpp#L50. Shows a extreme large value.

num_element: 18446744069955649536
size: 18446744069955649536
offset: 0

(BTW this is the first issue I encountered. During my initial attempt to reproduce it, I mistakenly set the varlen to False and was therefore unable to reproduce this problem. 😥)

@dianzhangchen
Copy link
Copy Markdown
Contributor Author

@dianzhangchen Sorry for my mistake, the precision has no issue when varlen is True. But I'm running into another issue with these params

b=44 varlen=True h=48 h_k=48 d=192 dv=128 causal=True
seqlens_q=torch. Tensor([8192,  379,  379,  379,  379,  379,  379,  379,  379,  379,  379,  379,
         234,  379,  379,  379,  379,  379,  379,  379,  379,  379,  378,  917,
         917,  917,  919,  918,  918,  852,  917,  917,  523,  697,  698,  698,
         697,  697,  698,  697,  697,  697,  697,  696])
seqlens_k=torch. Tensor([8192,  379,  379,  379,  379,  379,  379,  379,  379,  379,  379,  379,
         234,  379,  379,  379,  379,  379,  379,  379,  379,  379,  378,  917,
         917,  917,  919,  918,  918,  852,  917,  917,  523,  697,  698,  698,
         697,  697,  698,  697,  697,  697,  697,  696])
image (BTW this is the first issue I encountered. Maybe something about int32 in tensor shape and related to my [PR](https://github.com//pull/87). 😥 During my initial attempt to reproduce it, I mistakenly set the varlen to False and was therefore unable to reproduce this problem. )

The error here doesn’t seem to be caused by the fwd kernel. You may open a new issue to report this. Thanks.

@SeanLi-OI
Copy link
Copy Markdown
Contributor

@dianzhangchen Sorry for my mistake, the precision has no issue when varlen is True. But I'm running into another issue with these params

b=44 varlen=True h=48 h_k=48 d=192 dv=128 causal=True
seqlens_q=torch. Tensor([8192,  379,  379,  379,  379,  379,  379,  379,  379,  379,  379,  379,
         234,  379,  379,  379,  379,  379,  379,  379,  379,  379,  378,  917,
         917,  917,  919,  918,  918,  852,  917,  917,  523,  697,  698,  698,
         697,  697,  698,  697,  697,  697,  697,  696])
seqlens_k=torch. Tensor([8192,  379,  379,  379,  379,  379,  379,  379,  379,  379,  379,  379,
         234,  379,  379,  379,  379,  379,  379,  379,  379,  379,  378,  917,
         917,  917,  919,  918,  918,  852,  917,  917,  523,  697,  698,  698,
         697,  697,  698,  697,  697,  697,  697,  696])
image (BTW this is the first issue I encountered. Maybe something about int32 in tensor shape and related to my [PR](https://github.com/[/pull/87](https://github.com//pull/87)). 😥 During my initial attempt to reproduce it, I mistakenly set the varlen to False and was therefore unable to reproduce this problem. )

The error here doesn’t seem to be caused by the fwd kernel. You may open a new issue to report this. Thanks.

Sure, I open another issue #90 . Thanks for looking into it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants