From cd4c7cb5a4e4213ad16b25346dccde642db27370 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 17 Jul 2025 12:50:56 +0200 Subject: [PATCH 01/38] use partial to wrap around `transformers` utils! --- src/transformers/modeling_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6fadc8adf1e5..04cc2bc21a3b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2286,11 +2286,13 @@ def _check_attn_implementation(cls, attn_implementation: Union[str, dict]) -> Un repo_id, kernel_name = attn_implementation.split(":") kernel_name = kernel_name.strip() repo_id = repo_id.strip() - try: kernel = get_kernel(repo_id) - ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name)) - attn_implementation = f"kernel_{repo_id.replace('/', '_')}" + if "flash_attention" in kernel_name: + ALL_ATTENTION_FUNCTIONS[repo_id] = partial(flash_attention_forward, implementation=kernel) + else: + ALL_ATTENTION_FUNCTIONS[repo_id] = getattr(kernel, kernel_name) + attn_implementation = repo_id except FileNotFoundError as e: logger.warning( f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead." From 005f48213bd74c6773e7f52b4b24427b5eb06fe0 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 17 Jul 2025 13:01:23 +0200 Subject: [PATCH 02/38] try to refactor? --- .../modeling_flash_attention_utils.py | 656 ++++-------------- 1 file changed, 129 insertions(+), 527 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 1b5476b0ecc1..603680c5b041 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -1,17 +1,3 @@ -# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import inspect import os import warnings @@ -19,7 +5,6 @@ import torch import torch.nn.functional as F - from .utils import ( is_flash_attn_2_available, is_flash_attn_3_available, @@ -29,388 +14,123 @@ logging, ) - logger = logging.get_logger(__name__) -flash_attn_func = None -def _index_first_axis(tensor, indices): - """ - A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis, - after flattening the first two dimensions of the tensor. This is functionally equivalent to - FA2's `index_first_axis` and replaces the need to import it. - """ - # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first - # two dimensions to get (total_tokens, ...) before indexing. - reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:]) - return reshaped_tensor[indices] +def _index_first_axis(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + reshaped = tensor.reshape(-1, *tensor.shape[2:]) + return reshaped[indices] def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None): - """ - FA3-compatible unpad_input function. - - Arguments: - hidden_states: (batch, seqlen, ...) - attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. - unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. - Return: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. - indices: (total_nnz), the indices of masked tokens from the flattened input sequence. - cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. - max_seqlen_in_batch: int - seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. - """ - all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask - seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) - used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - - return ( - _index_first_axis(hidden_states, indices), - indices, - cu_seqlens, - max_seqlen_in_batch, - used_seqlens_in_batch, - ) + masks = attention_mask + unused_mask if unused_mask is not None else attention_mask + lengths = masks.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(masks.flatten(), as_tuple=False).flatten() + max_len = lengths.max().item() + cu = F.pad(torch.cumsum(lengths, dim=0, dtype=torch.int32), (1, 0)) + return (_index_first_axis(hidden_states, indices), indices, cu, max_len, attention_mask.sum(dim=-1, dtype=torch.int32)) -def _fa3_pad_input(hidden_states, indices, batch, seqlen): - """ - FA3-compatible pad_input function. - - Arguments: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. - batch: int, batch size for the padded sequence. - seqlen: int, maximum sequence length for the padded sequence. - Return: - hidden_states: (batch, seqlen, ...) - """ - dim = hidden_states.shape[1:] - output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) - output[indices] = hidden_states - return output.view(batch, seqlen, *dim) - - -FA_VERSION = None -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func as flash_attn_2_func - from flash_attn import flash_attn_varlen_func as flash_attn_2_varlen_func - from flash_attn.bert_padding import pad_input as pad_input_fa2 - from flash_attn.bert_padding import unpad_input as unpad_input_fa2 - from flash_attn.layers.rotary import apply_rotary_emb - - HAS_FA2 = True - FA_VERSION = 2 -elif is_torch_npu_available(): - # patch functions in package `flash-attn` when using flash-attention on Ascend NPU. - from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa: F401 - from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_2_func - from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_2_varlen_func - from .integrations.npu_flash_attention import pad_input as pad_input_fa2 - from .integrations.npu_flash_attention import unpad_input as unpad_input_fa2 - - HAS_FA2 = True - FA_VERSION = 2 -else: - flash_attn_2_func = None - flash_attn_2_varlen_func = None - pad_input_fa2 = None - unpad_input_fa2 = None - apply_rotary_emb = None - HAS_FA2 = False - -if is_flash_attn_3_available(): - from flash_attn_interface import flash_attn_func as flash_attn_3_func - from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func - - pad_input_fa3 = _fa3_pad_input - unpad_input_fa3 = _fa3_unpad_input - HAS_FA3 = True - FA_VERSION = 3 -else: - flash_attn_3_func = None - flash_attn_3_varlen_func = None - pad_input_fa3 = None - unpad_input_fa3 = None - HAS_FA3 = False - - -# Current Flash Attention implementations -if FA_VERSION: - flash_attn_func = globals()[f"flash_attn_{FA_VERSION}_func"] - flash_attn_varlen_func = globals()[f"flash_attn_{FA_VERSION}_varlen_func"] - unpad_input = globals()[f"unpad_input_fa{FA_VERSION}"] - pad_input = globals()[f"pad_input_fa{FA_VERSION}"] - - -_flash_supports_window_size = False - - -if flash_attn_func: - _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) +def _fa3_pad_input(hidden_states, indices, batch: int, seqlen: int): + out = torch.zeros((batch * seqlen), *hidden_states.shape[1:], + device=hidden_states.device, dtype=hidden_states.dtype) + out[indices] = hidden_states + return out.view(batch, seqlen, *hidden_states.shape[1:]) -def is_flash_attn_available(): - """Determine whether flash-attention can be used or not.""" - - if is_flash_attn_3_available(): - return True - - # if package `flash-attn` is available, flash-attention can be used natively. - if is_flash_attn_2_available(): - return True - - # flash-attention can be used on Ascend NPU without package `flash-attn` - if is_torch_npu_available(): - return True - - return False - - -def flash_attn_supports_top_left_mask(): - """Determine whether flash-attention uses top-left or down-right mask""" - - if is_flash_attn_3_available(): - return False - - if is_flash_attn_2_available(): - # top-left mask is used in package `flash-attn` with version lower than 2.1.0 - return not is_flash_attn_greater_or_equal_2_10() - - if is_torch_npu_available(): - # down-right mask is used on Ascend NPU by default, set env `NPU_FA2_SPARSE_MODE=2` to activate top-left mask. - from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask - - return is_npu_fa2_top_left_aligned_causal_mask() - - return False - - -def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: - """ - Retrieves indexing data required to repad unpadded (ragged) tensors. - - Arguments: - attention_mask (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - - Return: - indices (`torch.Tensor`): - The indices of non-masked tokens from the flattened input sequence. - cu_seqlens (`torch.Tensor`): - The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - max_seqlen_in_batch (`int`): - Maximum sequence length in batch. - """ - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - # NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile, - # this might cause a graph break - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) +def _get_unpad_data(attn_mask: torch.Tensor): + seqlens = attn_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten() + max_len = seqlens.max().item() + cu = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)) + return indices, cu, max_len def _upad_input( - query_layer: torch.Tensor, - key_layer: torch.Tensor, - value_layer: torch.Tensor, - attention_mask: torch.Tensor, - query_length: int, - unpad_input_func, + q, k, v, attn_mask: torch.Tensor, q_len: int, unpad_fn ): - """ - Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. - - This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary - tensors for query, key, value tensors. - - Arguments: - query_layer (`torch.Tensor`): - Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). - key_layer (`torch.Tensor`): - Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - value_layer (`torch.Tensor`): - Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - attention_mask (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - query_length (`int`): - Target length. - unpad_input_func: - The function to use for unpadding the input tensors. - - Return: - query_layer (`torch.Tensor`): - Query state without padding. Shape: (total_target_length, num_heads, head_dim). - key_layer (`torch.Tensor`): - Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - value_layer (`torch.Tensor`): - Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - indices_q (`torch.Tensor`): - The indices of non-masked tokens from the flattened input target sequence. - (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): - The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): - Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). - """ - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - - # With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage - # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores - if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]): - key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :] - - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = _index_first_axis(key_layer, indices_k) - value_layer = _index_first_axis(value_layer, indices_k) - if query_length == kv_seq_len: - query_layer = _index_first_axis(query_layer, indices_k) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + indices_k, cu_k, max_k = _get_unpad_data(attn_mask) + # trim KV if too long + if k.shape[1] > attn_mask.shape[-1]: + k, v = k[:, :attn_mask.shape[-1]], v[:, :attn_mask.shape[-1]] + q = (_index_first_axis(q, indices_k) if q_len == attn_mask.shape[-1] + else _index_first_axis(q.squeeze(1), torch.arange(q.shape[0], device=q.device))) + k = _index_first_axis(k, indices_k) + v = _index_first_axis(v, indices_k) + cu_q = cu_k if q_len == attn_mask.shape[-1] else torch.arange(q.shape[0]+1, dtype=torch.int32, device=q.device) + max_q = max_k if q_len == attn_mask.shape[-1] else 1 + return q, k, v, indices_k, (cu_q, cu_k), (max_q, max_k) + + +def _prepare_from_posids(q, k, v, pos_ids: torch.Tensor): + batch, ql, nh, hd = q.shape + _, kl, nhk, _ = k.shape + q_flat = q.view(-1, nh, hd) + k_flat = k.contiguous().view(-1, nhk, hd) + v_flat = v.contiguous().view(-1, nhk, hd) + pos = pos_ids.flatten() + idx = torch.arange(pos.numel(), device=pos.device, dtype=torch.int32) + cu = torch.cat((idx[pos==0], torch.tensor(idx.size(0), device=pos.device, dtype=torch.int32))) + max_len = cu.diff().max().item() + return q_flat, k_flat, v_flat, idx, (cu, cu), (max_len, max_len) + + +def _prepare_flash_attention_from_position_ids( + query, key, value, position_ids +): + warnings.warn( + "prepare_fa2_from_position_ids is deprecated, use _prepare_flash_attention_from_position_ids", + FutureWarning, ) + return _prepare_from_posids(query, key, value, position_ids) - -def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): - """ - This function returns necessary arguments to call `flash_attn_varlen_func`. - All three query, key, value states will be flattened. - Cumulative lengths of each examples in the batch will be extracted from position_ids. - - NOTE: ideally cumulative lengths should be prepared at the data collator stage - - Arguments: - query (`torch.Tensor`): - Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). - key (`torch.Tensor`): - Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - value (`torch.Tensor`): - Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - position_ids (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - - Return: - query (`torch.Tensor`): - Query state without padding. Shape: (total_target_length, num_heads, head_dim). - key (`torch.Tensor`): - Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - value (`torch.Tensor`): - Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - indices_q (`torch.Tensor`): - The indices of non-masked tokens from the flattened input target sequence. - (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): - The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): - Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). - """ - query = query.view(-1, query.size(-2), query.size(-1)) - key = key.contiguous().view(-1, key.size(-2), key.size(-1)) - value = value.contiguous().view(-1, value.size(-2), value.size(-1)) - position_ids = position_ids.flatten() - indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) - - cu_seq_lens = torch.cat( - ( - indices_q[position_ids == 0], - torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), +def fa_peft_integration_check(q, k, v, target_dtype: Optional[torch.dtype] = None): + if target_dtype and q.dtype == torch.float32: + logger.warning_once( + f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility." ) - ) + q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype) + return q, k, v - # NOTE: With torch compile, this will cause a graph break if you don't set - # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call - # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass. - # This is a limitation of flash attention API, as the function `flash_attn_varlen_func` - # requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`. - # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424 - # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing - # for some models (e.g. qwen2-vl). - max_length = cu_seq_lens.diff().max().item() - return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) +def _lazy_imports(impl: Optional[str]): + # returns funcs and pad/unpad based on impl + is_fa2 = is_flash_attn_2_available() or is_torch_npu_available() + is_fa3 = is_flash_attn_3_available() + if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import pad_input, unpad_input + from flash_attn.layers.rotary import apply_rotary_emb + return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, False + if impl == "flash_attention_3" or (impl is None and is_fa3): + from flash_attn_interface import flash_attn_func, flash_attn_varlen_func + pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input + return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, True + # fallback + raise ValueError(f"Invalid flash-attn implementation: {impl}") -def prepare_fa2_from_position_ids(*args, **kwargs): - warnings.warn( - "The function `prepare_fa2_from_position_ids` in `transformers.modeling_flash_attention_utils` is deprecated and will be removed in a future version. Please use `_prepare_flash_attention_from_position_ids` instead.", - FutureWarning, - ) - return _prepare_flash_attention_from_position_ids(*args, **kwargs) +_flash_supports_window = None -def fa_peft_integration_check( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - target_dtype: Optional[torch.dtype] = None, -): - """ - PEFT usually casts the layer norms in float32 for training stability reasons - therefore the input hidden states gets silently casted in float32. Hence, we need - cast them back in float16 / bfloat16 just to be sure everything works as expected. - This might slowdown training & inference so it is recommended to not cast the LayerNorms! - - Args: - query (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value (`torch.Tensor`): - Input value states to be passed to Flash Attention API - target_dtype (`torch.dtype`, *optional*): - The dtype to convert the attention tensors to. Conversion can be ignored by - not providing the target dtype. - """ - if target_dtype is None: - return query, key, value - - input_dtype = query.dtype - if input_dtype == torch.float32: - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) +def is_flash_attn_available(): + return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available() - query = query.to(target_dtype) - key = key.to(target_dtype) - value = value.to(target_dtype) - return query, key, value +def flash_attn_supports_top_left_mask(): + if is_flash_attn_3_available(): + return False + if is_flash_attn_2_available(): + return not is_flash_attn_greater_or_equal_2_10() + return from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask -flash_241 = is_flash_attn_greater_or_equal("2.4.1") -deterministic_g = None +class FlashAttentionKwargs(TypedDict, total=False): + cumulative_seqlens_q: Optional[torch.LongTensor] + cumulative_seqlens_k: Optional[torch.LongTensor] -def _flash_attention_forward( +def flash_attention_forward( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, @@ -429,185 +149,67 @@ def _flash_attention_forward( max_length_q: Optional[int] = None, max_length_k: Optional[int] = None, target_dtype: Optional[torch.dtype] = None, - attn_implementation: Optional[str] = None, + implementation: Optional[str] = None, **kwargs, ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`, *optional*): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`float`): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - use_top_left_mask (`bool`, defaults to `False`): - flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. - softcap (`float`, *optional*): - Softcap for the attention logits, used e.g. in gemma2. - deterministic (`bool`, *optional*): - Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. - attn_implementation (`str`, *optional*): - The attention implementation to use. If None, will default to the one based on the environment. - """ - if attn_implementation is None: - _flash_attn_varlen_func = flash_attn_varlen_func - _flash_attn_func = flash_attn_func - _pad_input = pad_input - _unpad_input = unpad_input - _is_fa3 = HAS_FA3 - elif attn_implementation == "flash_attention_3": - _flash_attn_varlen_func = flash_attn_3_varlen_func - _flash_attn_func = flash_attn_3_func - _pad_input = pad_input_fa3 - _unpad_input = unpad_input_fa3 - _is_fa3 = True - elif attn_implementation == "flash_attention_2": - _flash_attn_varlen_func = flash_attn_2_varlen_func - _flash_attn_func = flash_attn_2_func - _pad_input = pad_input_fa2 - _unpad_input = unpad_input_fa2 - _is_fa3 = False - - if not use_top_left_mask: - causal = is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. - causal = is_causal and query_length != 1 - - # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). - use_sliding_windows = ( - _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window - ) - flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} - - if _is_fa3: - if dropout > 0.0: - logger.warning_once("Flash Attention 3 does not support dropout. Setting dropout to 0.0.") - else: + # lazy import + flash_fn, flash_varlen_fn, pad_fn, unpad_fn, is_fa3 = _lazy_imports(implementation) + + causal = is_causal and not (use_top_left_mask and query_length == 1) + use_sw = ( + _flash_supports_window or "window_size" in inspect.signature(flash_fn).parameters + ) and sliding_window and key_states.shape[1] > sliding_window + flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sw else {} + if not is_fa3: flash_kwargs["dropout_p"] = dropout - - if flash_241: - if deterministic is None: - global deterministic_g - if deterministic_g is None: - deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" - deterministic = deterministic_g - flash_kwargs["deterministic"] = deterministic - + if is_flash_attn_greater_or_equal("2.4.1"): + det = deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + flash_kwargs["deterministic"] = det if softcap is not None: flash_kwargs["softcap"] = softcap - # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op + # dtype check query_states, key_states, value_states = fa_peft_integration_check( query_states, key_states, value_states, target_dtype ) - # We will use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach - # under two cases: - # Case 1. If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing - # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. - # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to - # use `flash_attn_varlen_func` knowing we already have all necessary the kwargs. NOTE: it is user's responsibility - # to take care of flattenning `position_ids` if that's needed by the model. See #39121 for more information - is_fa2_with_position_ids = ( - position_ids is not None - and query_states.shape[0] == 1 - and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())) - ) - is_fa2_with_varlen_kwargs = all( - kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k) - ) + # select varlen vs fixed + use_varlen = (position_ids is not None or all([cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k])) - # Contains at least one padding token in the sequence if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input( - query_states, key_states, value_states, attention_mask, query_length, _unpad_input + q, k, v, idx, (cu_q, cu_k), (mq, mk) = _upad_input( + query_states, key_states, value_states, attention_mask, query_length, unpad_fn ) - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = _flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - softmax_scale=softmax_scale, - causal=causal, - **flash_kwargs, + out_unpad = flash_varlen_fn( + q, k, v, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k, + max_seqlen_q=mq, max_seqlen_k=mk, softmax_scale=softmax_scale, + causal=causal, **flash_kwargs ) - attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length) - - elif is_fa2_with_varlen_kwargs or is_fa2_with_position_ids: - batch_size = query_states.size(0) - + out = pad_fn(out_unpad, idx, query_states.shape[0], query_length) + elif use_varlen: if cu_seq_lens_q is None or cu_seq_lens_k is None: - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = ( - _prepare_flash_attention_from_position_ids(query_states, key_states, value_states, position_ids) + q, k, v, idx, (cu_q, cu_k), (mq, mk) = _prepare_from_posids( + query_states, key_states, value_states, position_ids ) - - cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens - max_length_q, max_length_k = max_seq_lens - else: - query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) - key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) - value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) - - attn_output = _flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seq_lens_q, - cu_seqlens_k=cu_seq_lens_k, - max_seqlen_q=max_length_q, - max_seqlen_k=max_length_k, - softmax_scale=softmax_scale, - causal=causal, - **flash_kwargs, + q = query_states.view(-1, query_states.size(-2), query_states.size(-1)) + k = key_states.view(-1, key_states.size(-2), key_states.size(-1)) + v = value_states.view(-1, value_states.size(-2), value_states.size(-1)) + mq, mk = max_length_q, max_length_k + cu_q, cu_k = cu_seq_lens_q, cu_seq_lens_k + + out = flash_varlen_fn( + q, k, v, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k, + max_seqlen_q=mq, max_seqlen_k=mk, softmax_scale=softmax_scale, + causal=causal, **flash_kwargs ) - - attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) - + out = out.view(query_states.shape[0], -1, out.size(-2), out.size(-1)) else: - attn_output = _flash_attn_func( - query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs + out = flash_fn( + query_states, key_states, value_states, + softmax_scale=softmax_scale, causal=causal, **flash_kwargs ) - if isinstance(attn_output, tuple): - return attn_output[0] - return attn_output + return out[0] if isinstance(out, tuple) else out -class FlashAttentionKwargs(TypedDict, total=False): - """ - Keyword arguments for Flash Attention with Compile. - - Attributes: - cumulative_seqlens_q (`torch.LongTensor`, *optional*) - Gets cumulative sequence length for query state. - cumulative_seqlens_k (`torch.LongTensor`, *optional*) - Gets cumulative sequence length for key state. - max_length_q (`int`, *optional*): - Maximum sequence length for query state. - max_length_k (`int`, *optional*): - Maximum sequence length for key state. - """ - - cumulative_seqlens_q: Optional[torch.LongTensor] - cumulative_seqlens_k: Optional[torch.LongTensor] - max_length_q: Optional[int] - max_length_k: Optional[int] From 1b834a4da5c4d81ed4c8ec57ed1e86d351445988 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 17 Jul 2025 13:42:32 +0200 Subject: [PATCH 03/38] revert one wrong change --- .../modeling_flash_attention_utils.py | 168 +++++++++++++----- 1 file changed, 121 insertions(+), 47 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 603680c5b041..a797699c4d49 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -28,12 +28,19 @@ def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None): indices = torch.nonzero(masks.flatten(), as_tuple=False).flatten() max_len = lengths.max().item() cu = F.pad(torch.cumsum(lengths, dim=0, dtype=torch.int32), (1, 0)) - return (_index_first_axis(hidden_states, indices), indices, cu, max_len, attention_mask.sum(dim=-1, dtype=torch.int32)) + return ( + _index_first_axis(hidden_states, indices), + indices, + cu, + max_len, + attention_mask.sum(dim=-1, dtype=torch.int32), + ) def _fa3_pad_input(hidden_states, indices, batch: int, seqlen: int): - out = torch.zeros((batch * seqlen), *hidden_states.shape[1:], - device=hidden_states.device, dtype=hidden_states.dtype) + out = torch.zeros( + (batch * seqlen), *hidden_states.shape[1:], device=hidden_states.device, dtype=hidden_states.dtype + ) out[indices] = hidden_states return out.view(batch, seqlen, *hidden_states.shape[1:]) @@ -46,49 +53,89 @@ def _get_unpad_data(attn_mask: torch.Tensor): return indices, cu, max_len -def _upad_input( - q, k, v, attn_mask: torch.Tensor, q_len: int, unpad_fn -): +def _upad_input(q, k, v, attn_mask: torch.Tensor, q_len: int, unpad_fn): indices_k, cu_k, max_k = _get_unpad_data(attn_mask) # trim KV if too long if k.shape[1] > attn_mask.shape[-1]: - k, v = k[:, :attn_mask.shape[-1]], v[:, :attn_mask.shape[-1]] - q = (_index_first_axis(q, indices_k) if q_len == attn_mask.shape[-1] - else _index_first_axis(q.squeeze(1), torch.arange(q.shape[0], device=q.device))) + k, v = k[:, : attn_mask.shape[-1]], v[:, : attn_mask.shape[-1]] + q = ( + _index_first_axis(q, indices_k) + if q_len == attn_mask.shape[-1] + else _index_first_axis(q.squeeze(1), torch.arange(q.shape[0], device=q.device)) + ) k = _index_first_axis(k, indices_k) v = _index_first_axis(v, indices_k) - cu_q = cu_k if q_len == attn_mask.shape[-1] else torch.arange(q.shape[0]+1, dtype=torch.int32, device=q.device) + cu_q = cu_k if q_len == attn_mask.shape[-1] else torch.arange(q.shape[0] + 1, dtype=torch.int32, device=q.device) max_q = max_k if q_len == attn_mask.shape[-1] else 1 return q, k, v, indices_k, (cu_q, cu_k), (max_q, max_k) -def _prepare_from_posids(q, k, v, pos_ids: torch.Tensor): - batch, ql, nh, hd = q.shape - _, kl, nhk, _ = k.shape - q_flat = q.view(-1, nh, hd) - k_flat = k.contiguous().view(-1, nhk, hd) - v_flat = v.contiguous().view(-1, nhk, hd) - pos = pos_ids.flatten() - idx = torch.arange(pos.numel(), device=pos.device, dtype=torch.int32) - cu = torch.cat((idx[pos==0], torch.tensor(idx.size(0), device=pos.device, dtype=torch.int32))) - max_len = cu.diff().max().item() - return q_flat, k_flat, v_flat, idx, (cu, cu), (max_len, max_len) +def _prepare_from_posids(query, key, value, position_ids): + """ + This function returns necessary arguments to call `flash_attn_varlen_func`. + All three query, key, value states will be flattened. + Cumulative lengths of each examples in the batch will be extracted from position_ids. + NOTE: ideally cumulative lengths should be prepared at the data collator stage + Arguments: + query (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + position_ids (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + Return: + query (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + query = query.view(-1, query.size(-2), query.size(-1)) + key = key.contiguous().view(-1, key.size(-2), key.size(-1)) + value = value.contiguous().view(-1, value.size(-2), value.size(-1)) + position_ids = position_ids.flatten() + indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) + + cu_seq_lens = torch.cat( + ( + indices_q[position_ids == 0], + torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), + ) + ) + + # NOTE: With torch compile, this will cause a graph break if you don't set + # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call + # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass. + # This is a limitation of flash attention API, as the function `flash_attn_varlen_func` + # requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`. + # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424 + # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing + # for some models (e.g. qwen2-vl). + max_length = cu_seq_lens.diff().max().item() + + return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) -def _prepare_flash_attention_from_position_ids( - query, key, value, position_ids -): +def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): warnings.warn( "prepare_fa2_from_position_ids is deprecated, use _prepare_flash_attention_from_position_ids", FutureWarning, ) return _prepare_from_posids(query, key, value, position_ids) + def fa_peft_integration_check(q, k, v, target_dtype: Optional[torch.dtype] = None): if target_dtype and q.dtype == torch.float32: - logger.warning_once( - f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility." - ) + logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility.") q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype) return q, k, v @@ -101,11 +148,23 @@ def _lazy_imports(impl: Optional[str]): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.layers.rotary import apply_rotary_emb + return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, False if impl == "flash_attention_3" or (impl is None and is_fa3): from flash_attn_interface import flash_attn_func, flash_attn_varlen_func + pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, True + else: + pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input + return ( + getattr(impl, "flash_attn_func", None), + getattr(impl, "flash_attn_varlen_func"), + pad_input, + unpad_input, + True, + ) + # fallback raise ValueError(f"Invalid flash-attn implementation: {impl}") @@ -122,7 +181,10 @@ def flash_attn_supports_top_left_mask(): return False if is_flash_attn_2_available(): return not is_flash_attn_greater_or_equal_2_10() - return from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask + + from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask + + return is_npu_fa2_top_left_aligned_causal_mask() class FlashAttentionKwargs(TypedDict, total=False): @@ -130,7 +192,7 @@ class FlashAttentionKwargs(TypedDict, total=False): cumulative_seqlens_k: Optional[torch.LongTensor] -def flash_attention_forward( +def _flash_attention_forward( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, @@ -152,13 +214,13 @@ def flash_attention_forward( implementation: Optional[str] = None, **kwargs, ): - # lazy import flash_fn, flash_varlen_fn, pad_fn, unpad_fn, is_fa3 = _lazy_imports(implementation) - causal = is_causal and not (use_top_left_mask and query_length == 1) use_sw = ( - _flash_supports_window or "window_size" in inspect.signature(flash_fn).parameters - ) and sliding_window and key_states.shape[1] > sliding_window + (_flash_supports_window or "window_size" in inspect.signature(flash_varlen_fn).parameters) + and sliding_window + and key_states.shape[1] > sliding_window + ) flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sw else {} if not is_fa3: flash_kwargs["dropout_p"] = dropout @@ -168,26 +230,33 @@ def flash_attention_forward( if softcap is not None: flash_kwargs["softcap"] = softcap - # dtype check query_states, key_states, value_states = fa_peft_integration_check( query_states, key_states, value_states, target_dtype ) - - # select varlen vs fixed - use_varlen = (position_ids is not None or all([cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k])) - + use_varlen = position_ids is not None or all([cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]) if attention_mask is not None: q, k, v, idx, (cu_q, cu_k), (mq, mk) = _upad_input( query_states, key_states, value_states, attention_mask, query_length, unpad_fn ) out_unpad = flash_varlen_fn( - q, k, v, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k, - max_seqlen_q=mq, max_seqlen_k=mk, softmax_scale=softmax_scale, - causal=causal, **flash_kwargs + q, + k, + v, + cu_seqlens_q=cu_q, + cu_seqlens_k=cu_k, + max_seqlen_q=mq, + max_seqlen_k=mk, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, ) out = pad_fn(out_unpad, idx, query_states.shape[0], query_length) elif use_varlen: if cu_seq_lens_q is None or cu_seq_lens_k is None: + if position_ids is None: + raise ValueError( + "Position ids should be passed if the attention mask is not passed and the cu_seq-lens are not passed." + ) q, k, v, idx, (cu_q, cu_k), (mq, mk) = _prepare_from_posids( query_states, key_states, value_states, position_ids ) @@ -199,17 +268,22 @@ def flash_attention_forward( cu_q, cu_k = cu_seq_lens_q, cu_seq_lens_k out = flash_varlen_fn( - q, k, v, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k, - max_seqlen_q=mq, max_seqlen_k=mk, softmax_scale=softmax_scale, - causal=causal, **flash_kwargs + q, + k, + v, + cu_seqlens_q=cu_q, + cu_seqlens_k=cu_k, + max_seqlen_q=mq, + max_seqlen_k=mk, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, ) out = out.view(query_states.shape[0], -1, out.size(-2), out.size(-1)) else: out = flash_fn( - query_states, key_states, value_states, - softmax_scale=softmax_scale, causal=causal, **flash_kwargs + query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs ) return out[0] if isinstance(out, tuple) else out - From d93f366ef235ea9fef0371016baf32aa338511e3 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 17 Jul 2025 14:02:19 +0200 Subject: [PATCH 04/38] just a nit --- src/transformers/modeling_flash_attention_utils.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index a797699c4d49..d31180eee77e 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -104,7 +104,6 @@ def _prepare_from_posids(query, key, value, position_ids): value = value.contiguous().view(-1, value.size(-2), value.size(-1)) position_ids = position_ids.flatten() indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) - cu_seq_lens = torch.cat( ( indices_q[position_ids == 0], @@ -162,12 +161,9 @@ def _lazy_imports(impl: Optional[str]): getattr(impl, "flash_attn_varlen_func"), pad_input, unpad_input, - True, + False, ) - # fallback - raise ValueError(f"Invalid flash-attn implementation: {impl}") - _flash_supports_window = None @@ -266,7 +262,6 @@ def _flash_attention_forward( v = value_states.view(-1, value_states.size(-2), value_states.size(-1)) mq, mk = max_length_q, max_length_k cu_q, cu_k = cu_seq_lens_q, cu_seq_lens_k - out = flash_varlen_fn( q, k, @@ -286,4 +281,3 @@ def _flash_attention_forward( ) return out[0] if isinstance(out, tuple) else out - From 2b7d411d24fd7fbcd92d41532249d140303a0973 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 17 Jul 2025 14:29:14 +0200 Subject: [PATCH 05/38] push --- src/transformers/modeling_flash_attention_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index d31180eee77e..8134d067d741 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -99,7 +99,7 @@ def _prepare_from_posids(query, key, value, position_ids): (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). """ - query = query.view(-1, query.size(-2), query.size(-1)) + query = query.contiguous().view(-1, query.size(-2), query.size(-1)) key = key.contiguous().view(-1, key.size(-2), key.size(-1)) value = value.contiguous().view(-1, value.size(-2), value.size(-1)) position_ids = position_ids.flatten() From affba20d26d03d253570fbc7223439f1f0fd2470 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 17 Jul 2025 15:17:40 +0200 Subject: [PATCH 06/38] reverter watever was wrong! --- .../integrations/flash_attention.py | 1 - src/transformers/masking_utils.py | 2 + .../modeling_flash_attention_utils.py | 159 ++++++++++++++---- src/transformers/modeling_utils.py | 1 + 4 files changed, 126 insertions(+), 37 deletions(-) diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index 00df0ef0fd66..6e0f04874529 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -38,7 +38,6 @@ def flash_attention_forward( "FlashAttention does not support inputs with dim=0.\n" "Please check your input shapes or use SDPA instead." ) - # FA2 uses non-transposed inputs query = query.transpose(1, 2) key = key.transpose(1, 2) diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 10f1a394d5a8..1964bd1e9884 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -689,6 +689,8 @@ def _preprocess_mask_arguments( # with `torch._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped # according trace_rules.lookup SKIP_DIRS'` -- can be removed when we require Python>=3.11 if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS._global_mapping: + if "kernel" in config._attn_implementation: + return True, attention_mask, None, None, None return True, None, None, None, None # Move the mask to correct device, and potentially switch dtype for efficiency diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 8134d067d741..485fb0662a16 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -18,7 +18,7 @@ def _index_first_axis(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: - reshaped = tensor.reshape(-1, *tensor.shape[2:]) + reshaped = tensor.contiguous().reshape(-1, *tensor.shape[2:]) return reshaped[indices] @@ -37,40 +37,128 @@ def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None): ) -def _fa3_pad_input(hidden_states, indices, batch: int, seqlen: int): - out = torch.zeros( - (batch * seqlen), *hidden_states.shape[1:], device=hidden_states.device, dtype=hidden_states.dtype +def _fa3_pad_input(hidden_states, indices, batch, seqlen): + """ + FA3-compatible pad_input function. + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[1:] + output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) + output[indices] = hidden_states + return output.view(batch, seqlen, *dim) + + +def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + Arguments: + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + Return: + indices (`torch.Tensor`): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + # NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile, + # this might cause a graph break + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, ) - out[indices] = hidden_states - return out.view(batch, seqlen, *hidden_states.shape[1:]) - - -def _get_unpad_data(attn_mask: torch.Tensor): - seqlens = attn_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten() - max_len = seqlens.max().item() - cu = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)) - return indices, cu, max_len - - -def _upad_input(q, k, v, attn_mask: torch.Tensor, q_len: int, unpad_fn): - indices_k, cu_k, max_k = _get_unpad_data(attn_mask) - # trim KV if too long - if k.shape[1] > attn_mask.shape[-1]: - k, v = k[:, : attn_mask.shape[-1]], v[:, : attn_mask.shape[-1]] - q = ( - _index_first_axis(q, indices_k) - if q_len == attn_mask.shape[-1] - else _index_first_axis(q.squeeze(1), torch.arange(q.shape[0], device=q.device)) + + +def _upad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + unpad_input_func, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. + This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary + tensors for query, key, value tensors. + Arguments: + query_layer (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + query_length (`int`): + Target length. + unpad_input_func: + The function to use for unpadding the input tensors. + Return: + query_layer (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + # With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage + # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores + if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]): + key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :] + + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = _index_first_axis(key_layer, indices_k) + value_layer = _index_first_axis(value_layer, indices_k) + if query_length == kv_seq_len: + query_layer = _index_first_axis(query_layer, indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) - k = _index_first_axis(k, indices_k) - v = _index_first_axis(v, indices_k) - cu_q = cu_k if q_len == attn_mask.shape[-1] else torch.arange(q.shape[0] + 1, dtype=torch.int32, device=q.device) - max_q = max_k if q_len == attn_mask.shape[-1] else 1 - return q, k, v, indices_k, (cu_q, cu_k), (max_q, max_k) -def _prepare_from_posids(query, key, value, position_ids): +def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): """ This function returns necessary arguments to call `flash_attn_varlen_func`. All three query, key, value states will be flattened. @@ -99,18 +187,18 @@ def _prepare_from_posids(query, key, value, position_ids): (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). """ - query = query.contiguous().view(-1, query.size(-2), query.size(-1)) + query = query.view(-1, query.size(-2), query.size(-1)) key = key.contiguous().view(-1, key.size(-2), key.size(-1)) value = value.contiguous().view(-1, value.size(-2), value.size(-1)) position_ids = position_ids.flatten() indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) + cu_seq_lens = torch.cat( ( indices_q[position_ids == 0], torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), ) ) - # NOTE: With torch compile, this will cause a graph break if you don't set # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass. @@ -120,7 +208,6 @@ def _prepare_from_posids(query, key, value, position_ids): # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing # for some models (e.g. qwen2-vl). max_length = cu_seq_lens.diff().max().item() - return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) @@ -239,7 +326,7 @@ def _flash_attention_forward( k, v, cu_seqlens_q=cu_q, - cu_seqlens_k=cu_k, + cu_seqlens_k=cu_k.clone(), max_seqlen_q=mq, max_seqlen_k=mk, softmax_scale=softmax_scale, @@ -267,7 +354,7 @@ def _flash_attention_forward( k, v, cu_seqlens_q=cu_q, - cu_seqlens_k=cu_k, + cu_seqlens_k=cu_k.clone(), max_seqlen_q=mq, max_seqlen_k=mk, softmax_scale=softmax_scale, diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 04cc2bc21a3b..04a08ec186f3 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -71,6 +71,7 @@ verify_tp_plan, ) from .loss.loss_utils import LOSS_MAPPING +from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS from .pytorch_utils import ( # noqa: F401 Conv1D, apply_chunking_to_forward, From 1959eb2812df84f78efe53afd32d241bab36fcbb Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 17 Jul 2025 15:24:45 +0200 Subject: [PATCH 07/38] some nits --- src/transformers/masking_utils.py | 2 -- src/transformers/modeling_flash_attention_utils.py | 4 ++-- src/transformers/modeling_utils.py | 3 +++ 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 1964bd1e9884..10f1a394d5a8 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -689,8 +689,6 @@ def _preprocess_mask_arguments( # with `torch._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped # according trace_rules.lookup SKIP_DIRS'` -- can be removed when we require Python>=3.11 if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS._global_mapping: - if "kernel" in config._attn_implementation: - return True, attention_mask, None, None, None return True, None, None, None, None # Move the mask to correct device, and potentially switch dtype for efficiency diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 485fb0662a16..f672722df9e8 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -158,7 +158,7 @@ def _upad_input( ) -def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): +def _prepare_from_posids(query, key, value, position_ids): """ This function returns necessary arguments to call `flash_attn_varlen_func`. All three query, key, value states will be flattened. @@ -187,7 +187,7 @@ def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). """ - query = query.view(-1, query.size(-2), query.size(-1)) + query = query.contiguous().view(-1, query.size(-2), query.size(-1)) key = key.contiguous().view(-1, key.size(-2), key.size(-1)) value = value.contiguous().view(-1, value.size(-2), value.size(-1)) position_ids = position_ids.flatten() diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 04a08ec186f3..ad53a89a4dc9 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2293,6 +2293,9 @@ def _check_attn_implementation(cls, attn_implementation: Union[str, dict]) -> Un ALL_ATTENTION_FUNCTIONS[repo_id] = partial(flash_attention_forward, implementation=kernel) else: ALL_ATTENTION_FUNCTIONS[repo_id] = getattr(kernel, kernel_name) + ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[repo_id] = ALL_MASK_ATTENTION_FUNCTIONS[ + "flash_attention_2" + ] attn_implementation = repo_id except FileNotFoundError as e: logger.warning( From 888cd402fefadfc5101af76775bb4fc3434be1fb Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 17 Jul 2025 15:49:14 +0200 Subject: [PATCH 08/38] fixes when there is no attention mask --- .../integrations/flash_attention.py | 1 + .../modeling_flash_attention_utils.py | 20 +++++++++++++++++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index 6e0f04874529..43c65b46c805 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -75,6 +75,7 @@ def flash_attention_forward( use_top_left_mask=_use_top_left_mask, target_dtype=target_dtype, attn_implementation=module.config._attn_implementation, + layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None, **kwargs, ) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index f672722df9e8..f5b566ee9739 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -190,12 +190,14 @@ def _prepare_from_posids(query, key, value, position_ids): query = query.contiguous().view(-1, query.size(-2), query.size(-1)) key = key.contiguous().view(-1, key.size(-2), key.size(-1)) value = value.contiguous().view(-1, value.size(-2), value.size(-1)) + cu_seqlens_k = torch.cat([torch.tensor([0], device=query.device), position_ids[:, -1].cumsum(dim=0) + 1], dim=0) + max_k = torch.max(position_ids, dim=1).values.max().item() + 1 position_ids = position_ids.flatten() indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) cu_seq_lens = torch.cat( ( - indices_q[position_ids == 0], + torch.tensor([0], device=position_ids.device, dtype=torch.int32), torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), ) ) @@ -208,7 +210,7 @@ def _prepare_from_posids(query, key, value, position_ids): # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing # for some models (e.g. qwen2-vl). max_length = cu_seq_lens.diff().max().item() - return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) + return (query, key, value, indices_q, (cu_seq_lens, cu_seqlens_k), (max_length, max_k)) def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): @@ -321,6 +323,13 @@ def _flash_attention_forward( q, k, v, idx, (cu_q, cu_k), (mq, mk) = _upad_input( query_states, key_states, value_states, attention_mask, query_length, unpad_fn ) + if kwargs.get("layer_idx", 0) == 0: + print( + cu_q, + cu_k, + mq, + mk, + ) out_unpad = flash_varlen_fn( q, k, @@ -349,6 +358,13 @@ def _flash_attention_forward( v = value_states.view(-1, value_states.size(-2), value_states.size(-1)) mq, mk = max_length_q, max_length_k cu_q, cu_k = cu_seq_lens_q, cu_seq_lens_k + if kwargs.get("layer_idx", 0) == 0: + print( + cu_q, + cu_k, + mq, + mk, + ) out = flash_varlen_fn( q, k, From 5a7ae11308ca515f4bb7220acdf0881cc470f32c Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 17 Jul 2025 16:46:09 +0200 Subject: [PATCH 09/38] bring the licence back --- .../modeling_flash_attention_utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index f5b566ee9739..acffa84820e0 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -1,3 +1,16 @@ +# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import inspect import os import warnings @@ -5,6 +18,7 @@ import torch import torch.nn.functional as F + from .utils import ( is_flash_attn_2_available, is_flash_attn_3_available, @@ -14,6 +28,7 @@ logging, ) + logger = logging.get_logger(__name__) From c57673bbcd0f81f4e5c7ffb2c20b8c2800d7eb41 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 17 Jul 2025 15:21:06 +0000 Subject: [PATCH 10/38] some fixes --- src/transformers/modeling_flash_attention_utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index acffa84820e0..2aa0f4268a2e 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -265,7 +265,7 @@ def _lazy_imports(impl: Optional[str]): getattr(impl, "flash_attn_varlen_func"), pad_input, unpad_input, - False, + True, ) @@ -349,14 +349,16 @@ def _flash_attention_forward( q, k, v, - cu_seqlens_q=cu_q, - cu_seqlens_k=cu_k.clone(), + cu_seqlens_q=cu_q.to(torch.int32), + cu_seqlens_k=cu_k.clone().to(torch.int32), max_seqlen_q=mq, max_seqlen_k=mk, softmax_scale=softmax_scale, causal=causal, **flash_kwargs, ) + if isinstance(out_unpad, tuple): + out_unpad = out_unpad[0] out = pad_fn(out_unpad, idx, query_states.shape[0], query_length) elif use_varlen: if cu_seq_lens_q is None or cu_seq_lens_k is None: @@ -384,14 +386,16 @@ def _flash_attention_forward( q, k, v, - cu_seqlens_q=cu_q, - cu_seqlens_k=cu_k.clone(), + cu_seqlens_q=cu_q.to(torch.int32), + cu_seqlens_k=cu_k.clone().to(torch.int32), max_seqlen_q=mq, max_seqlen_k=mk, softmax_scale=softmax_scale, causal=causal, **flash_kwargs, ) + if isinstance(out, tuple): + out = out[0] out = out.view(query_states.shape[0], -1, out.size(-2), out.size(-1)) else: out = flash_fn( From 7d69d8345dadc415fec6d5304bbcd87836316f3c Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 17 Jul 2025 17:22:06 +0200 Subject: [PATCH 11/38] nit --- .../modeling_flash_attention_utils.py | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index acffa84820e0..8c4c0204db3f 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -38,17 +38,32 @@ def _index_first_axis(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tens def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None): - masks = attention_mask + unused_mask if unused_mask is not None else attention_mask - lengths = masks.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(masks.flatten(), as_tuple=False).flatten() - max_len = lengths.max().item() - cu = F.pad(torch.cumsum(lengths, dim=0, dtype=torch.int32), (1, 0)) + """ + FA3-compatible unpad_input function. + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( _index_first_axis(hidden_states, indices), indices, - cu, - max_len, - attention_mask.sum(dim=-1, dtype=torch.int32), + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, ) From 112e2a64af857f595ec7ee1e38c8416462cee0c4 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 17 Jul 2025 17:26:16 +0200 Subject: [PATCH 12/38] style --- src/transformers/modeling_flash_attention_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 7f02615f3166..593b2177c391 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -265,7 +265,6 @@ def _lazy_imports(impl: Optional[str]): if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import pad_input, unpad_input - from flash_attn.layers.rotary import apply_rotary_emb return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, False if impl == "flash_attention_3" or (impl is None and is_fa3): From 501aa7ea16226387bda32fe30b58ea442717d666 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 17 Jul 2025 17:35:41 +0200 Subject: [PATCH 13/38] remove prints --- src/transformers/modeling_flash_attention_utils.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 593b2177c391..24f0c1186085 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -352,13 +352,6 @@ def _flash_attention_forward( q, k, v, idx, (cu_q, cu_k), (mq, mk) = _upad_input( query_states, key_states, value_states, attention_mask, query_length, unpad_fn ) - if kwargs.get("layer_idx", 0) == 0: - print( - cu_q, - cu_k, - mq, - mk, - ) out_unpad = flash_varlen_fn( q, k, @@ -389,13 +382,6 @@ def _flash_attention_forward( v = value_states.view(-1, value_states.size(-2), value_states.size(-1)) mq, mk = max_length_q, max_length_k cu_q, cu_k = cu_seq_lens_q, cu_seq_lens_k - if kwargs.get("layer_idx", 0) == 0: - print( - cu_q, - cu_k, - mq, - mk, - ) out = flash_varlen_fn( q, k, From 04088bec61603d2e93e3d69cb8a78e46b8eb7228 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 17 Jul 2025 17:46:13 +0200 Subject: [PATCH 14/38] correct dtype --- src/transformers/modeling_flash_attention_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 24f0c1186085..c8d399f08a6e 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -220,7 +220,9 @@ def _prepare_from_posids(query, key, value, position_ids): query = query.contiguous().view(-1, query.size(-2), query.size(-1)) key = key.contiguous().view(-1, key.size(-2), key.size(-1)) value = value.contiguous().view(-1, value.size(-2), value.size(-1)) - cu_seqlens_k = torch.cat([torch.tensor([0], device=query.device), position_ids[:, -1].cumsum(dim=0) + 1], dim=0) + cu_seqlens_k = torch.cat( + [torch.tensor([0], dtype=torch.int32, device=query.device), position_ids[:, -1].cumsum(dim=0) + 1], dim=0 + ) max_k = torch.max(position_ids, dim=1).values.max().item() + 1 position_ids = position_ids.flatten() indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) From b1e104b0ed95039faebd3873e4f115c140f45e21 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 17 Jul 2025 18:38:48 +0200 Subject: [PATCH 15/38] fa flags for testing --- tests/causal_lm_tester.py | 2 +- tests/generation/test_utils.py | 4 ++-- tests/test_modeling_common.py | 14 +++++--------- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/tests/causal_lm_tester.py b/tests/causal_lm_tester.py index 9807c8856059..b13f824bf7f4 100644 --- a/tests/causal_lm_tester.py +++ b/tests/causal_lm_tester.py @@ -422,7 +422,7 @@ def test_model_rope_scaling(self): @slow def test_flash_attn_2_equivalence(self): for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(reason="Model does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 531bf70d5eeb..fab1672b5c86 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2297,8 +2297,8 @@ def _test_attention_implementation(self, attn_implementation): max_new_tokens = 3 support_flag = { "sdpa": "_supports_sdpa", - "flash_attention_2": "_supports_flash_attn_2", - "flash_attention_3": "_supports_flash_attn_3", + "flash_attention_2": "_supports_flash_attn", + "flash_attention_3": "_supports_flash_attn", } set_model_tester_for_less_flaky_test(self) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index fbb8d5f541a4..daafcc60a485 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3471,9 +3471,7 @@ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_sid self.skipTest(reason="Model architecture does not support attentions") for model_class in self.all_model_classes: - if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or ( - attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3 - ): + if not model_class._supports_flash_attn and (attn_implementation == "flash_attention_2" or attn_implementation == "flash_attention_3"): self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -4132,9 +4130,9 @@ def flash_attention_padding_matches_padding_free_with_position_ids( for model_class in self.all_generative_model_classes: if not ( - model_class._supports_flash_attn_2 - if attn_implementation == "flash_attention_2" - else model_class._supports_flash_attn_3 + model_class._supports_flash_attn + #if attn_implementation == "flash_attention_2" + #else model_class._supports_flash_attn_3 ): self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") @@ -4256,9 +4254,7 @@ def flash_attn_from_config(self, attn_implementation: str): self.skipTest(reason="Model architecture does not support attentions") for model_class in self.all_generative_model_classes: - if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or ( - attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3 - ): + if not model_class._supports_flash_attn and (attn_implementation == "flash_attention_2" or attn_implementation == "flash_attention_3"): self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From 7087e7b8a3d9e34d0adcc6b29d022cb157cc7520 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 17 Jul 2025 18:42:04 +0200 Subject: [PATCH 16/38] update --- tests/test_modeling_common.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index daafcc60a485..5d48632704a7 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3471,7 +3471,9 @@ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_sid self.skipTest(reason="Model architecture does not support attentions") for model_class in self.all_model_classes: - if not model_class._supports_flash_attn and (attn_implementation == "flash_attention_2" or attn_implementation == "flash_attention_3"): + if not model_class._supports_flash_attn and ( + attn_implementation == "flash_attention_2" or attn_implementation == "flash_attention_3" + ): self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -4131,8 +4133,8 @@ def flash_attention_padding_matches_padding_free_with_position_ids( for model_class in self.all_generative_model_classes: if not ( model_class._supports_flash_attn - #if attn_implementation == "flash_attention_2" - #else model_class._supports_flash_attn_3 + # if attn_implementation == "flash_attention_2" + # else model_class._supports_flash_attn_3 ): self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") @@ -4254,7 +4256,9 @@ def flash_attn_from_config(self, attn_implementation: str): self.skipTest(reason="Model architecture does not support attentions") for model_class in self.all_generative_model_classes: - if not model_class._supports_flash_attn and (attn_implementation == "flash_attention_2" or attn_implementation == "flash_attention_3"): + if not model_class._supports_flash_attn and ( + attn_implementation == "flash_attention_2" or attn_implementation == "flash_attention_3" + ): self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From 6a2996a031306c556426d0b467745a21a770e4b0 Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 18 Jul 2025 08:45:15 +0200 Subject: [PATCH 17/38] use paged attention if requested! --- src/transformers/integrations/flash_paged.py | 3 +++ src/transformers/modeling_utils.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py index b0463d952487..c304b1ff71fe 100644 --- a/src/transformers/integrations/flash_paged.py +++ b/src/transformers/integrations/flash_paged.py @@ -20,6 +20,7 @@ def paged_attention_forward( max_seqlen_q=None, max_seqlen_k=None, block_tables=None, + implementation=None, **kwargs, ) -> torch.Tensor: r"""Perform the forward pass of attention with paged key-value cache. @@ -46,6 +47,8 @@ def paged_attention_forward( """ k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs) + if implementation is not None: + flash_attn_varlen_func = implementation.flash_attn_varlen_func attn_output = flash_attn_varlen_func( q.transpose(1, 2).squeeze(0), k.transpose(1, 2).squeeze(0), diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 13e29fc73f66..4a1ee23d88fd 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2291,6 +2291,8 @@ def _check_attn_implementation(cls, attn_implementation: Union[str, dict]) -> Un kernel = get_kernel(repo_id) if "flash_attention" in kernel_name: ALL_ATTENTION_FUNCTIONS[repo_id] = partial(flash_attention_forward, implementation=kernel) + elif "paged_atention" in kernel_name: + ALL_ATTENTION_FUNCTIONS[repo_id] = partial(paged_attention_forward, implementation=kernel) else: ALL_ATTENTION_FUNCTIONS[repo_id] = getattr(kernel, kernel_name) ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[repo_id] = ALL_MASK_ATTENTION_FUNCTIONS[ From a5862941ed4c65fb90264720188db7d37a619e4b Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 18 Jul 2025 11:07:31 +0200 Subject: [PATCH 18/38] updates --- examples/pytorch/continuous_batching.py | 7 +++++-- src/transformers/integrations/flash_paged.py | 6 +++--- src/transformers/modeling_utils.py | 2 +- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index 9aaa836f7bae..1cba7cf783eb 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -11,7 +11,10 @@ model_id = "meta-llama/Llama-3.2-3b-Instruct" model = AutoModelForCausalLM.from_pretrained( - model_id, attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto" + model_id, + attn_implementation="kernels-community/metal-flash-sdpa:paged_attention", + torch_dtype=torch.bfloat16, + device_map="auto", ).eval() tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") @@ -28,7 +31,7 @@ ) train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test") - +train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version # --- Example 1: Simple Version using generate_batch --- print("--- Running CB Generation Example ---") diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py index c304b1ff71fe..90b1393b9dcd 100644 --- a/src/transformers/integrations/flash_paged.py +++ b/src/transformers/integrations/flash_paged.py @@ -50,9 +50,9 @@ def paged_attention_forward( if implementation is not None: flash_attn_varlen_func = implementation.flash_attn_varlen_func attn_output = flash_attn_varlen_func( - q.transpose(1, 2).squeeze(0), - k.transpose(1, 2).squeeze(0), - v.transpose(1, 2).squeeze(0), + q.transpose(1, 2).squeeze(0).contiguous(), + k.transpose(1, 2).squeeze(0).contiguous(), + v.transpose(1, 2).squeeze(0).contiguous(), cumulative_seqlens_q.to(torch.int32), cumulative_seqlens_k.to(torch.int32), max_seqlen_q, diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 670e094c21bd..4c7c228ec8d9 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2291,7 +2291,7 @@ def _check_attn_implementation(cls, attn_implementation: Union[dict, str]) -> Un kernel = get_kernel(repo_id) if "flash_attention" in kernel_name: ALL_ATTENTION_FUNCTIONS[repo_id] = partial(flash_attention_forward, implementation=kernel) - elif "paged_atention" in kernel_name: + elif "paged_attention" in kernel_name: ALL_ATTENTION_FUNCTIONS[repo_id] = partial(paged_attention_forward, implementation=kernel) else: ALL_ATTENTION_FUNCTIONS[repo_id] = getattr(kernel, kernel_name) From 57842f56ea39a1409a12117dbf25b1b91d43346f Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 18 Jul 2025 11:32:01 +0200 Subject: [PATCH 19/38] a clone was needed, not sure why --- src/transformers/integrations/flash_paged.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py index 90b1393b9dcd..32621f273230 100644 --- a/src/transformers/integrations/flash_paged.py +++ b/src/transformers/integrations/flash_paged.py @@ -54,7 +54,7 @@ def paged_attention_forward( k.transpose(1, 2).squeeze(0).contiguous(), v.transpose(1, 2).squeeze(0).contiguous(), cumulative_seqlens_q.to(torch.int32), - cumulative_seqlens_k.to(torch.int32), + cumulative_seqlens_k.to(torch.int32).clone(), max_seqlen_q, max_seqlen_k, softmax_scale=module.scaling, From 43b7f322e58815e5f1dc3eea6960e1a31612886e Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 18 Jul 2025 13:20:36 +0200 Subject: [PATCH 20/38] automatically create cu seq lens when input is flash, this at least makes sure layers don't re-compute --- src/transformers/generation/utils.py | 29 ++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 809db679d35c..411d81b54287 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -677,6 +677,35 @@ def prepare_inputs_for_generation( if encoder_attention_mask is not None: model_inputs["attention_mask"] = encoder_attention_mask + if "flash" in self.config._attn_implementation: + tensor_kws = {"dtype": torch.int32, "device": self.device} + cu_seq_lens_q: Optional[torch.LongTensor] = None + cu_seq_lens_k: Optional[torch.LongTensor] = None + max_length_q: Optional[int] = None + max_length_k: Optional[int] = None + + position_ids = model_inputs.get("position_ids") + if position_ids is not None: + last_pos = position_ids[:, -1] + cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), last_pos.cumsum(dim=0).add(1)], dim=0) + max_length_k = last_pos.max().item() + 1 + flat_q = position_ids.flatten() + cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), torch.tensor([flat_q.numel()], **tensor_kws)]) + max_length_q = flat_q.numel() + elif attention_mask is not None: + seqlens = attention_mask.sum(dim=-1, **tensor_kws) + cu_seq_lens_k = torch.nn.functional.pad(seqlens.cumsum(dim=0), (1, 0)) + max_length_k = seqlens.max().item() + q_len = cache_position.size(0) + cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), torch.tensor(q_len, **tensor_kws)]) + max_length_q = q_len + + model_inputs.update( + cu_seq_lens_q=cu_seq_lens_q, + cu_seq_lens_k=cu_seq_lens_k, + max_length_q=max_length_q, + max_length_k=max_length_k, + ) # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`). for key, value in kwargs.items(): if key not in model_inputs: From 12bad1b3f2fb556f1608bba1efd9528d59a892e4 Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 18 Jul 2025 15:19:35 +0200 Subject: [PATCH 21/38] simplify and improve? --- src/transformers/generation/utils.py | 34 +++++++------------ .../modeling_flash_attention_utils.py | 6 ++-- 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 411d81b54287..c4398e6cb17c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -619,6 +619,7 @@ def prepare_inputs_for_generation( model_input = model_input.clone(memory_format=torch.contiguous_format) model_inputs[model_input_name] = model_input + q_lengths = attention_mask.sum(dim=-1) # 6. Create 4D attention mask is we are using a compilable cache (important for performant compiled forward # pass) if ( @@ -679,30 +680,19 @@ def prepare_inputs_for_generation( if "flash" in self.config._attn_implementation: tensor_kws = {"dtype": torch.int32, "device": self.device} - cu_seq_lens_q: Optional[torch.LongTensor] = None - cu_seq_lens_k: Optional[torch.LongTensor] = None - max_length_q: Optional[int] = None - max_length_k: Optional[int] = None - - position_ids = model_inputs.get("position_ids") - if position_ids is not None: - last_pos = position_ids[:, -1] - cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), last_pos.cumsum(dim=0).add(1)], dim=0) - max_length_k = last_pos.max().item() + 1 - flat_q = position_ids.flatten() - cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), torch.tensor([flat_q.numel()], **tensor_kws)]) - max_length_q = flat_q.numel() - elif attention_mask is not None: - seqlens = attention_mask.sum(dim=-1, **tensor_kws) - cu_seq_lens_k = torch.nn.functional.pad(seqlens.cumsum(dim=0), (1, 0)) - max_length_k = seqlens.max().item() - q_len = cache_position.size(0) - cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), torch.tensor(q_len, **tensor_kws)]) - max_length_q = q_len + pos = model_inputs["position_ids"][:, -1] + + cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), pos.cumsum(0).add(1)], 0) + max_length_k = int(pos.max()) + 1 + + bs, seq_len = input_ids.size() + q_len = torch.ones(bs, **tensor_kws) if seq_len == 1 else pos.to(torch.int32).add(1) + cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0)], 0) + max_length_q = int(q_len.max()) model_inputs.update( - cu_seq_lens_q=cu_seq_lens_q, - cu_seq_lens_k=cu_seq_lens_k, + cu_seq_lens_q=cu_seq_lens_q.to(self.device), + cu_seq_lens_k=cu_seq_lens_k.to(self.device), max_length_q=max_length_q, max_length_k=max_length_k, ) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index c8d399f08a6e..74be7ece2612 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -379,9 +379,9 @@ def _flash_attention_forward( query_states, key_states, value_states, position_ids ) else: - q = query_states.view(-1, query_states.size(-2), query_states.size(-1)) - k = key_states.view(-1, key_states.size(-2), key_states.size(-1)) - v = value_states.view(-1, value_states.size(-2), value_states.size(-1)) + q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) + k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) + v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) mq, mk = max_length_q, max_length_k cu_q, cu_k = cu_seq_lens_q, cu_seq_lens_k out = flash_varlen_fn( From c0b600a54ab29c82c3639af4c1e3bac6e6792d79 Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 21 Jul 2025 15:23:37 +0200 Subject: [PATCH 22/38] flash attention is kinda broken on recent cuda version so allow the opportunity to use something else --- .../modeling_flash_attention_utils.py | 34 +++++++++++++++++-- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 74be7ece2612..6884a34cbc21 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -16,9 +16,12 @@ import warnings from typing import Optional, TypedDict +from kernels import get_kernel import torch import torch.nn.functional as F +from transformers.utils.import_utils import is_kernels_available + from .utils import ( is_flash_attn_2_available, is_flash_attn_3_available, @@ -265,10 +268,35 @@ def _lazy_imports(impl: Optional[str]): is_fa2 = is_flash_attn_2_available() or is_torch_npu_available() is_fa3 = is_flash_attn_3_available() if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3): - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import pad_input, unpad_input + try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import pad_input, unpad_input + + return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, False + + except ImportError as e: + print( + "Official flash attention import did not work, do you want to try to use `kernels-community/flash-attn` (trust remote code)?" + ) + if input("Yes / No") == "Yes": + if not is_kernels_available(): + raise ImportError("You need to install kernels: `pip install kernels`") + from kernels import get_kernel + + impl = get_kernel("kernels-community/flash-attn") + pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input + return ( + getattr(impl, "flash_attn_func", None), + getattr(impl, "flash_attn_varlen_func"), + pad_input, + unpad_input, + True, + ) - return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, False + else: + raise ImportError( + "Failed to import flash attention 2, please install it or use another implementation." + ) from e if impl == "flash_attention_3" or (impl is None and is_fa3): from flash_attn_interface import flash_attn_func, flash_attn_varlen_func From 11e500010485c631146eebd3e1ec335118802977 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 21 Jul 2025 15:30:07 +0000 Subject: [PATCH 23/38] fix! --- .../modeling_flash_attention_utils.py | 29 +++++++++++++++---- src/transformers/modeling_utils.py | 2 +- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 6884a34cbc21..cabff79b7750 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -275,10 +275,16 @@ def _lazy_imports(impl: Optional[str]): return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, False except ImportError as e: - print( - "Official flash attention import did not work, do you want to try to use `kernels-community/flash-attn` (trust remote code)?" - ) - if input("Yes / No") == "Yes": + if not globals().get("use_remote_fa2", None): + use_remote_fa2 = ( + input( + "Unable to import the official flash attention, do you want to try to use `kernels-community/flash-attn` (trust remote code) Yes or No? " + ) + .strip() + .lower() + ) + globals()["use_remote_fa2"] = use_remote_fa2 in {"yes", "y", "1"} + if globals()["use_remote_fa2"]: if not is_kernels_available(): raise ImportError("You need to install kernels: `pip install kernels`") from kernels import get_kernel @@ -358,7 +364,20 @@ def _flash_attention_forward( implementation: Optional[str] = None, **kwargs, ): - flash_fn, flash_varlen_fn, pad_fn, unpad_fn, is_fa3 = _lazy_imports(implementation) + if not all(k in globals() for k in ("_flash_fn", "_flash_varlen_fn", "_pad_fn", "_unpad_fn", "_is_fa3")): + flash_fn, flash_varlen_fn, pad_fn, unpad_fn, is_fa3 = _lazy_imports(implementation) + globals()["_flash_fn"] = flash_fn + globals()["_flash_varlen_fn"] = flash_varlen_fn + globals()["_pad_fn"] = pad_fn + globals()["_unpad_fn"] = unpad_fn + globals()["_is_fa3"] = is_fa3 + else: + flash_fn = globals()["_flash_fn"] + flash_varlen_fn = globals()["_flash_varlen_fn"] + pad_fn = globals()["_pad_fn"] + unpad_fn = globals()["_unpad_fn"] + is_fa3 = globals()["_is_fa3"] + causal = is_causal and not (use_top_left_mask and query_length == 1) use_sw = ( (_flash_supports_window or "window_size" in inspect.signature(flash_varlen_fn).parameters) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 874817ac0d72..18b463f848a6 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2618,7 +2618,7 @@ def _check_and_adjust_attn_implementation( ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[repo_id] = ALL_MASK_ATTENTION_FUNCTIONS[ "flash_attention_2" ] - applicable_attn_implementation = repo_id" + applicable_attn_implementation = repo_id except FileNotFoundError as e: logger.warning_once( f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using " From 1c073509927a6824822fa093f69c6f56b24ffa82 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 21 Jul 2025 15:39:15 +0000 Subject: [PATCH 24/38] protect kernels import --- src/transformers/generation/utils.py | 1 - src/transformers/modeling_flash_attention_utils.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index a4685574fafc..3c85b92d40ca 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -619,7 +619,6 @@ def prepare_inputs_for_generation( model_input = model_input.clone(memory_format=torch.contiguous_format) model_inputs[model_input_name] = model_input - q_lengths = attention_mask.sum(dim=-1) # 6. Create 4D attention mask is we are using a compilable cache (important for performant compiled forward # pass) if ( diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index cabff79b7750..c2220737c837 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -16,7 +16,6 @@ import warnings from typing import Optional, TypedDict -from kernels import get_kernel import torch import torch.nn.functional as F From cdaa1eb6a8ae15491f4c70533a18ae18709712f1 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 22 Jul 2025 13:49:07 +0200 Subject: [PATCH 25/38] update --- .../generation/continuous_batching.py | 3 ++- src/transformers/generation/logits_process.py | 21 ++++++++++++++----- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py index 09ee1fe8ce1d..e462e483c24e 100644 --- a/src/transformers/generation/continuous_batching.py +++ b/src/transformers/generation/continuous_batching.py @@ -1119,7 +1119,8 @@ def __init__( self._request_lock = threading.Lock() self.model.generation_config.top_p = None self.do_sample = getattr(generation_config, "do_sample", True) - self.logit_processor = self.model._get_logits_processor(self.model.generation_config) + generation_config = model.generation_config if generation_config is None else generation_config + self.logit_processor = self.model._get_logits_processor(generation_config) self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True) self.profile = getattr(generation_config, "profile", False) self.manual_eviction = manual_eviction diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index d4c08e270bbc..92dabcb9529f 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -361,13 +361,24 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to if self.prompt_ignore_length: input_ids = input_ids[:, self.prompt_ignore_length :] - score = torch.gather(scores, 1, input_ids) + # Select the last generated token for each sequence + last_token_ids = input_ids[:, -1] # shape: (batch_size,) + + # Gather scores for those tokens + batch_indices = torch.arange(scores.size(0), device=scores.device) + token_scores = scores[batch_indices, last_token_ids] + + # Apply penalty + adjusted_scores = torch.where( + token_scores < 0, + token_scores * self.penalty, + token_scores / self.penalty, + ) - # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities - score = torch.where(score < 0, score * self.penalty, score / self.penalty) + # Write back into scores tensor + scores[batch_indices, last_token_ids] = adjusted_scores - scores_processed = scores.scatter(1, input_ids, score) - return scores_processed + return scores class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor): From 767d5852c67b3afdac8f563d1d711b392a393524 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 22 Jul 2025 13:58:46 +0200 Subject: [PATCH 26/38] properly parse generation config being passed --- src/transformers/generation/continuous_batching.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py index 09ee1fe8ce1d..f10c5ac7c0d5 100644 --- a/src/transformers/generation/continuous_batching.py +++ b/src/transformers/generation/continuous_batching.py @@ -1119,7 +1119,8 @@ def __init__( self._request_lock = threading.Lock() self.model.generation_config.top_p = None self.do_sample = getattr(generation_config, "do_sample", True) - self.logit_processor = self.model._get_logits_processor(self.model.generation_config) + generation_config = generation_config if generation_config else model.generation_config + self.logit_processor = self.model._get_logits_processor(generation_config) self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True) self.profile = getattr(generation_config, "profile", False) self.manual_eviction = manual_eviction From c75c5398b1af3a0cd621bb9b65f86f6dd8faec0a Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 22 Jul 2025 14:05:28 +0200 Subject: [PATCH 27/38] revert and update --- src/transformers/generation/logits_process.py | 21 +++++-------------- src/transformers/modeling_utils.py | 4 +--- 2 files changed, 6 insertions(+), 19 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 92dabcb9529f..d4c08e270bbc 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -361,24 +361,13 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to if self.prompt_ignore_length: input_ids = input_ids[:, self.prompt_ignore_length :] - # Select the last generated token for each sequence - last_token_ids = input_ids[:, -1] # shape: (batch_size,) - - # Gather scores for those tokens - batch_indices = torch.arange(scores.size(0), device=scores.device) - token_scores = scores[batch_indices, last_token_ids] - - # Apply penalty - adjusted_scores = torch.where( - token_scores < 0, - token_scores * self.penalty, - token_scores / self.penalty, - ) + score = torch.gather(scores, 1, input_ids) - # Write back into scores tensor - scores[batch_indices, last_token_ids] = adjusted_scores + # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities + score = torch.where(score < 0, score * self.penalty, score / self.penalty) - return scores + scores_processed = scores.scatter(1, input_ids, score) + return scores_processed class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 18b463f848a6..65b5ba7ed41a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2609,10 +2609,8 @@ def _check_and_adjust_attn_implementation( repo_id = repo_id.strip() try: kernel = get_kernel(repo_id) - if "flash_attention" in kernel_name: + if hasattr("flash_attn_varlen", kernel): ALL_ATTENTION_FUNCTIONS[repo_id] = partial(flash_attention_forward, implementation=kernel) - elif "paged_attention" in kernel_name: - ALL_ATTENTION_FUNCTIONS[repo_id] = partial(paged_attention_forward, implementation=kernel) else: ALL_ATTENTION_FUNCTIONS[repo_id] = getattr(kernel, kernel_name) ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[repo_id] = ALL_MASK_ATTENTION_FUNCTIONS[ From a2f3126e7c2e9b458d6932034f1aa218df195de0 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 22 Jul 2025 14:17:02 +0200 Subject: [PATCH 28/38] add two tests --- src/transformers/testing_utils.py | 16 ++++++++++++++++ src/transformers/utils/auto_docstring.py | 3 ++- tests/test_modeling_common.py | 21 ++++++++++++++++++++- 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 1df380b6fd70..a6307ae87893 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -86,6 +86,7 @@ is_faiss_available, is_fbgemm_gpu_available, is_flash_attn_2_available, + is_kernels_available, is_flash_attn_3_available, is_flax_available, is_flute_available, @@ -586,6 +587,16 @@ def require_flash_attn(test_case): return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case) +def require_kernels(test_case): + """ + Decorator marking a test that requires Flash Attention. + + These tests are skipped when Flash Attention isn't installed. + + """ + return unittest.skipUnless(is_kernels_available(), "test requires Flash Attention")(test_case) + + def require_flash_attn_3(test_case): """ Decorator marking a test that requires Flash Attention 3. @@ -1103,6 +1114,11 @@ def require_torch_gpu(test_case): return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case) +def require_torch_mps(test_case): + """Decorator marking a test that requires CUDA and PyTorch.""" + return unittest.skipUnless(torch_device == "mps", "test requires CUDA")(test_case) + + def require_large_cpu_ram(test_case, memory: float = 80): """Decorator marking a test that requires a CPU RAM with more than `memory` GiB of memory.""" if not is_psutil_available(): diff --git a/src/transformers/utils/auto_docstring.py b/src/transformers/utils/auto_docstring.py index 11eb382bda99..ab5a393a01bd 100644 --- a/src/transformers/utils/auto_docstring.py +++ b/src/transformers/utils/auto_docstring.py @@ -1170,7 +1170,8 @@ def format_args_docstring(docstring, model_name): placeholders_dict = get_placeholders_dict(placeholders, model_name) # replace the placeholders in the docstring with the values from the placeholders_dict for placeholder, value in placeholders_dict.items(): - docstring = docstring.replace(f"{{{placeholder}}}", value) + if placeholder is not None: + docstring = docstring.replace(f"{{{placeholder}}}", value) return docstring diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 778bd9bd4aae..431b4c00ce44 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -85,6 +85,8 @@ require_bitsandbytes, require_deepspeed, require_flash_attn, + require_kernels, + require_torch_mps, require_flash_attn_3, require_non_hpu, require_safetensors, @@ -3563,6 +3565,24 @@ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_sid else: assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + @require_kernels + @require_torch_gpu + @mark.flash_attn_test + @slow + @is_flaky() + def test_flash_attn_kernel_inference_equivalence(self): + self.flash_attn_inference_equivalence(attn_implementation="kernels-community/flash-attn3", padding_side="left") + + @require_flash_attn + @require_torch_mps + @mark.flash_attn_test + @slow + @is_flaky() + def test_flash_attn_kernel_mps_inference_equivalence(self): + self.flash_attn_inference_equivalence( + attn_implementation="kernels-community/metal-flash-sdpa", padding_side="left" + ) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -4255,7 +4275,6 @@ def flash_attn_from_config(self, attn_implementation: str): for model_class in self.all_generative_model_classes: if not model_class._supports_flash_attn: - self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From 85829d7d955b1454fe3d35e36e5eb1ceb604b30b Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 22 Jul 2025 14:31:12 +0200 Subject: [PATCH 29/38] some fixes --- src/transformers/modeling_utils.py | 20 ++++++++++++++------ src/transformers/testing_utils.py | 2 +- src/transformers/utils/auto_docstring.py | 18 ++++++++++++------ tests/test_modeling_common.py | 6 +++--- 4 files changed, 30 insertions(+), 16 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a52b254037b5..f4fd894b320d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2786,19 +2786,25 @@ def _check_and_adjust_attn_implementation( None to sdpa (to potentially eager). """ applicable_attn_implementation = "sdpa" if attn_implementation is None else attn_implementation - if re.match(r"^[^/:]+/[^/:]+:[^/:]+$", applicable_attn_implementation): + if re.match(r"^[^/:]+/[^/:]+:?[^/:]+$", applicable_attn_implementation): if not is_kernels_available(): raise ValueError("kernels is not installed. Please install it with `pip install kernels`.") # Extract repo_id and kernel_name from the string - repo_id, kernel_name = attn_implementation.split(":") - kernel_name = kernel_name.strip() + if ":" in applicable_attn_implementation: + repo_id, kernel_name = attn_implementation.split(":") + kernel_name = kernel_name.strip() + else: + repo_id = attn_implementation + kernel_name = None repo_id = repo_id.strip() try: kernel = get_kernel(repo_id) - if hasattr("flash_attn_varlen", kernel): - ALL_ATTENTION_FUNCTIONS[repo_id] = partial(flash_attention_forward, implementation=kernel) - else: + if hasattr(kernel, "flash_attn_varlen_func"): + ALL_ATTENTION_FUNCTIONS._global_mapping[repo_id] = partial( + flash_attention_forward, implementation=kernel + ) + elif kernel_name is not None: ALL_ATTENTION_FUNCTIONS[repo_id] = getattr(kernel, kernel_name) ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[repo_id] = ALL_MASK_ATTENTION_FUNCTIONS[ "flash_attention_2" @@ -2810,6 +2816,8 @@ def _check_and_adjust_attn_implementation( "default attention implementation instead (sdpa if available, eager otherwise)." ) applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case + finally: + return applicable_attn_implementation if applicable_attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys(): message = ( f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are ' diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index a6307ae87893..6337f2103d4f 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -1116,7 +1116,7 @@ def require_torch_gpu(test_case): def require_torch_mps(test_case): """Decorator marking a test that requires CUDA and PyTorch.""" - return unittest.skipUnless(torch_device == "mps", "test requires CUDA")(test_case) + return unittest.skipUnless(torch_device == "mps", "test requires MPS")(test_case) def require_large_cpu_ram(test_case, memory: float = 80): diff --git a/src/transformers/utils/auto_docstring.py b/src/transformers/utils/auto_docstring.py index ab5a393a01bd..4ad75c9b9c97 100644 --- a/src/transformers/utils/auto_docstring.py +++ b/src/transformers/utils/auto_docstring.py @@ -1142,10 +1142,14 @@ def get_placeholders_dict(placeholders: list, model_name: str) -> dict: for placeholder in placeholders: # Infer placeholders from the model name and the auto modules if placeholder in PLACEHOLDER_TO_AUTO_MODULE: - place_holder_value = getattr( - getattr(auto_module, PLACEHOLDER_TO_AUTO_MODULE[placeholder][0]), - PLACEHOLDER_TO_AUTO_MODULE[placeholder][1], - ).get(model_name, None) + try: + place_holder_value = getattr( + getattr(auto_module, PLACEHOLDER_TO_AUTO_MODULE[placeholder][0]), + PLACEHOLDER_TO_AUTO_MODULE[placeholder][1], + ).get(model_name, None) + except ImportError: + # In case a library is not installed, we don't want to fail the docstring generation + place_holder_value = None if place_holder_value is not None: if isinstance(place_holder_value, (list, tuple)): place_holder_value = place_holder_value[0] @@ -1171,8 +1175,10 @@ def format_args_docstring(docstring, model_name): # replace the placeholders in the docstring with the values from the placeholders_dict for placeholder, value in placeholders_dict.items(): if placeholder is not None: - docstring = docstring.replace(f"{{{placeholder}}}", value) - + try: + docstring = docstring.replace(f"{{{placeholder}}}", value) + except Exception as e: + pass return docstring diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 86c5f671e403..7649d592de13 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3570,15 +3570,15 @@ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_sid @mark.flash_attn_test @slow @is_flaky() - def test_flash_attn_kernel_inference_equivalence(self): + def test_flash_attn_kernels_inference_equivalence(self): self.flash_attn_inference_equivalence(attn_implementation="kernels-community/flash-attn3", padding_side="left") - @require_flash_attn @require_torch_mps + @require_kernels @mark.flash_attn_test @slow @is_flaky() - def test_flash_attn_kernel_mps_inference_equivalence(self): + def test_flash_attn_kernels_mps_inference_equivalence(self): self.flash_attn_inference_equivalence( attn_implementation="kernels-community/metal-flash-sdpa", padding_side="left" ) From 56981a55d8c973bba158911fe969601b5027e667 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 22 Jul 2025 14:40:11 +0200 Subject: [PATCH 30/38] fix test FA2 --- tests/test_modeling_common.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 7649d592de13..4c11aabb342b 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3476,18 +3476,12 @@ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_sid self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.head_dim = 64 + print(config) model = model_class(config) with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation - ) - model_fa.to(torch_device) - - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) model.to(torch_device) - dummy_input = inputs_dict[model.main_input_name][:1] if dummy_input.dtype in [torch.float32, torch.float16]: dummy_input = dummy_input.to(torch.bfloat16) @@ -3506,11 +3500,14 @@ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_sid decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) - outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + model.set_attn_implementation(attn_implementation) + outputs_fa = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) else: outputs = model(dummy_input, output_hidden_states=True) - outputs_fa = model_fa(dummy_input, output_hidden_states=True) + model.set_attn_implementation(attn_implementation) + outputs_fa = model(dummy_input, output_hidden_states=True) + model.set_attn_implementation("sdpa") logits = ( outputs.hidden_states[-1] if not model.config.is_encoder_decoder @@ -3534,7 +3531,8 @@ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_sid other_inputs["attention_mask"] = dummy_attention_mask outputs = model(dummy_input, **other_inputs) - outputs_fa = model_fa(dummy_input, **other_inputs) + model.set_attn_implementation(attn_implementation) + outputs_fa = model(dummy_input, **other_inputs) else: other_inputs = { "output_hidden_states": True, @@ -3543,8 +3541,10 @@ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_sid other_inputs["attention_mask"] = dummy_attention_mask outputs = model(dummy_input, **other_inputs) - outputs_fa = model_fa(dummy_input, **other_inputs) + model.set_attn_implementation(attn_implementation) + outputs_fa = model(dummy_input, **other_inputs) + model.set_attn_implementation("sdpa") logits = ( outputs.hidden_states[-1] if not model.config.is_encoder_decoder @@ -3561,7 +3561,8 @@ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_sid # check with inference + dropout model.train() - _ = model_fa(dummy_input, **other_inputs) + model.set_attn_implementation(attn_implementation) + _ = model(dummy_input, **other_inputs) else: assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) From b3f7a49ea3bb332d98a26b899b557fc0e381c82a Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 22 Jul 2025 14:47:58 +0200 Subject: [PATCH 31/38] takes comment into account --- src/transformers/generation/utils.py | 2 +- src/transformers/modeling_flash_attention_utils.py | 11 +++++++---- tests/test_modeling_common.py | 3 +-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c468e3d72248..3bffb5fdda91 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -677,7 +677,7 @@ def prepare_inputs_for_generation( if encoder_attention_mask is not None: model_inputs["attention_mask"] = encoder_attention_mask - if "flash" in self.config._attn_implementation: + if "flash" in self.config._attn_implementation and self._supports_attention_backend: tensor_kws = {"dtype": torch.int32, "device": self.device} pos = model_inputs["position_ids"][:, -1] diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index c2220737c837..51734ca2a7d0 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -249,7 +249,7 @@ def _prepare_from_posids(query, key, value, position_ids): def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): warnings.warn( - "prepare_fa2_from_position_ids is deprecated, use _prepare_flash_attention_from_position_ids", + "prepare_fa2_from_position_ids is deprecated, use _prepare_from_posids", FutureWarning, ) return _prepare_from_posids(query, key, value, position_ids) @@ -370,16 +370,19 @@ def _flash_attention_forward( globals()["_pad_fn"] = pad_fn globals()["_unpad_fn"] = unpad_fn globals()["_is_fa3"] = is_fa3 + flash_supports_window = "window_size" in inspect.signature(flash_varlen_fn).parameters) + globals()["_flash_supports_window"] = flash_supports_window else: flash_fn = globals()["_flash_fn"] flash_varlen_fn = globals()["_flash_varlen_fn"] pad_fn = globals()["_pad_fn"] unpad_fn = globals()["_unpad_fn"] is_fa3 = globals()["_is_fa3"] + flash_supports_window = globals()["_flash_supports_window"] causal = is_causal and not (use_top_left_mask and query_length == 1) use_sw = ( - (_flash_supports_window or "window_size" in inspect.signature(flash_varlen_fn).parameters) + (_flash_supports_window or flash_supports_window) and sliding_window and key_states.shape[1] > sliding_window ) @@ -395,7 +398,7 @@ def _flash_attention_forward( query_states, key_states, value_states = fa_peft_integration_check( query_states, key_states, value_states, target_dtype ) - use_varlen = position_ids is not None or all([cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]) + use_mask = position_ids is not None or all([cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]) if attention_mask is not None: q, k, v, idx, (cu_q, cu_k), (mq, mk) = _upad_input( query_states, key_states, value_states, attention_mask, query_length, unpad_fn @@ -415,7 +418,7 @@ def _flash_attention_forward( if isinstance(out_unpad, tuple): out_unpad = out_unpad[0] out = pad_fn(out_unpad, idx, query_states.shape[0], query_length) - elif use_varlen: + elif use_mask: if cu_seq_lens_q is None or cu_seq_lens_k is None: if position_ids is None: raise ValueError( diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4c11aabb342b..c4b8bc4c4041 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3476,8 +3476,7 @@ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_sid self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.head_dim = 64 - print(config) + config.head_dim = 64 # fa2 does not always support arbitrary headim model = model_class(config) with tempfile.TemporaryDirectory() as tmpdirname: From 21e07f77b8a4255922f1687d4781e111f24b1a56 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 22 Jul 2025 14:49:27 +0200 Subject: [PATCH 32/38] fixup --- src/transformers/integrations/flash_paged.py | 2 +- .../modeling_flash_attention_utils.py | 6 +- src/transformers/testing_utils.py | 2 +- src/transformers/utils/auto_docstring.py | 2 +- tests/test_modeling_common.py | 153 +++++++++--------- 5 files changed, 79 insertions(+), 86 deletions(-) diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py index 32621f273230..52d7d1c4503f 100644 --- a/src/transformers/integrations/flash_paged.py +++ b/src/transformers/integrations/flash_paged.py @@ -5,7 +5,7 @@ if is_flash_attn_2_available(): - from flash_attn import flash_attn_varlen_func + pass def paged_attention_forward( diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 51734ca2a7d0..60a19a8df65c 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -370,7 +370,7 @@ def _flash_attention_forward( globals()["_pad_fn"] = pad_fn globals()["_unpad_fn"] = unpad_fn globals()["_is_fa3"] = is_fa3 - flash_supports_window = "window_size" in inspect.signature(flash_varlen_fn).parameters) + flash_supports_window = "window_size" in inspect.signature(flash_varlen_fn).parameters globals()["_flash_supports_window"] = flash_supports_window else: flash_fn = globals()["_flash_fn"] @@ -382,9 +382,7 @@ def _flash_attention_forward( causal = is_causal and not (use_top_left_mask and query_length == 1) use_sw = ( - (_flash_supports_window or flash_supports_window) - and sliding_window - and key_states.shape[1] > sliding_window + (_flash_supports_window or flash_supports_window) and sliding_window and key_states.shape[1] > sliding_window ) flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sw else {} if not is_fa3: diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 6337f2103d4f..0e117d71f712 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -86,7 +86,6 @@ is_faiss_available, is_fbgemm_gpu_available, is_flash_attn_2_available, - is_kernels_available, is_flash_attn_3_available, is_flax_available, is_flute_available, @@ -105,6 +104,7 @@ is_jinja_available, is_jumanpp_available, is_keras_nlp_available, + is_kernels_available, is_levenshtein_available, is_librosa_available, is_liger_kernel_available, diff --git a/src/transformers/utils/auto_docstring.py b/src/transformers/utils/auto_docstring.py index 4ad75c9b9c97..f277df1af17e 100644 --- a/src/transformers/utils/auto_docstring.py +++ b/src/transformers/utils/auto_docstring.py @@ -1177,7 +1177,7 @@ def format_args_docstring(docstring, model_name): if placeholder is not None: try: docstring = docstring.replace(f"{{{placeholder}}}", value) - except Exception as e: + except Exception: pass return docstring diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index c4b8bc4c4041..9c4c0da4ee19 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -85,15 +85,15 @@ require_bitsandbytes, require_deepspeed, require_flash_attn, - require_kernels, - require_torch_mps, require_flash_attn_3, + require_kernels, require_non_hpu, require_safetensors, require_torch, require_torch_accelerator, require_torch_gpu, require_torch_greater_or_equal, + require_torch_mps, require_torch_multi_accelerator, require_torch_multi_gpu, require_torch_sdpa, @@ -3479,91 +3479,86 @@ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_sid config.head_dim = 64 # fa2 does not always support arbitrary headim model = model_class(config) - with tempfile.TemporaryDirectory() as tmpdirname: - model.to(torch_device) - dummy_input = inputs_dict[model.main_input_name][:1] - if dummy_input.dtype in [torch.float32, torch.float16]: - dummy_input = dummy_input.to(torch.bfloat16) + model.to(torch_device) + dummy_input = inputs_dict[model.main_input_name][:1] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) - dummy_attention_mask = inputs_dict.get("attention_mask", None) + dummy_attention_mask = inputs_dict.get("attention_mask", None) - if dummy_attention_mask is not None: - dummy_attention_mask = dummy_attention_mask[:1] - if padding_side == "left": - dummy_attention_mask[:, 1:] = 1 - dummy_attention_mask[:, :1] = 0 - else: - dummy_attention_mask[:, :-1] = 1 - dummy_attention_mask[:, -1:] = 0 - if model.config.is_encoder_decoder: - decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] - - outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) - model.set_attn_implementation(attn_implementation) - outputs_fa = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + if dummy_attention_mask is not None: + dummy_attention_mask = dummy_attention_mask[:1] + if padding_side == "left": + dummy_attention_mask[:, 1:] = 1 + dummy_attention_mask[:, :1] = 0 else: - outputs = model(dummy_input, output_hidden_states=True) - model.set_attn_implementation(attn_implementation) - outputs_fa = model(dummy_input, output_hidden_states=True) - - model.set_attn_implementation("sdpa") - logits = ( - outputs.hidden_states[-1] - if not model.config.is_encoder_decoder - else outputs.decoder_hidden_states[-1] - ) - logits_fa = ( - outputs_fa.hidden_states[-1] - if not model.config.is_encoder_decoder - else outputs_fa.decoder_hidden_states[-1] - ) + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + if model.config.is_encoder_decoder: + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] - assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + model.set_attn_implementation(attn_implementation) + outputs_fa = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + else: + outputs = model(dummy_input, output_hidden_states=True) + model.set_attn_implementation(attn_implementation) + outputs_fa = model(dummy_input, output_hidden_states=True) - if model.config.is_encoder_decoder: - other_inputs = { - "decoder_input_ids": decoder_input_ids, - "decoder_attention_mask": dummy_attention_mask, - "output_hidden_states": True, - } - if dummy_attention_mask is not None: - other_inputs["attention_mask"] = dummy_attention_mask + model.set_attn_implementation("sdpa") + logits = ( + outputs.hidden_states[-1] if not model.config.is_encoder_decoder else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) - outputs = model(dummy_input, **other_inputs) - model.set_attn_implementation(attn_implementation) - outputs_fa = model(dummy_input, **other_inputs) - else: - other_inputs = { - "output_hidden_states": True, - } - if dummy_attention_mask is not None: - other_inputs["attention_mask"] = dummy_attention_mask - - outputs = model(dummy_input, **other_inputs) - model.set_attn_implementation(attn_implementation) - outputs_fa = model(dummy_input, **other_inputs) - - model.set_attn_implementation("sdpa") - logits = ( - outputs.hidden_states[-1] - if not model.config.is_encoder_decoder - else outputs.decoder_hidden_states[-1] - ) - logits_fa = ( - outputs_fa.hidden_states[-1] - if not model.config.is_encoder_decoder - else outputs_fa.decoder_hidden_states[-1] - ) + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) - if padding_side == "left": - assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) + if model.config.is_encoder_decoder: + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask - # check with inference + dropout - model.train() - model.set_attn_implementation(attn_implementation) - _ = model(dummy_input, **other_inputs) - else: - assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + outputs = model(dummy_input, **other_inputs) + model.set_attn_implementation(attn_implementation) + outputs_fa = model(dummy_input, **other_inputs) + else: + other_inputs = { + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + model.set_attn_implementation(attn_implementation) + outputs_fa = model(dummy_input, **other_inputs) + + model.set_attn_implementation("sdpa") + logits = ( + outputs.hidden_states[-1] if not model.config.is_encoder_decoder else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + if padding_side == "left": + assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) + + # check with inference + dropout + model.train() + model.set_attn_implementation(attn_implementation) + _ = model(dummy_input, **other_inputs) + else: + assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) @require_kernels @require_torch_gpu From a8b7ec6506c2d37b263bf1d060ab8665919a43c9 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 22 Jul 2025 14:50:45 +0200 Subject: [PATCH 33/38] revert changes --- examples/pytorch/continuous_batching.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index 1cba7cf783eb..9aaa836f7bae 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -11,10 +11,7 @@ model_id = "meta-llama/Llama-3.2-3b-Instruct" model = AutoModelForCausalLM.from_pretrained( - model_id, - attn_implementation="kernels-community/metal-flash-sdpa:paged_attention", - torch_dtype=torch.bfloat16, - device_map="auto", + model_id, attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto" ).eval() tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") @@ -31,7 +28,7 @@ ) train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test") -train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version + # --- Example 1: Simple Version using generate_batch --- print("--- Running CB Generation Example ---") From f111d33850744fa76a1e230e16d57d90d8621c7f Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 22 Jul 2025 14:52:18 +0200 Subject: [PATCH 34/38] revert the clone, it is only needed because the metal kernel is not doing it? --- src/transformers/modeling_flash_attention_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 60a19a8df65c..a8af5d0b45e4 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -406,7 +406,7 @@ def _flash_attention_forward( k, v, cu_seqlens_q=cu_q.to(torch.int32), - cu_seqlens_k=cu_k.clone().to(torch.int32), + cu_seqlens_k=cu_k.to(torch.int32), max_seqlen_q=mq, max_seqlen_k=mk, softmax_scale=softmax_scale, @@ -436,7 +436,7 @@ def _flash_attention_forward( k, v, cu_seqlens_q=cu_q.to(torch.int32), - cu_seqlens_k=cu_k.clone().to(torch.int32), + cu_seqlens_k=cu_k.to(torch.int32), max_seqlen_q=mq, max_seqlen_k=mk, softmax_scale=softmax_scale, From cd98c1fee3167c484944580516edea03fb03f784 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 22 Jul 2025 15:06:43 +0200 Subject: [PATCH 35/38] [docs] update attention implementation and cache docs (#39547) * update docs * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * applu suggestions --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/attention_interface.md | 28 +++++++++++++++++++++++++ docs/source/en/cache_explanation.md | 30 ++++++++++++++++++++++++++- docs/source/en/llm_optims.md | 12 +++++++++-- docs/source/en/perf_infer_gpu_one.md | 8 ++++++- 4 files changed, 74 insertions(+), 4 deletions(-) diff --git a/docs/source/en/attention_interface.md b/docs/source/en/attention_interface.md index 034686ad2c8a..407a47a7d353 100644 --- a/docs/source/en/attention_interface.md +++ b/docs/source/en/attention_interface.md @@ -72,6 +72,34 @@ model(torch.ones(1, 5, dtype=int)) and it will stop printing the statements, as it now uses the `sdpa` attention. This allows to quickly change an attention function, without needing to reload the model! +## Different attention per backbone in multimodal models + +For multimodal models different attention functions may work better for each backbone module. For example, some vision backbones perform better in fp32, but are incompatible with FlashAttention. To continue using FlashAttention while keeping the vision encoder in fp32, create a dict and map each config to an attention implementation as shown below. + +```python +from transformers import AutoModelForImageTextToText + +model_id = "facebook/chameleon-7b" + +attention_implementation_per_backbone = {"vision_config": "sdpa", "text_config": "flash_attention_2"} +model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation=attention_implementation_per_backbone) + +# NOTE: keys in the attention implementation have to be the same as the sub-config names +for key in attention_implementation_per_backbone: + assert key in model.config.sub_configs, f"Invalid key in `attention_implementation`" + +# You can omit certain backbones - the default attention function (SDPA) will be used +# This is equivalent to the previous example +model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation={"text_config": "flash_attention_2"}) + + +# Set the same attention implementation for all backbones with single string, same as in non-multimodal models +model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager") + +# Alternatively use a dict with an empty key for global configuration +model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation={"": "eager"}) +``` + ## What about new args needed in my custom attention function? But indeed, what if the new function requires a new arg to be properly used? It's no issue! Models supporting the diff --git a/docs/source/en/cache_explanation.md b/docs/source/en/cache_explanation.md index 13f310669200..17d35c33ab4c 100644 --- a/docs/source/en/cache_explanation.md +++ b/docs/source/en/cache_explanation.md @@ -132,6 +132,34 @@ for _ in range(max_new_tokens): print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]) "[INST] Hello, what's your name. [/INST] Hello! My name is LLaMA," ``` + +## Cache position + +The cache position tracks where to insert new tokens in the attention cache. It represents the *absolute* position of each token in the context, independent of padding or batch structure. Suppose you already cached `N` tokens and are now processing `K` new tokens. The cache position for the new tokens will range from `N` to `N + K - 1`. In other words, you're processing tokens at positions - `[N, N + 1, N + 2, ..., N + K - 1]`. + +Cache position is used internally for two purposes: + +1. Selecting new tokens to process in the input sequence and ensuring only tokens that haven’t been cached yet are passed to the model's `forward`. +2. Storing key/value pairs at the correct positions in the cache. This is especially important for fixed-size caches, like [`StaticCache`], that pre-allocates a specific cache length. + +The generation loop usually takes care of the cache position, but if you're writing a custom generation method, it is important that cache positions are accurate since they are used to write and read key/value states into fixed slots. + + +```py +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache + +model_id = "meta-llama/Llama-2-7b-chat-hf" +model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda:0") +tokenizer = AutoTokenizer.from_pretrained(model_id) + +messages = [{"role": "user", "content": "You are a helpful assistant."}] +inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda:0") +generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=10) + +``` + + ## Legacy cache format Before the [`Cache`] class, the cache used to be stored as a tuple of tuples of tensors. This format is dynamic because it grows as text is generated, similar to [`DynamicCache`]. @@ -157,4 +185,4 @@ generation_outputs = model.generate(**inputs, return_dict_in_generate=True, retu cache = DynamicCache.from_legacy_cache(generation_outputs.past_key_values) legacy_format_cache = cache.to_legacy_cache() -``` \ No newline at end of file +``` diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index e8e20dab5db6..0295a5bf1b34 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -341,7 +341,7 @@ A known issue with transformer models is that the self-attention mechanism grows FlashAttention and [FlashAttention-2](./perf_infer_gpu_one#flashattention-2) break up the attention computation into smaller chunks and reduces the number of intermediate read/write operations to the GPU memory to speed up inference. FlashAttention-2 improves on the original FlashAttention algorithm by also parallelizing over sequence length dimension and better partitioning work on the hardware to reduce synchronization and communication overhead. -To use FlashAttention-2, set [attn_implementation](https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.PreTrainedModel.from_pretrained.attn_implementation) to `"flash_attention_2"` in [`~PreTrainedModel.from_pretrained`]. +To use FlashAttention-2, set [attn_implementation](https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.PreTrainedModel.from_pretrained.attn_implementation) to `"flash_attention_2"` in [`~PreTrainedModel.from_pretrained`] or set with `model.set_attention_implementation("flash_attention_2")` to dynamically update the [attention interface](./attention_interface) after the model is loaded. ```py from transformers import AutoModelForCausalLM, BitsAndBytesConfig @@ -353,6 +353,14 @@ model = AutoModelForCausalLM.from_pretrained( torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", ) + +# Change the model's attention dynamically after loading +model = AutoModelForCausalLM.from_pretrained( + "google/gemma-2b", + quantization_config=quant_config, + torch_dtype=torch.bfloat16 +) +model.set_attention_implementation("flash_attention_2") ``` ### PyTorch scaled dot product attention @@ -360,7 +368,7 @@ model = AutoModelForCausalLM.from_pretrained( Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch's C++ implementation. SDPA chooses the most performant attention algorithm if you're using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation. > [!TIP] -> SDPA automaticallysupports FlashAttention-2 as long as you have the latest PyTorch version installed. +> SDPA automatically supports FlashAttention-2 as long as you have the latest PyTorch version installed. Use the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager to explicitly enable or disable any of the four attention algorithms. For example, use `SDPBackend.FLASH_ATTENTION` to enable FlashAttention. diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index c3a7ddc8d8af..fa726e1f98b4 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -177,10 +177,16 @@ There are three supported implementations available. SDPA is used by default for PyTorch v2.1.1. and greater when an implementation is available. You could explicitly enable SDPA by setting `attn_implementation="sdpa"` in [`~PreTrainedModel.from_pretrained`] though. Certain attention parameters, such as `head_mask` and `output_attentions=True`, are unsupported and returns a warning that Transformers will fall back to the (slower) eager implementation. +Refer to the [AttentionInterface](./attention_interface) guide to learn how to change the attention implementation after loading a model. + ```py from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", device_map="auto", attn_implementation="sdpa") + +# Change the model's attention dynamically after loading it +model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", device_map="auto") +model.set_attention_implementation("sdpa") ``` SDPA selects the most performant implementation available, but you can also explicitly select an implementation with [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager. The example below shows how to enable the FlashAttention2 implementation with `enable_flash=True`. @@ -234,7 +240,7 @@ FlashAttention2 support is currently limited to Instinct MI210, Instinct MI250 a -Enable FlashAttention2 by setting `attn_implementation="flash_attention_2"` in [`~PreTrainedModel.from_pretrained`]. FlashAttention2 is only supported for models with the fp16 or bf16 torch type. Make sure to cast your model to the appropriate data type first. +Enable FlashAttention2 by setting `attn_implementation="flash_attention_2"` in [`~PreTrainedModel.from_pretrained`] or by setting `model.set_attention_implementation("flash_attention_2")` to dynamically update the [attention interface](./attention_interface). FlashAttention2 is only supported for models with the fp16 or bf16 torch type. Make sure to cast your model to the appropriate data type first. ```py from transformers import AutoModelForCausalLM From f457a085d99e07f8e25b4da3435d6ed45f24850a Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 22 Jul 2025 15:11:10 +0200 Subject: [PATCH 36/38] fix mps on our side for now --- src/transformers/modeling_flash_attention_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index a8af5d0b45e4..848c2a214113 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -401,6 +401,9 @@ def _flash_attention_forward( q, k, v, idx, (cu_q, cu_k), (mq, mk) = _upad_input( query_states, key_states, value_states, attention_mask, query_length, unpad_fn ) + # TODO for now this is required to work with https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.p + if "mps" in str(q.device): + cu_k = cu_k.clone() out_unpad = flash_varlen_fn( q, k, @@ -431,6 +434,8 @@ def _flash_attention_forward( v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) mq, mk = max_length_q, max_length_k cu_q, cu_k = cu_seq_lens_q, cu_seq_lens_k + if "mps" in str(q.device): + cu_k = cu_k.clone() out = flash_varlen_fn( q, k, From 38d241b48bf53f0b203f7cdea2f6251d02a98328 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 22 Jul 2025 15:13:21 +0200 Subject: [PATCH 37/38] Update src/transformers/integrations/flash_paged.py --- src/transformers/integrations/flash_paged.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py index 52d7d1c4503f..32621f273230 100644 --- a/src/transformers/integrations/flash_paged.py +++ b/src/transformers/integrations/flash_paged.py @@ -5,7 +5,7 @@ if is_flash_attn_2_available(): - pass + from flash_attn import flash_attn_varlen_func def paged_attention_forward( From c0f4f0998ccb4d83d592afeeaca808d1815a41c8 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 22 Jul 2025 15:25:53 +0200 Subject: [PATCH 38/38] no qa --- src/transformers/integrations/flash_paged.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py index 32621f273230..236e216b3ff2 100644 --- a/src/transformers/integrations/flash_paged.py +++ b/src/transformers/integrations/flash_paged.py @@ -5,7 +5,7 @@ if is_flash_attn_2_available(): - from flash_attn import flash_attn_varlen_func + from flash_attn import flash_attn_varlen_func # noqa: F401 def paged_attention_forward(