[FA] Native torch integration#45153
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Cyrilvallez
left a comment
There was a problem hiding this comment.
Nice, just a few comments/questions!
| global _flash_api_alternative_names | ||
| for name in [original_name, _flash_api_alternative_names.get(original_name, original_name)]: | ||
| if supports_mapping[name]: | ||
| kwargs_dict[name] = obj | ||
| return name |
There was a problem hiding this comment.
Humm, so we modify the kwargs_dict in-place but return either the object (if None), or the name? Seems quite odd to me no?
Could we simply forward the whole kwarg dict to this function and remap kwarg names?
There was a problem hiding this comment.
There is a return statement in the docstrings exactly for that:
Return:
name (`str`, *optional*):
The associated name the object was added to the kwargs (if it was added; otherwise None).
Yea, it is a bit odd but is a culmination of bad practice of not keeping to standards across so many libraries
- I need the name for
max_seq_len_(q/k)as they need extra care to avoid device syncs where we can - Packing everything at once is a nice idea but I will push a commit in a second that makes this a bit less elegant
- Atm, we silently ignore if we do not set the object to the kwargs
- Some features are not that important (e.g. dropout, deterministic) --> warn
- Other features are core --> raise error
| def _flash_attention_mask_varlen( | ||
| flash_varlen_fn: Callable, |
There was a problem hiding this comment.
Are varlen that different if we have a mask or not? It's simply a matter of creating the seq_lens no?
There was a problem hiding this comment.
The core difference is in the way the input is prepared
- Mask varlen has to manually unpad the input into packed sequences, keep track of the indices, and the finally manually pad to the original shape (padded with 0s) --> this is the only way how it can kind of work with our original caching logic
- Pure varlen only needs to properly prepare the metadara (where do the sequences end) but the input is already properly packed
| _use_top_left_mask = flash_attn_supports_top_left_mask() | ||
|
|
There was a problem hiding this comment.
Is it not needed anymore? We went beyond the versions that required this hack?
There was a problem hiding this comment.
It is already silently not used anymore except for NPU users that set an environment variable. Imo, we should deprecate this (maybe a different PR but I need some cleanups here either way)
…ed in all but maybe npu), change error + warning logic, some simplifications
| _flash_api_alternative_names = { | ||
| "s_aux": "learnable_sink", | ||
| "cu_seqlens_q": "cu_seq_q", | ||
| "cu_seqlens_k": "cu_seq_k", | ||
| "max_seqlen_q": "max_q", | ||
| "max_seqlen_k": "max_k", | ||
| "softmax_scale": "scale", | ||
| } |
There was a problem hiding this comment.
why ppl can't agree on using the same naming convention 😿
There was a problem hiding this comment.
Yea it's hell 😓 converted this to lists now just in case it gets even worse (which atp I wouldnt be surprised about)
| ) | ||
|
|
||
| if is_flash_attention_requested(self.config): | ||
| if is_flash_attention_requested(self.config, allow_torch=is_flash_attn_torch_available()): |
There was a problem hiding this comment.
i was thinking more of using a single call to attention_interface, so maybe for qwen we need to override the set_attn_implementation and force-set "fa_torch" on vision config whenever sdpa is requested. WDYT?
There was a problem hiding this comment.
Wouldn't we need to have a list of exceptions of model types then where we force a different attn implementation? Or do you mean to overwrite in all pretrained ones - fearing that modular will get messy across the board maybe?
I also think, it's maybe smarter to not modify models in this PR and move this to a different PR. It's definitely the way forward tho to make this the default path then imo. Just need to check against our CI again etc.
vasqu
left a comment
There was a problem hiding this comment.
Self-review on points I wanted to clarify / highlight
| use_top_left_mask=_use_top_left_mask, | ||
| target_dtype=target_dtype, | ||
| attn_implementation=module.config._attn_implementation, | ||
| layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None, |
There was a problem hiding this comment.
Was not used for quite a while now and just hidden behind kwargs
| # Set environment variable `NPU_FA2_SPARSE_MODE` to 2 when using top-left aligned causal mask. | ||
| TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE = 2 | ||
| # It can set an environment variable `NPU_FA2_SPARSE_MODE` to control this behavior. | ||
| TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE = 2 # Deprecated |
There was a problem hiding this comment.
Deprecating top left mask: The min FA version of 2.3.3 no longer uses that and we should just align
| cu_seq_lens_q: torch.LongTensor | None | ||
| cu_seq_lens_k: torch.LongTensor | None | ||
| cu_seq_lens_q: torch.IntTensor | None | ||
| cu_seq_lens_k: torch.IntTensor | None |
There was a problem hiding this comment.
Not super important, but this typing was just wrong for quite a while
| if implementation in ["sdpa", "flash_attention_torch"]: | ||
| from torch.nn.attention.varlen import varlen_attn as flash_attn_varlen_func | ||
|
|
||
| flash_attn_func = None # not supported yet |
There was a problem hiding this comment.
I tested using F.scaled_dot_product but it is even more limited (no mask, no different seq lengths etc)
| if flash_attn_func is None: | ||
| logger.warning( | ||
| f"The loaded flash attention implementation at `{implementation}` only supports varlen, i.e. " | ||
| "it can only be used with continuous batching and does not support the full functionality for " | ||
| "the base transformers generation methods." | ||
| ) | ||
| if flash_attn_with_kvcache is None: | ||
| logger.warning( | ||
| f"The loaded flash attention implementation at `{implementation}` does not support block tables, so" | ||
| " the full performances of continuous batching will not be achieved, only the varlen path will be " | ||
| "used." | ||
| ) |
There was a problem hiding this comment.
These warnings were over the top, should happen at request time / run time if actually used -> err / warning
| # Torch varlen can use sliding window but also has to set it to determine causality | ||
| if flash_kwargs.get("causal") is None: | ||
| if flash_kwargs.get("window_size") is None: | ||
| flash_kwargs["window_size"] = (-1, 0) if is_causal else (-1, -1) | ||
| elif is_causal: | ||
| flash_kwargs["window_size"] = (flash_kwargs["window_size"][0], 0) |
There was a problem hiding this comment.
Torch does not have a native is_causal kwarg or similar, it sets it based on the window size... yet another convention out of nowhere
| if isinstance(out_unpad, tuple): | ||
| out_unpad = out_unpad[0] | ||
|
|
||
| return pad_fn(out_unpad, indices_q, query_states.size(0), query_length) |
There was a problem hiding this comment.
This is a refactor to split this from the original forward
- More grouped functions where they belong: base fa fn and everything processing related is above them
- Needed for something else, explaining later in the normal entrypoint forward where we decide which of these to use
| flash_kwargs = partial( | ||
| process_flash_kwargs_fn, | ||
| query_length=query_length, | ||
| key_length=key_states.size(1), |
There was a problem hiding this comment.
They actually are not useful at all:
- q length was only used for top left masking to determine causality
- k length was used to sometimes skip setting the window size but honestly there is no real reason to create this overhead
| if flash_fn is None: | ||
| if not is_tracing(query_states): | ||
| logger.warning_once( | ||
| "We detected that your current underlying Flash Attention implementation does not implement a simple base" | ||
| "Flash Attention function (non-varlen). This can lead to slight inefficiencies (generation speed) and " | ||
| "changes in generation." | ||
| ) |
There was a problem hiding this comment.
This is what I meant re refactoring functions out: Torch FA does not have a base FA (like I said tried using F.scaled_dot_product but it's even worse). This still produces the same outputs but due to kernel launches / block sizes, I suspect still slight deviations --> hence the warning
| elif flash_attn_version == 2 and not is_flash_attn_greater_or_equal("2.3.3"): | ||
| raise ImportError(f"{preface} FlashAttention{flash_attn_version} requires at least version `2.3.3`.") | ||
| raise ImportError(f"{preface} Flash Attention {flash_attn_version} requires at least version `2.3.3`.") | ||
| elif flash_attn_version == "torch" and not is_torch_greater_or_equal("2.11.0"): |
There was a problem hiding this comment.
Theoretically could have been 2.10 but then we would not have SWA and I think it's too core atp. Softcapping is not as popular for example
|
Ok, I will now start recompiling FA2 and FA3 to make comparisons again against the torch native version Edit: Have them all now 2-4 + torch 2.11 |
| if is_flash_attn_available(): | ||
| from ...integrations.flash_attention import get_target_dtype | ||
| from ...modeling_flash_attention_utils import _flash_attention_forward |
There was a problem hiding this comment.
Guards are not needed since a while now, now that we have proper lazy loading
| # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. | ||
| # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. | ||
| # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). | ||
| self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() |
There was a problem hiding this comment.
Time to deprecate as well along
|
@stas00 should be interesting to you to use native FA within torch backend
Still figuring a few details out but outputs looks fairly reasonable |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: bark, diffllama, falcon, gpt_neo, gptj, kyutai_speech_to_text, mimi, moshi, nemotron |
|
Thank you for the heads up and working on this integration as well, Anton! that would be very useful to have it built-in in pytorch. |
As per title, with torch releasing the varlen API, we can somewhat use native FA (with limited feature support)
Restrictions
Enables