-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Fix the FA2 logic in the longcat_flash model #42549
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@vasqu , please help review. Thanks! |
| if config.qk_head_dim != config.v_head_dim: | ||
| self.skipTest( | ||
| reason="Flash Attention 2 requires qk_head_dim == v_head_dim, but got " | ||
| f"qk_head_dim={config.qk_head_dim}, v_head_dim={config.v_head_dim}" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this not skip all tests here? I doubt that the classes will have different head dims.
I'd rather we properly adjust the sizes than to skip - should always be the last resort.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this test only involves this single model_class. After digging deeper into LongcatFlashForCausalLM, I found that it already implements the padding pre-processing for FA2 internally. However, for the fallback FA2 path(kernels), it failed to correctly match the parameter naming. This PR updates that part with the fix. Please review it again!
13038ae to
c492337
Compare
13038ae to
11e4ac9
Compare
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx for iterating! I think we can generalize this some more to include all attentions? They will probably face similar issues
| uses_flash_attention_2 = ( | ||
| "flash" in self.config._attn_implementation and self.config._attn_implementation.endswith("2") | ||
| ) | ||
| if uses_flash_attention_2 and self.qk_head_dim != self.v_head_dim: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| uses_flash_attention_2 = ( | |
| "flash" in self.config._attn_implementation and self.config._attn_implementation.endswith("2") | |
| ) | |
| if uses_flash_attention_2 and self.qk_head_dim != self.v_head_dim: | |
| if "flash" in self.config._attn_implementation and self.qk_head_dim != self.v_head_dim: |
I think we should generalize this here to check for all flavors. FA3 etc would face the same issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Done.
| ) | ||
|
|
||
| if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: | ||
| if uses_flash_attention_2 and self.qk_head_dim != self.v_head_dim: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if uses_flash_attention_2 and self.qk_head_dim != self.v_head_dim: | |
| if "flash" in self.config._attn_implementation and self.qk_head_dim != self.v_head_dim: |
Same here then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
[For maintainers] Suggested jobs to run (before merge) run-slow: longcat_flash |
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect, let's merge
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
* Matching FA2 naming under kernels * make style * convert model * Follow the comments
What does this PR do?
FA2 does not support MLA (i.e., cases where the dimensions of Q, K, and V heads are inconsistent), so skip this test.