Skip to content

Persimmon FlashAttention2 [WIP]#26482

Closed
jeromeku wants to merge 3 commits intohuggingface:mainfrom
jeromeku:persimmon-FA2
Closed

Persimmon FlashAttention2 [WIP]#26482
jeromeku wants to merge 3 commits intohuggingface:mainfrom
jeromeku:persimmon-FA2

Conversation

@jeromeku
Copy link
Copy Markdown

What does this PR do?

Adds Flash Attention 2 for Persimmon per #26350

Before submitting

Who can review?

@younesbelkada

Ran tests on A100 80G, see attached for venv.

RUN_SLOW=1 pytest -sv --disable-warnings -k flash_attn_2 tests/models/persimmon/test_modeling_persimmon.py::PersimmonModelTest
================================================================================= test session starts ==================================================================================
platform linux -- Python 3.9.17, pytest-7.4.2, pluggy-1.3.0 -- /notebooks/virtualenvs/persimmon-fa2/bin/python
cachedir: .pytest_cache
configfile: setup.cfg
collecting ... Using /root/.cache/torch_extensions/py39_cu118 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/py39_cu118/cuda_kernel/build.ninja...
Building extension module cuda_kernel...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module cuda_kernel...
Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers
pip install xformers.
collected 116 items / 110 deselected / 6 selected                                                                                                                                      

tests/models/persimmon/test_modeling_persimmon.py::PersimmonModelTest::test_flash_attn_2_conversion You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
PASSED
tests/models/persimmon/test_modeling_persimmon.py::PersimmonModelTest::test_flash_attn_2_generate_left_padding You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
PASSED
tests/models/persimmon/test_modeling_persimmon.py::PersimmonModelTest::test_flash_attn_2_generate_padding_right You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
PASSED
tests/models/persimmon/test_modeling_persimmon.py::PersimmonModelTest::test_flash_attn_2_generate_use_cache You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
PASSED
tests/models/persimmon/test_modeling_persimmon.py::PersimmonModelTest::test_flash_attn_2_inference You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
PASSED
tests/models/persimmon/test_modeling_persimmon.py::PersimmonModelTest::test_flash_attn_2_inference_padding_right You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
PASSED

==================================================================== 6 passed, 110 deselected, 8 warnings in 5.62s =====================================================================

requirements.txt

@jeromeku
Copy link
Copy Markdown
Author

jeromeku commented Oct 2, 2023

@younesbelkada

FAILED tests/models/persimmon/test_modeling_persimmon.py::PersimmonModelTest::test_pipeline_text_generation - AssertionError: (<class 'RuntimeError'>, <class 'IndexError'>, <class 'ValueError'>, <class 'AssertionError'>) not raised

This test at tests/pipelines/test_pipelines_text_generation.py:250: in run_pipeline_test text_generator("This is a test" * 500, max_new_tokens=20) is failing since Persimmon should be able to handle long sequences. The default config has max_position_embeddings = 16384 so the long sequence tests, which are run for model_max_length < 1000 in test_pipelines_text_generation are not raising errors (so assertRaises is failing).

The other test failure is due to the tokenizer does not have a pad token:

FAILED tests/models/persimmon/test_modeling_persimmon.py::PersimmonModelTest::test_pipeline_zero_shot - ValueError: Asking to pad but the tokenizer does not have a padding token. Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`.

I'm guessing this is related to why the PersimmonDecoderLayer and PersimmonFlashAttention don't have a kwarg for padding_mask compared to the respective implementations for LlamaDecoderLayer and LlamaFlashAttention...

@younesbelkada
Copy link
Copy Markdown
Contributor

Thanks for your work on this!
regarding the failing tests, this is very strange as the PR does not modify anything un-related to FA-2 so should keep all previous behaviour. Can you try to merge your branch with main and run the tests again?
Also can you elaborate on why pad/unpad is not needed for this architecture?

@jeromeku
Copy link
Copy Markdown
Author

jeromeku commented Oct 3, 2023

@younesbelkada

Regarding the padding mask, both LlamaDecoderLayer and LlamaFlashAttention forward signatures have a padding_mask as a kwarg -- see here and here.

In contrast, neither the PersimmonDecoderLayer nor the PersimmonFlashAttention forward signatures have a padding_mask as a kwarg.

LlamaFlashAttention2 per your implementation handles the cases with padding and no padding by calling two different methods of the flash_attention_interface, one which unpads, packs qkv and calls flash_attn_varlen_func), and in the other case with no padding flash_attn_func.

Since the Persimmon layers (per the original implementation) don't have a padding_mask as input, I only used flash_attn_func in the PersimmonFlashAttention2 implementation. Let me know if this needs to be changed.

@jeromeku
Copy link
Copy Markdown
Author

@younesbelkada
lmk if changes needed

Copy link
Copy Markdown
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your great work !
I left 2 comments, let's wait for #26792 being merged and I can help you implement pad/unpad in this PR!

Comment on lines +499 to +523
# Not needed for Persimmon
# if padding_mask is not None:
# batch_size = query_states.shape[0]
# query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
# query_states, key_states, value_states, padding_mask, query_length
# )

# cu_seqlens_q, cu_seqlens_k = cu_seq_lens
# max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

# attn_output_unpad = flash_attn_varlen_func(
# query_states,
# key_states,
# value_states,
# cu_seqlens_q=cu_seqlens_q,
# cu_seqlens_k=cu_seqlens_k,
# max_seqlen_q=max_seqlen_in_batch_q,
# max_seqlen_k=max_seqlen_in_batch_k,
# dropout_p=dropout,
# softmax_scale=softmax_scale,
# causal=True,
# )

# attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
# else:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Not needed for Persimmon
# if padding_mask is not None:
# batch_size = query_states.shape[0]
# query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
# query_states, key_states, value_states, padding_mask, query_length
# )
# cu_seqlens_q, cu_seqlens_k = cu_seq_lens
# max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
# attn_output_unpad = flash_attn_varlen_func(
# query_states,
# key_states,
# value_states,
# cu_seqlens_q=cu_seqlens_q,
# cu_seqlens_k=cu_seqlens_k,
# max_seqlen_q=max_seqlen_in_batch_q,
# max_seqlen_k=max_seqlen_in_batch_k,
# dropout_p=dropout,
# softmax_scale=softmax_scale,
# causal=True,
# )
# attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
# else:

Comment on lines +441 to +450
if input_dtype == torch.float32:
logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
)

query_states = query_states.to(torch.float16)
key_states = key_states.to(torch.float16)
value_states = value_states.to(torch.float16)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeromeku
Copy link
Copy Markdown
Author

@younesbelkada

Crteated a new branch and PR #27052

Went ahead and added 2d-->4d attention mask per #26792 and adjusted FA2 to accommodate attention mask.

@younesbelkada
Copy link
Copy Markdown
Contributor

OK thanks! let's then close this PR in favor of #27052 ? what do you think?

@github-actions
Copy link
Copy Markdown
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@xhluca
Copy link
Copy Markdown
Contributor

xhluca commented Nov 28, 2023

Any update on this?

@github-actions
Copy link
Copy Markdown
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions Bot closed this Jan 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants