Persimmon fa2 attention4d#27052
Conversation
|
lmk if my implementation of 4d attention mask (#26792) + FA2 needs tweaking. Regarding previous comment, I'd like to understand HF's current strategy for integrating third-party / OSS libraries and components. Given the rapid pace of innovation in this space, want to ensure that |
|
Thanks very much for your great contrib @jeromeku ! Sorry for the delay responding on the PR, I will have an extensive look at the PR and your questions by beginning of next week (from 30rd october) 🙏 |
|
Hi! great work and I don't mean to butt in here, but in case it helps take this home: I was trying to get this to work and ran into some issues with the latest (4.36.dev0) version of transformers after cloning this pr and rebasing on main. I had to do this because of the llama2 tokenizer/optimum import issue that I get using the transformers version as is verbatim on this pr. After scouring gh, I came across a fine-tuning repo for Fuyu, and the same author has a working version of FA2 for persimmon (I was able to train persimmon on it with FA2): pip install git+https://github.com/phillip-kravtsov/transformers.git@floating-updatesThis is experimental and the work of one dude, so for the record the working SHA is: hope this helps! |
|
Let me know how I can improve the PR. Also, would appreciate thoughts on previous query when you get a chance. Thanks! |
|
|
||
|
|
||
| # Copied from transformers.models.llama.modeling_llama.AttnMaskConverter | ||
| class AttnMaskConverter: |
There was a problem hiding this comment.
instead you can import it from modeling_attn_mask_utils:
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_maskThere was a problem hiding this comment.
Made the aforementioned changes but am getting odd errors with test_flash_attn_2_generate_padding_right test (all other tests pass).
For example, here are the outputs with and without FA2:
out=tensor([[38, 17, 72, 39, 27, 98, 50, 53],
[88, 83, 53, 72, 98, 43, 42, 90],
[44, 38, 53, 70, 57, 77, 58, 87],
[61, 57, 26, 71, 68, 5, 76, 60],
[29, 8, 61, 62, 66, 12, 52, 63],
[83, 71, 43, 61, 22, 83, 89, 89],
[50, 38, 11, 22, 42, 13, 65, 14],
[71, 71, 0, 9, 3, 37, 10, 84],
[81, 19, 56, 62, 67, 18, 58, 87],
[39, 9, 49, 48, 22, 13, 72, 20],
[26, 96, 7, 26, 54, 46, 32, 6],
[67, 9, 87, 93, 42, 8, 65, 14],
[25, 83, 70, 30, 32, 92, 25, 64]], device='cuda:0')
out_fa=tensor([[38, 17, 72, 39, 27, 98, 50, 95],
[88, 83, 53, 72, 98, 43, 42, 30],
[44, 38, 53, 70, 57, 77, 58, 87],
[61, 57, 26, 71, 68, 5, 76, 60],
[29, 8, 61, 62, 66, 12, 52, 28],
[83, 71, 43, 61, 22, 83, 89, 89],
[50, 38, 11, 22, 42, 13, 65, 10],
[71, 71, 0, 9, 3, 37, 10, 84],
[81, 19, 56, 62, 67, 18, 58, 87],
[39, 9, 49, 48, 22, 13, 72, 10],
[26, 96, 7, 26, 54, 46, 32, 94],
[67, 9, 87, 93, 42, 8, 65, 10],
[25, 83, 70, 30, 32, 92, 25, 64]], device='cuda:0')Seems like there is some unaccounted for randomness as some sequences match while are others do not.
This is odd given that the flash_attn_2_inference_padding_{left,right} tests passes as does flash_attn_2_generate_left_padding test.
How to interpret generate_padding_right test? The test is padding the last token in each sequence of a batch of input_ids then generating 1 token. Since the last token in the input is masked, does this mean the next token generated should be the prediction on the last token but where it only attends to prior tokens in the sequence (compared to the typical causal case where the prediction is also based on the token attending to itself)?
Also, when I run flash_attn_2 tests for Mistral, the generate_padding_right and inference_padding_right tests both fail as no ValueError is raised.
Any ideas what could be causing this? I've attached my venv.
|
Trying to get to the bottom of the issue:
Thoughts? FWIW, here's the output of |
|
Hi @jeromeku, the |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
Adds Flash Attention 2 for Persimmon per #26350
Adds 2d->4d attention mask per #26792
Who can review?
@younesbelkada