Skip to content

[FA] Native torch integration#45153

Draft
vasqu wants to merge 23 commits intohuggingface:mainfrom
vasqu:torch-flash-attn
Draft

[FA] Native torch integration#45153
vasqu wants to merge 23 commits intohuggingface:mainfrom
vasqu:torch-flash-attn

Conversation

@vasqu
Copy link
Copy Markdown
Contributor

@vasqu vasqu commented Mar 31, 2026

As per title, with torch releasing the varlen API, we can somewhat use native FA (with limited feature support)

Restrictions

  • Unsupported features
    • Dropout
    • Learnable sinks (attention sinks)
    • Determinism
    • Softcap
    • CB KV cache native primitives

Enables

  • Packed forward without masks (on the fly) --> qwen vlms and the derivative models (glmv, ernie vl, etc)
  • Native FA without any external dependencies like the original pkg or kernels
    • CB support OOB
  • Experimental FA features in torch, e.g. using FA3/4 backends directly through torch

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, just a few comments/questions!

Comment thread src/transformers/modeling_flash_attention_utils.py
Comment on lines +337 to +341
global _flash_api_alternative_names
for name in [original_name, _flash_api_alternative_names.get(original_name, original_name)]:
if supports_mapping[name]:
kwargs_dict[name] = obj
return name
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Humm, so we modify the kwargs_dict in-place but return either the object (if None), or the name? Seems quite odd to me no?
Could we simply forward the whole kwarg dict to this function and remap kwarg names?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a return statement in the docstrings exactly for that:

Return:
        name (`str`, *optional*):
            The associated name the object was added to the kwargs (if it was added; otherwise None).

Yea, it is a bit odd but is a culmination of bad practice of not keeping to standards across so many libraries

  • I need the name for max_seq_len_(q/k) as they need extra care to avoid device syncs where we can
  • Packing everything at once is a nice idea but I will push a commit in a second that makes this a bit less elegant
    • Atm, we silently ignore if we do not set the object to the kwargs
    • Some features are not that important (e.g. dropout, deterministic) --> warn
    • Other features are core --> raise error

Comment thread src/transformers/modeling_flash_attention_utils.py
Comment on lines +632 to +633
def _flash_attention_mask_varlen(
flash_varlen_fn: Callable,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are varlen that different if we have a mask or not? It's simply a matter of creating the seq_lens no?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The core difference is in the way the input is prepared

  • Mask varlen has to manually unpad the input into packed sequences, keep track of the indices, and the finally manually pad to the original shape (padded with 0s) --> this is the only way how it can kind of work with our original caching logic
  • Pure varlen only needs to properly prepare the metadara (where do the sequences end) but the input is already properly packed

Comment on lines -9 to -10
_use_top_left_mask = flash_attn_supports_top_left_mask()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it not needed anymore? We went beyond the versions that required this hack?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is already silently not used anymore except for NPU users that set an environment variable. Imo, we should deprecate this (maybe a different PR but I need some cleanups here either way)

vasqu added 2 commits April 1, 2026 12:43
…ed in all but maybe npu), change error + warning logic, some simplifications
Comment on lines +159 to +166
_flash_api_alternative_names = {
"s_aux": "learnable_sink",
"cu_seqlens_q": "cu_seq_q",
"cu_seqlens_k": "cu_seq_k",
"max_seqlen_q": "max_q",
"max_seqlen_k": "max_k",
"softmax_scale": "scale",
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why ppl can't agree on using the same naming convention 😿

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea it's hell 😓 converted this to lists now just in case it gets even worse (which atp I wouldnt be surprised about)

)

if is_flash_attention_requested(self.config):
if is_flash_attention_requested(self.config, allow_torch=is_flash_attn_torch_available()):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was thinking more of using a single call to attention_interface, so maybe for qwen we need to override the set_attn_implementation and force-set "fa_torch" on vision config whenever sdpa is requested. WDYT?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't we need to have a list of exceptions of model types then where we force a different attn implementation? Or do you mean to overwrite in all pretrained ones - fearing that modular will get messy across the board maybe?

I also think, it's maybe smarter to not modify models in this PR and move this to a different PR. It's definitely the way forward tho to make this the default path then imo. Just need to check against our CI again etc.

Copy link
Copy Markdown
Contributor Author

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Self-review on points I wanted to clarify / highlight

use_top_left_mask=_use_top_left_mask,
target_dtype=target_dtype,
attn_implementation=module.config._attn_implementation,
layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None,
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was not used for quite a while now and just hidden behind kwargs

