Skip to content

[Gemma2] Support FA2 softcapping#31887

Merged
ArthurZucker merged 3 commits intomainfrom
gemma-fa2
Jul 11, 2024
Merged

[Gemma2] Support FA2 softcapping#31887
ArthurZucker merged 3 commits intomainfrom
gemma-fa2

Conversation

@ArthurZucker
Copy link
Copy Markdown
Collaborator

What does this PR do?

Adds support for the new FA2 softcapping following Dao-AILab/flash-attention#1025

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, looks good to me! 2.6.0 was released 3 hours ago, let's go

Copy link
Copy Markdown
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - thanks for adding!

@ArthurZucker ArthurZucker merged commit f4ec7a2 into main Jul 11, 2024
@ArthurZucker ArthurZucker deleted the gemma-fa2 branch July 11, 2024 09:57
ArthurZucker added a commit that referenced this pull request Jul 11, 2024
* Support softcapping

* strictly greater than

* update
@ShadowTeamCN
Copy link
Copy Markdown

Good to see this. Can we use it for model fine-tuning, or is it just for inference? Google recommends fine-tuning in 'eager' mode.

@ArthurZucker
Copy link
Copy Markdown
Collaborator Author

Now you can use it for finetuning as well if you have the correct version of FA2. Not sure if finetuning "requires" it

@heartkilla
Copy link
Copy Markdown

Great! Any plans for sdpa support as well?

@ArthurZucker
Copy link
Copy Markdown
Collaborator Author

Sdpa is a bit more complicated, we need to use flex attention, did not have time to implement. Do you want to open a PR?

@hiyouga
Copy link
Copy Markdown
Contributor

hiyouga commented Jul 13, 2024

Hi @ArthurZucker, should we also add the sliding window and soft-capping to flash_attn_func

else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)

just like
else:
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=causal,
window_size=(self.config.sliding_window, self.config.sliding_window),
)

@ArthurZucker
Copy link
Copy Markdown
Collaborator Author

ArthurZucker commented Jul 15, 2024

It should be here on main: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma2/modeling_gemma2.py#L361 we updated the whole FA2 integration .

On the release branch it was there AFAIK

@hiyouga
Copy link
Copy Markdown
Contributor

hiyouga commented Jul 15, 2024

Get it, thanks for replying!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants