Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
cd4c7cb
use partial to wrap around `transformers` utils!
ArthurZucker Jul 17, 2025
005f482
try to refactor?
ArthurZucker Jul 17, 2025
1b834a4
revert one wrong change
ArthurZucker Jul 17, 2025
d93f366
just a nit
ArthurZucker Jul 17, 2025
2b7d411
push
ArthurZucker Jul 17, 2025
affba20
reverter watever was wrong!
ArthurZucker Jul 17, 2025
1959eb2
some nits
ArthurZucker Jul 17, 2025
888cd40
fixes when there is no attention mask
ArthurZucker Jul 17, 2025
8f5e62b
Merge branch 'main' of github.com:huggingface/transformers into kerne…
ArthurZucker Jul 17, 2025
5a7ae11
bring the licence back
ArthurZucker Jul 17, 2025
c57673b
some fixes
ArthurZucker Jul 17, 2025
7d69d83
nit
ArthurZucker Jul 17, 2025
7e94910
Merge branch 'kernels-flash-attn' of github.com:huggingface/transform…
ArthurZucker Jul 17, 2025
112e2a6
style
ArthurZucker Jul 17, 2025
501aa7e
remove prints
ArthurZucker Jul 17, 2025
04088be
correct dtype
ArthurZucker Jul 17, 2025
b1e104b
fa flags for testing
vasqu Jul 17, 2025
7087e7b
update
ArthurZucker Jul 17, 2025
cc58aca
Merge branch 'main' into kernels-flash-attn
ArthurZucker Jul 17, 2025
6a2996a
use paged attention if requested!
ArthurZucker Jul 18, 2025
8ddc525
Merge branch 'kernels-flash-attn' of github.com:huggingface/transform…
ArthurZucker Jul 18, 2025
a586294
updates
ArthurZucker Jul 18, 2025
57842f5
a clone was needed, not sure why
ArthurZucker Jul 18, 2025
43b7f32
automatically create cu seq lens when input is flash, this at least m…
ArthurZucker Jul 18, 2025
12bad1b
simplify and improve?
ArthurZucker Jul 18, 2025
c0b600a
flash attention is kinda broken on recent cuda version so allow the o…
ArthurZucker Jul 21, 2025
5c64874
Merge branch 'main' into kernels-flash-attn
ArthurZucker Jul 21, 2025
11e5000
fix!
ArthurZucker Jul 21, 2025
1c07350
protect kernels import
ArthurZucker Jul 21, 2025
cdaa1eb
update
ArthurZucker Jul 22, 2025
767d585
properly parse generation config being passed
ArthurZucker Jul 22, 2025
10f866e
Merge branch 'kernels-flash-attn' of github.com:huggingface/transform…
ArthurZucker Jul 22, 2025
c75c539
revert and update
ArthurZucker Jul 22, 2025
a2f3126
add two tests
ArthurZucker Jul 22, 2025
63b01c3
Merge branch 'main' of github.com:huggingface/transformers into kerne…
ArthurZucker Jul 22, 2025
85829d7
some fixes
ArthurZucker Jul 22, 2025
56981a5
fix test FA2
ArthurZucker Jul 22, 2025
b3f7a49
takes comment into account
ArthurZucker Jul 22, 2025
21e07f7
fixup
ArthurZucker Jul 22, 2025
a8b7ec6
revert changes
ArthurZucker Jul 22, 2025
f111d33
revert the clone, it is only needed because the metal kernel is not d…
ArthurZucker Jul 22, 2025
cd98c1f
[docs] update attention implementation and cache docs (#39547)
zucchini-nlp Jul 22, 2025
f457a08
fix mps on our side for now
ArthurZucker Jul 22, 2025
38d241b
Update src/transformers/integrations/flash_paged.py
ArthurZucker Jul 22, 2025
cb58187
Merge branches 'main' and 'kernels-flash-attn' of github.com:huggingf…
ArthurZucker Jul 22, 2025
c0f4f09
no qa
ArthurZucker Jul 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions docs/source/en/attention_interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,34 @@ model(torch.ones(1, 5, dtype=int))
and it will stop printing the statements, as it now uses the `sdpa` attention.
This allows to quickly change an attention function, without needing to reload the model!

## Different attention per backbone in multimodal models

For multimodal models different attention functions may work better for each backbone module. For example, some vision backbones perform better in fp32, but are incompatible with FlashAttention. To continue using FlashAttention while keeping the vision encoder in fp32, create a dict and map each config to an attention implementation as shown below.

```python
from transformers import AutoModelForImageTextToText

model_id = "facebook/chameleon-7b"

attention_implementation_per_backbone = {"vision_config": "sdpa", "text_config": "flash_attention_2"}
model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation=attention_implementation_per_backbone)

# NOTE: keys in the attention implementation have to be the same as the sub-config names
for key in attention_implementation_per_backbone:
assert key in model.config.sub_configs, f"Invalid key in `attention_implementation`"

# You can omit certain backbones - the default attention function (SDPA) will be used
# This is equivalent to the previous example
model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation={"text_config": "flash_attention_2"})


# Set the same attention implementation for all backbones with single string, same as in non-multimodal models
model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager")

# Alternatively use a dict with an empty key for global configuration
model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation={"": "eager"})
```

## What about new args needed in my custom attention function?

But indeed, what if the new function requires a new arg to be properly used? It's no issue! Models supporting the
Expand Down
30 changes: 29 additions & 1 deletion docs/source/en/cache_explanation.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,34 @@ for _ in range(max_new_tokens):
print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0])
"[INST] Hello, what's your name. [/INST] Hello! My name is LLaMA,"
```

## Cache position

The cache position tracks where to insert new tokens in the attention cache. It represents the *absolute* position of each token in the context, independent of padding or batch structure. Suppose you already cached `N` tokens and are now processing `K` new tokens. The cache position for the new tokens will range from `N` to `N + K - 1`. In other words, you're processing tokens at positions - `[N, N + 1, N + 2, ..., N + K - 1]`.

Cache position is used internally for two purposes:

1. Selecting new tokens to process in the input sequence and ensuring only tokens that haven’t been cached yet are passed to the model's `forward`.
2. Storing key/value pairs at the correct positions in the cache. This is especially important for fixed-size caches, like [`StaticCache`], that pre-allocates a specific cache length.

The generation loop usually takes care of the cache position, but if you're writing a custom generation method, it is important that cache positions are accurate since they are used to write and read key/value states into fixed slots.


```py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache

model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda:0")
tokenizer = AutoTokenizer.from_pretrained(model_id)

messages = [{"role": "user", "content": "You are a helpful assistant."}]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda:0")
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=10)

```


## Legacy cache format

Before the [`Cache`] class, the cache used to be stored as a tuple of tuples of tensors. This format is dynamic because it grows as text is generated, similar to [`DynamicCache`].
Expand All @@ -157,4 +185,4 @@ generation_outputs = model.generate(**inputs, return_dict_in_generate=True, retu

cache = DynamicCache.from_legacy_cache(generation_outputs.past_key_values)
legacy_format_cache = cache.to_legacy_cache()
```
```
12 changes: 10 additions & 2 deletions docs/source/en/llm_optims.md
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ A known issue with transformer models is that the self-attention mechanism grows

FlashAttention and [FlashAttention-2](./perf_infer_gpu_one#flashattention-2) break up the attention computation into smaller chunks and reduces the number of intermediate read/write operations to the GPU memory to speed up inference. FlashAttention-2 improves on the original FlashAttention algorithm by also parallelizing over sequence length dimension and better partitioning work on the hardware to reduce synchronization and communication overhead.

To use FlashAttention-2, set [attn_implementation](https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.PreTrainedModel.from_pretrained.attn_implementation) to `"flash_attention_2"` in [`~PreTrainedModel.from_pretrained`].
To use FlashAttention-2, set [attn_implementation](https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.PreTrainedModel.from_pretrained.attn_implementation) to `"flash_attention_2"` in [`~PreTrainedModel.from_pretrained`] or set with `model.set_attention_implementation("flash_attention_2")` to dynamically update the [attention interface](./attention_interface) after the model is loaded.

```py
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
Expand All @@ -353,14 +353,22 @@ model = AutoModelForCausalLM.from_pretrained(
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)

# Change the model's attention dynamically after loading
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
quantization_config=quant_config,
torch_dtype=torch.bfloat16
)
model.set_attention_implementation("flash_attention_2")
```

### PyTorch scaled dot product attention

Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch's C++ implementation. SDPA chooses the most performant attention algorithm if you're using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation.

> [!TIP]
> SDPA automaticallysupports FlashAttention-2 as long as you have the latest PyTorch version installed.
> SDPA automatically supports FlashAttention-2 as long as you have the latest PyTorch version installed.

Use the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager to explicitly enable or disable any of the four attention algorithms. For example, use `SDPBackend.FLASH_ATTENTION` to enable FlashAttention.

Expand Down
8 changes: 7 additions & 1 deletion docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,16 @@ There are three supported implementations available.

SDPA is used by default for PyTorch v2.1.1. and greater when an implementation is available. You could explicitly enable SDPA by setting `attn_implementation="sdpa"` in [`~PreTrainedModel.from_pretrained`] though. Certain attention parameters, such as `head_mask` and `output_attentions=True`, are unsupported and returns a warning that Transformers will fall back to the (slower) eager implementation.

Refer to the [AttentionInterface](./attention_interface) guide to learn how to change the attention implementation after loading a model.

```py
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", device_map="auto", attn_implementation="sdpa")

# Change the model's attention dynamically after loading it
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", device_map="auto")
model.set_attention_implementation("sdpa")
```

SDPA selects the most performant implementation available, but you can also explicitly select an implementation with [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager. The example below shows how to enable the FlashAttention2 implementation with `enable_flash=True`.
Expand Down Expand Up @@ -234,7 +240,7 @@ FlashAttention2 support is currently limited to Instinct MI210, Instinct MI250 a
</hfoption>
</hfoptions>

Enable FlashAttention2 by setting `attn_implementation="flash_attention_2"` in [`~PreTrainedModel.from_pretrained`]. FlashAttention2 is only supported for models with the fp16 or bf16 torch type. Make sure to cast your model to the appropriate data type first.
Enable FlashAttention2 by setting `attn_implementation="flash_attention_2"` in [`~PreTrainedModel.from_pretrained`] or by setting `model.set_attention_implementation("flash_attention_2")` to dynamically update the [attention interface](./attention_interface). FlashAttention2 is only supported for models with the fp16 or bf16 torch type. Make sure to cast your model to the appropriate data type first.

```py
from transformers import AutoModelForCausalLM
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/generation/continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,8 @@ def __init__(
self._request_lock = threading.Lock()
self.model.generation_config.top_p = None
self.do_sample = getattr(generation_config, "do_sample", True)
self.logit_processor = self.model._get_logits_processor(self.model.generation_config)
generation_config = model.generation_config if generation_config is None else generation_config
self.logit_processor = self.model._get_logits_processor(generation_config)
self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True)
self.profile = getattr(generation_config, "profile", False)
self.manual_eviction = manual_eviction
Expand Down
18 changes: 18 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,24 @@ def prepare_inputs_for_generation(
if encoder_attention_mask is not None:
model_inputs["attention_mask"] = encoder_attention_mask

if "flash" in self.config._attn_implementation and self._supports_attention_backend:
tensor_kws = {"dtype": torch.int32, "device": self.device}
pos = model_inputs["position_ids"][:, -1]

cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), pos.cumsum(0).add(1)], 0)
max_length_k = int(pos.max()) + 1

bs, seq_len = input_ids.size()
q_len = torch.ones(bs, **tensor_kws) if seq_len == 1 else pos.to(torch.int32).add(1)
cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0)], 0)
max_length_q = int(q_len.max())

model_inputs.update(
cu_seq_lens_q=cu_seq_lens_q.to(self.device),
cu_seq_lens_k=cu_seq_lens_k.to(self.device),
max_length_q=max_length_q,
max_length_k=max_length_k,
)
# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/integrations/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def flash_attention_forward(
"FlashAttention does not support inputs with dim=0.\n"
"Please check your input shapes or use SDPA instead."
)

# FA2 uses non-transposed inputs
query = query.transpose(1, 2)
key = key.transpose(1, 2)
Expand Down Expand Up @@ -76,6 +75,7 @@ def flash_attention_forward(
use_top_left_mask=_use_top_left_mask,
target_dtype=target_dtype,
attn_implementation=module.config._attn_implementation,
layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None,
**kwargs,
)

Expand Down
13 changes: 8 additions & 5 deletions src/transformers/integrations/flash_paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
from flash_attn import flash_attn_varlen_func # noqa: F401


def paged_attention_forward(
Expand All @@ -20,6 +20,7 @@ def paged_attention_forward(
max_seqlen_q=None,
max_seqlen_k=None,
block_tables=None,
implementation=None,
**kwargs,
) -> torch.Tensor:
r"""Perform the forward pass of attention with paged key-value cache.
Expand All @@ -46,12 +47,14 @@ def paged_attention_forward(
"""
k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs)

if implementation is not None:
flash_attn_varlen_func = implementation.flash_attn_varlen_func
attn_output = flash_attn_varlen_func(
q.transpose(1, 2).squeeze(0),
k.transpose(1, 2).squeeze(0),
v.transpose(1, 2).squeeze(0),
q.transpose(1, 2).squeeze(0).contiguous(),
k.transpose(1, 2).squeeze(0).contiguous(),
v.transpose(1, 2).squeeze(0).contiguous(),
cumulative_seqlens_q.to(torch.int32),
cumulative_seqlens_k.to(torch.int32),
cumulative_seqlens_k.to(torch.int32).clone(),
max_seqlen_q,
Comment thread
ArthurZucker marked this conversation as resolved.
max_seqlen_k,
softmax_scale=module.scaling,
Expand Down
Loading