# Set environment variable `NPU_FA2_SPARSE_MODE` to 2 when using top-left aligned causal mask.
TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE = 2
# It can set an environment variable `NPU_FA2_SPARSE_MODE` to control this behavior.
TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE = 2 # Deprecated
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deprecating top left mask: The min FA version of 2.3.3 no longer uses that and we should just align

cu_seq_lens_q: torch.LongTensor | None
cu_seq_lens_k: torch.LongTensor | None
cu_seq_lens_q: torch.IntTensor | None
cu_seq_lens_k: torch.IntTensor | None
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not super important, but this typing was just wrong for quite a while

if implementation in ["sdpa", "flash_attention_torch"]:
from torch.nn.attention.varlen import varlen_attn as flash_attn_varlen_func

flash_attn_func = None # not supported yet
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested using F.scaled_dot_product but it is even more limited (no mask, no different seq lengths etc)

Comment on lines -191 to -202
if flash_attn_func is None:
logger.warning(
f"The loaded flash attention implementation at `{implementation}` only supports varlen, i.e. "
"it can only be used with continuous batching and does not support the full functionality for "
"the base transformers generation methods."
)
if flash_attn_with_kvcache is None:
logger.warning(
f"The loaded flash attention implementation at `{implementation}` does not support block tables, so"
" the full performances of continuous batching will not be achieved, only the varlen path will be "
"used."
)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These warnings were over the top, should happen at request time / run time if actually used -> err / warning

Comment on lines +437 to +442
# Torch varlen can use sliding window but also has to set it to determine causality
if flash_kwargs.get("causal") is None:
if flash_kwargs.get("window_size") is None:
flash_kwargs["window_size"] = (-1, 0) if is_causal else (-1, -1)
elif is_causal:
flash_kwargs["window_size"] = (flash_kwargs["window_size"][0], 0)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Torch does not have a native is_causal kwarg or similar, it sets it based on the window size... yet another convention out of nowhere

if isinstance(out_unpad, tuple):
out_unpad = out_unpad[0]

return pad_fn(out_unpad, indices_q, query_states.size(0), query_length)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a refactor to split this from the original forward

  • More grouped functions where they belong: base fa fn and everything processing related is above them
  • Needed for something else, explaining later in the normal entrypoint forward where we decide which of these to use

flash_kwargs = partial(
process_flash_kwargs_fn,
query_length=query_length,
key_length=key_states.size(1),
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They actually are not useful at all:

  • q length was only used for top left masking to determine causality
  • k length was used to sometimes skip setting the window size but honestly there is no real reason to create this overhead

Comment on lines +959 to +965
if flash_fn is None:
if not is_tracing(query_states):
logger.warning_once(
"We detected that your current underlying Flash Attention implementation does not implement a simple base"
"Flash Attention function (non-varlen). This can lead to slight inefficiencies (generation speed) and "
"changes in generation."
)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what I meant re refactoring functions out: Torch FA does not have a base FA (like I said tried using F.scaled_dot_product but it's even worse). This still produces the same outputs but due to kernel launches / block sizes, I suspect still slight deviations --> hence the warning

elif flash_attn_version == 2 and not is_flash_attn_greater_or_equal("2.3.3"):
raise ImportError(f"{preface} FlashAttention{flash_attn_version} requires at least version `2.3.3`.")
raise ImportError(f"{preface} Flash Attention {flash_attn_version} requires at least version `2.3.3`.")
elif flash_attn_version == "torch" and not is_torch_greater_or_equal("2.11.0"):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theoretically could have been 2.10 but then we would not have SWA and I think it's too core atp. Softcapping is not as popular for example

@vasqu
Copy link
Copy Markdown
Contributor Author

vasqu commented Apr 1, 2026

Ok, I will now start recompiling FA2 and FA3 to make comparisons again against the torch native version

Edit: Have them all now 2-4 + torch 2.11

Comment on lines -58 to -60
if is_flash_attn_available():
from ...integrations.flash_attention import get_target_dtype
from ...modeling_flash_attention_utils import _flash_attention_forward
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guards are not needed since a while now, now that we have proper lazy loading

# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Time to deprecate as well along

@vasqu
Copy link
Copy Markdown
Contributor Author

vasqu commented Apr 1, 2026

@stas00 should be interesting to you to use native FA within torch backend

  • Needs torch 2.11
  • You can load via attn_implementation=flash_attention_torch

Still figuring a few details out but outputs looks fairly reasonable

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 1, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: bark, diffllama, falcon, gpt_neo, gptj, kyutai_speech_to_text, mimi, moshi, nemotron

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Apr 1, 2026

Thank you for the heads up and working on this integration as well, Anton!

that would be very useful to have it built-in in pytorch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants