Fix FA2 tests#29909
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
AH. That's a great catch. Thanks for it!
| model = model_class.from_pretrained( | ||
| tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" | ||
| ) | ||
| model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) |
There was a problem hiding this comment.
let's update the name to test_flash_attn_2_inference_equivalence or something like that!
There was a problem hiding this comment.
Will do!
On a side note, how to make sure that every model using FA2 still passes ? The tests are slow, so I'm not actually sure the CI is totally green ?
There was a problem hiding this comment.
You'll need to run the tests manually. You can select just the flash attention tests by doing something like:
RUN_SLOW=1 pytest tests/models -k "flash_attn" on a GPU setup
amyeroberts
left a comment
There was a problem hiding this comment.
Good spot - thanks for fixing!
| model = model_class.from_pretrained( | ||
| tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" | ||
| ) | ||
| model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) |
There was a problem hiding this comment.
You'll need to run the tests manually. You can select just the flash attention tests by doing something like:
RUN_SLOW=1 pytest tests/models -k "flash_attn" on a GPU setup
|
I've ran I'll open an issue to keep trace of the different failures. Should I still merge the PR in the meantime? |
|
@ylacombe Thanks for running and sharing the results! Merging depends on whether the same tests are failing on main, if they are, then merging is fine; if not, the tests will need to be fixed :) |
|
Testing this right now then ! |
|
Well, the same tests fail except qwen2 and stablelm that are introduced by this PR, but this makes sense since the FA2 tests were'nt actually testing FA2 |
|
Feel free to mege! |
😨😨😨😨😨 |
|
Thanks a lot ❤️ for the fix and great catch! One nit: It would be really nice 🙏 if you can mention, in the PR description, a bit why the previous testing is done improperly. Something as simple as
This way, it's super clear what the PR is doing even before diving into the changes. |
|
afaik many FA2 tests were already failing (they are not in the CI) due to diffs in logits |
@fxmarty I think we or you (?) have run those tests before merging. Do you know why we have many failing FA2 tests? Or those many failing tests are only for newly added (many) models ..? |
|
Oh, they are not run on T4 GPUs. |
|
@ydshieh When I used to run these tests locally (some months ago), it was because the diff tolerance was too low between eager/fa2. Some models (as whisper) somehow require a large diff tolerance |
What does this PR do?
#26572 introduced an artifact that avoid properly testing inference with Flash Attention 2, the model supposed to be loaded without Flash Attention 2 (as a reference to compare) was in fact using Flash Attention 2!
cc @fxmarty @ArthurZucker @amyeroberts