Fix CB Accuracy Regression under FA2#45274
Fix CB Accuracy Regression under FA2#45274Qubitium wants to merge 11 commits intohuggingface:mainfrom
Conversation
|
Hi @Qubitium ! Thanks for reporting this issue, which I totally overlooked. I was wondering if the fix shouldn't be to just pad |
@remi-or I just checked your alternate PR and it is better in every way. Feel free to close this PR. As far as the |
|
Closing, any further discussion on the topic of FA2 can happen on #45323 :) |
What does this PR do?
CUDA graph reuse used the wrong key: replay reuse depended on padded tensor sizes, but FA varlen kernels also depend on non-tensor runtime ints such as max_seqlen_q and max_seqlen_k. That allowed CB to replay a graph captured for one FA runtime shape against a different one, which is why max_batch_tokens could change accuracy.
_ensure_decode_fast_path_is_available()only accepted FA3, but FA2/FA3 should both be supported.I highly suspect, looking at the comments that installed the FA3 only gates that the first bug is the one might be source cause of the output variance that lead to the FA2 off-gating.
Look at the accuracy collapse at max_bt (max batched token size for cb) == 384!
The 896/1024 are caused by another un-related bug that I did not push to this PR so this one is clean/isolated.
Code Agent Policy
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@remi-or @ArthurZucker @McPatate