Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,21 @@ def _process_flash_attention_kwargs(
flash_kwargs (`dict`):
A dict of kwargs that are requested and supported.
"""

user_kwargs = {
"dropout_p": dropout,
"window_size": sliding_window,
"deterministic": deterministic,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not 100% sure about this one. If someone requests deterministic output, it might still be ok to emit non-deterministic output with a warning?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it depends I think! I will wait for @ArthurZucker and @vasqu inputs since they maintain the fa implementation

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't feel super confident about this

  • dropout is fair
  • sliding window might still unintentionally work if key_length <= sliding_window
  • deterministic has the env variable as well
  • s_aux is only supported in gpt oss and we check that people use it
    if value and "flash" in value and value.removeprefix("paged|") != "kernels-community/vllm-flash-attn3":

Additionally, I'm also for being more lenient at least on deterministic (warning)

TL;DR: The conditions below should be counted in when we raise the error + a small test would be nice

"softcap": softcap,
"s_aux": s_aux,
}
# Note 'window_size' in supports_mapping maps to our 'sliding_window' param
for k, v in user_kwargs.items():
if not supports_mapping[k] and v is not None:
raise ValueError(
f"Parameter `{k}` is not supported by this Flash Attention implementation but was set, please use a different attentionimplementation."
)

flash_kwargs = {
"causal": is_causal and not (use_top_left_mask and query_length == 1),
"softmax_scale": softmax_scale,
Expand Down
Loading