Persimmon FlashAttention2 [WIP]#26482
Conversation
This test at The other test failure is due to the tokenizer does not have a pad token: I'm guessing this is related to why the |
|
Thanks for your work on this! |
|
Regarding the padding mask, both In contrast, neither the
Since the |
3f28d30 to
772a732
Compare
|
@younesbelkada |
younesbelkada
left a comment
There was a problem hiding this comment.
Thanks a lot for your great work !
I left 2 comments, let's wait for #26792 being merged and I can help you implement pad/unpad in this PR!
| # Not needed for Persimmon | ||
| # if padding_mask is not None: | ||
| # batch_size = query_states.shape[0] | ||
| # query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( | ||
| # query_states, key_states, value_states, padding_mask, query_length | ||
| # ) | ||
|
|
||
| # cu_seqlens_q, cu_seqlens_k = cu_seq_lens | ||
| # max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens | ||
|
|
||
| # attn_output_unpad = flash_attn_varlen_func( | ||
| # query_states, | ||
| # key_states, | ||
| # value_states, | ||
| # cu_seqlens_q=cu_seqlens_q, | ||
| # cu_seqlens_k=cu_seqlens_k, | ||
| # max_seqlen_q=max_seqlen_in_batch_q, | ||
| # max_seqlen_k=max_seqlen_in_batch_k, | ||
| # dropout_p=dropout, | ||
| # softmax_scale=softmax_scale, | ||
| # causal=True, | ||
| # ) | ||
|
|
||
| # attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) | ||
| # else: |
There was a problem hiding this comment.
| # Not needed for Persimmon | |
| # if padding_mask is not None: | |
| # batch_size = query_states.shape[0] | |
| # query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( | |
| # query_states, key_states, value_states, padding_mask, query_length | |
| # ) | |
| # cu_seqlens_q, cu_seqlens_k = cu_seq_lens | |
| # max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens | |
| # attn_output_unpad = flash_attn_varlen_func( | |
| # query_states, | |
| # key_states, | |
| # value_states, | |
| # cu_seqlens_q=cu_seqlens_q, | |
| # cu_seqlens_k=cu_seqlens_k, | |
| # max_seqlen_q=max_seqlen_in_batch_q, | |
| # max_seqlen_k=max_seqlen_in_batch_k, | |
| # dropout_p=dropout, | |
| # softmax_scale=softmax_scale, | |
| # causal=True, | |
| # ) | |
| # attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) | |
| # else: |
| if input_dtype == torch.float32: | ||
| logger.warning_once( | ||
| "The input hidden states seems to be silently casted in float32, this might be related to" | ||
| " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | ||
| " float16." | ||
| ) | ||
|
|
||
| query_states = query_states.to(torch.float16) | ||
| key_states = key_states.to(torch.float16) | ||
| value_states = value_states.to(torch.float16) |
There was a problem hiding this comment.
Can you apply the same procedure as in https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L478-L492 ?
|
OK thanks! let's then close this PR in favor of #27052 ? what do you think? |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
Any update on this? |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
Adds Flash Attention 2 for Persimmon per #26350
Before submitting
Who can review?
@younesbelkada
Ran tests on A100 80G, see attached for venv.
requirements.txt