diff --git a/docs/source/en/attention_interface.md b/docs/source/en/attention_interface.md index 034686ad2c8a..407a47a7d353 100644 --- a/docs/source/en/attention_interface.md +++ b/docs/source/en/attention_interface.md @@ -72,6 +72,34 @@ model(torch.ones(1, 5, dtype=int)) and it will stop printing the statements, as it now uses the `sdpa` attention. This allows to quickly change an attention function, without needing to reload the model! +## Different attention per backbone in multimodal models + +For multimodal models different attention functions may work better for each backbone module. For example, some vision backbones perform better in fp32, but are incompatible with FlashAttention. To continue using FlashAttention while keeping the vision encoder in fp32, create a dict and map each config to an attention implementation as shown below. + +```python +from transformers import AutoModelForImageTextToText + +model_id = "facebook/chameleon-7b" + +attention_implementation_per_backbone = {"vision_config": "sdpa", "text_config": "flash_attention_2"} +model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation=attention_implementation_per_backbone) + +# NOTE: keys in the attention implementation have to be the same as the sub-config names +for key in attention_implementation_per_backbone: + assert key in model.config.sub_configs, f"Invalid key in `attention_implementation`" + +# You can omit certain backbones - the default attention function (SDPA) will be used +# This is equivalent to the previous example +model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation={"text_config": "flash_attention_2"}) + + +# Set the same attention implementation for all backbones with single string, same as in non-multimodal models +model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager") + +# Alternatively use a dict with an empty key for global configuration +model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation={"": "eager"}) +``` + ## What about new args needed in my custom attention function? But indeed, what if the new function requires a new arg to be properly used? It's no issue! Models supporting the diff --git a/docs/source/en/cache_explanation.md b/docs/source/en/cache_explanation.md index 13f310669200..17d35c33ab4c 100644 --- a/docs/source/en/cache_explanation.md +++ b/docs/source/en/cache_explanation.md @@ -132,6 +132,34 @@ for _ in range(max_new_tokens): print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]) "[INST] Hello, what's your name. [/INST] Hello! My name is LLaMA," ``` + +## Cache position + +The cache position tracks where to insert new tokens in the attention cache. It represents the *absolute* position of each token in the context, independent of padding or batch structure. Suppose you already cached `N` tokens and are now processing `K` new tokens. The cache position for the new tokens will range from `N` to `N + K - 1`. In other words, you're processing tokens at positions - `[N, N + 1, N + 2, ..., N + K - 1]`. + +Cache position is used internally for two purposes: + +1. Selecting new tokens to process in the input sequence and ensuring only tokens that haven’t been cached yet are passed to the model's `forward`. +2. Storing key/value pairs at the correct positions in the cache. This is especially important for fixed-size caches, like [`StaticCache`], that pre-allocates a specific cache length. + +The generation loop usually takes care of the cache position, but if you're writing a custom generation method, it is important that cache positions are accurate since they are used to write and read key/value states into fixed slots. + + +```py +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache + +model_id = "meta-llama/Llama-2-7b-chat-hf" +model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda:0") +tokenizer = AutoTokenizer.from_pretrained(model_id) + +messages = [{"role": "user", "content": "You are a helpful assistant."}] +inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda:0") +generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=10) + +``` + + ## Legacy cache format Before the [`Cache`] class, the cache used to be stored as a tuple of tuples of tensors. This format is dynamic because it grows as text is generated, similar to [`DynamicCache`]. @@ -157,4 +185,4 @@ generation_outputs = model.generate(**inputs, return_dict_in_generate=True, retu cache = DynamicCache.from_legacy_cache(generation_outputs.past_key_values) legacy_format_cache = cache.to_legacy_cache() -``` \ No newline at end of file +``` diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index e8e20dab5db6..0295a5bf1b34 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -341,7 +341,7 @@ A known issue with transformer models is that the self-attention mechanism grows FlashAttention and [FlashAttention-2](./perf_infer_gpu_one#flashattention-2) break up the attention computation into smaller chunks and reduces the number of intermediate read/write operations to the GPU memory to speed up inference. FlashAttention-2 improves on the original FlashAttention algorithm by also parallelizing over sequence length dimension and better partitioning work on the hardware to reduce synchronization and communication overhead. -To use FlashAttention-2, set [attn_implementation](https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.PreTrainedModel.from_pretrained.attn_implementation) to `"flash_attention_2"` in [`~PreTrainedModel.from_pretrained`]. +To use FlashAttention-2, set [attn_implementation](https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.PreTrainedModel.from_pretrained.attn_implementation) to `"flash_attention_2"` in [`~PreTrainedModel.from_pretrained`] or set with `model.set_attention_implementation("flash_attention_2")` to dynamically update the [attention interface](./attention_interface) after the model is loaded. ```py from transformers import AutoModelForCausalLM, BitsAndBytesConfig @@ -353,6 +353,14 @@ model = AutoModelForCausalLM.from_pretrained( torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", ) + +# Change the model's attention dynamically after loading +model = AutoModelForCausalLM.from_pretrained( + "google/gemma-2b", + quantization_config=quant_config, + torch_dtype=torch.bfloat16 +) +model.set_attention_implementation("flash_attention_2") ``` ### PyTorch scaled dot product attention @@ -360,7 +368,7 @@ model = AutoModelForCausalLM.from_pretrained( Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch's C++ implementation. SDPA chooses the most performant attention algorithm if you're using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation. > [!TIP] -> SDPA automaticallysupports FlashAttention-2 as long as you have the latest PyTorch version installed. +> SDPA automatically supports FlashAttention-2 as long as you have the latest PyTorch version installed. Use the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager to explicitly enable or disable any of the four attention algorithms. For example, use `SDPBackend.FLASH_ATTENTION` to enable FlashAttention. diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index c3a7ddc8d8af..fa726e1f98b4 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -177,10 +177,16 @@ There are three supported implementations available. SDPA is used by default for PyTorch v2.1.1. and greater when an implementation is available. You could explicitly enable SDPA by setting `attn_implementation="sdpa"` in [`~PreTrainedModel.from_pretrained`] though. Certain attention parameters, such as `head_mask` and `output_attentions=True`, are unsupported and returns a warning that Transformers will fall back to the (slower) eager implementation. +Refer to the [AttentionInterface](./attention_interface) guide to learn how to change the attention implementation after loading a model. + ```py from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", device_map="auto", attn_implementation="sdpa") + +# Change the model's attention dynamically after loading it +model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", device_map="auto") +model.set_attention_implementation("sdpa") ``` SDPA selects the most performant implementation available, but you can also explicitly select an implementation with [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager. The example below shows how to enable the FlashAttention2 implementation with `enable_flash=True`. @@ -234,7 +240,7 @@ FlashAttention2 support is currently limited to Instinct MI210, Instinct MI250 a -Enable FlashAttention2 by setting `attn_implementation="flash_attention_2"` in [`~PreTrainedModel.from_pretrained`]. FlashAttention2 is only supported for models with the fp16 or bf16 torch type. Make sure to cast your model to the appropriate data type first. +Enable FlashAttention2 by setting `attn_implementation="flash_attention_2"` in [`~PreTrainedModel.from_pretrained`] or by setting `model.set_attention_implementation("flash_attention_2")` to dynamically update the [attention interface](./attention_interface). FlashAttention2 is only supported for models with the fp16 or bf16 torch type. Make sure to cast your model to the appropriate data type first. ```py from transformers import AutoModelForCausalLM diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py index 09ee1fe8ce1d..e462e483c24e 100644 --- a/src/transformers/generation/continuous_batching.py +++ b/src/transformers/generation/continuous_batching.py @@ -1119,7 +1119,8 @@ def __init__( self._request_lock = threading.Lock() self.model.generation_config.top_p = None self.do_sample = getattr(generation_config, "do_sample", True) - self.logit_processor = self.model._get_logits_processor(self.model.generation_config) + generation_config = model.generation_config if generation_config is None else generation_config + self.logit_processor = self.model._get_logits_processor(generation_config) self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True) self.profile = getattr(generation_config, "profile", False) self.manual_eviction = manual_eviction diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e360acdac341..3bffb5fdda91 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -677,6 +677,24 @@ def prepare_inputs_for_generation( if encoder_attention_mask is not None: model_inputs["attention_mask"] = encoder_attention_mask + if "flash" in self.config._attn_implementation and self._supports_attention_backend: + tensor_kws = {"dtype": torch.int32, "device": self.device} + pos = model_inputs["position_ids"][:, -1] + + cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), pos.cumsum(0).add(1)], 0) + max_length_k = int(pos.max()) + 1 + + bs, seq_len = input_ids.size() + q_len = torch.ones(bs, **tensor_kws) if seq_len == 1 else pos.to(torch.int32).add(1) + cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0)], 0) + max_length_q = int(q_len.max()) + + model_inputs.update( + cu_seq_lens_q=cu_seq_lens_q.to(self.device), + cu_seq_lens_k=cu_seq_lens_k.to(self.device), + max_length_q=max_length_q, + max_length_k=max_length_k, + ) # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`). for key, value in kwargs.items(): if key not in model_inputs: diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index 00df0ef0fd66..43c65b46c805 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -38,7 +38,6 @@ def flash_attention_forward( "FlashAttention does not support inputs with dim=0.\n" "Please check your input shapes or use SDPA instead." ) - # FA2 uses non-transposed inputs query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -76,6 +75,7 @@ def flash_attention_forward( use_top_left_mask=_use_top_left_mask, target_dtype=target_dtype, attn_implementation=module.config._attn_implementation, + layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None, **kwargs, ) diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py index b0463d952487..236e216b3ff2 100644 --- a/src/transformers/integrations/flash_paged.py +++ b/src/transformers/integrations/flash_paged.py @@ -5,7 +5,7 @@ if is_flash_attn_2_available(): - from flash_attn import flash_attn_varlen_func + from flash_attn import flash_attn_varlen_func # noqa: F401 def paged_attention_forward( @@ -20,6 +20,7 @@ def paged_attention_forward( max_seqlen_q=None, max_seqlen_k=None, block_tables=None, + implementation=None, **kwargs, ) -> torch.Tensor: r"""Perform the forward pass of attention with paged key-value cache. @@ -46,12 +47,14 @@ def paged_attention_forward( """ k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs) + if implementation is not None: + flash_attn_varlen_func = implementation.flash_attn_varlen_func attn_output = flash_attn_varlen_func( - q.transpose(1, 2).squeeze(0), - k.transpose(1, 2).squeeze(0), - v.transpose(1, 2).squeeze(0), + q.transpose(1, 2).squeeze(0).contiguous(), + k.transpose(1, 2).squeeze(0).contiguous(), + v.transpose(1, 2).squeeze(0).contiguous(), cumulative_seqlens_q.to(torch.int32), - cumulative_seqlens_k.to(torch.int32), + cumulative_seqlens_k.to(torch.int32).clone(), max_seqlen_q, max_seqlen_k, softmax_scale=module.scaling, diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 1b5476b0ecc1..848c2a214113 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import inspect import os import warnings @@ -20,6 +19,8 @@ import torch import torch.nn.functional as F +from transformers.utils.import_utils import is_kernels_available + from .utils import ( is_flash_attn_2_available, is_flash_attn_3_available, @@ -31,25 +32,16 @@ logger = logging.get_logger(__name__) -flash_attn_func = None -def _index_first_axis(tensor, indices): - """ - A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis, - after flattening the first two dimensions of the tensor. This is functionally equivalent to - FA2's `index_first_axis` and replaces the need to import it. - """ - # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first - # two dimensions to get (total_tokens, ...) before indexing. - reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:]) - return reshaped_tensor[indices] +def _index_first_axis(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + reshaped = tensor.contiguous().reshape(-1, *tensor.shape[2:]) + return reshaped[indices] def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None): """ FA3-compatible unpad_input function. - Arguments: hidden_states: (batch, seqlen, ...) attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. @@ -80,7 +72,6 @@ def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None): def _fa3_pad_input(hidden_states, indices, batch, seqlen): """ FA3-compatible pad_input function. - Arguments: hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. @@ -95,109 +86,12 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen): return output.view(batch, seqlen, *dim) -FA_VERSION = None -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func as flash_attn_2_func - from flash_attn import flash_attn_varlen_func as flash_attn_2_varlen_func - from flash_attn.bert_padding import pad_input as pad_input_fa2 - from flash_attn.bert_padding import unpad_input as unpad_input_fa2 - from flash_attn.layers.rotary import apply_rotary_emb - - HAS_FA2 = True - FA_VERSION = 2 -elif is_torch_npu_available(): - # patch functions in package `flash-attn` when using flash-attention on Ascend NPU. - from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa: F401 - from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_2_func - from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_2_varlen_func - from .integrations.npu_flash_attention import pad_input as pad_input_fa2 - from .integrations.npu_flash_attention import unpad_input as unpad_input_fa2 - - HAS_FA2 = True - FA_VERSION = 2 -else: - flash_attn_2_func = None - flash_attn_2_varlen_func = None - pad_input_fa2 = None - unpad_input_fa2 = None - apply_rotary_emb = None - HAS_FA2 = False - -if is_flash_attn_3_available(): - from flash_attn_interface import flash_attn_func as flash_attn_3_func - from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func - - pad_input_fa3 = _fa3_pad_input - unpad_input_fa3 = _fa3_unpad_input - HAS_FA3 = True - FA_VERSION = 3 -else: - flash_attn_3_func = None - flash_attn_3_varlen_func = None - pad_input_fa3 = None - unpad_input_fa3 = None - HAS_FA3 = False - - -# Current Flash Attention implementations -if FA_VERSION: - flash_attn_func = globals()[f"flash_attn_{FA_VERSION}_func"] - flash_attn_varlen_func = globals()[f"flash_attn_{FA_VERSION}_varlen_func"] - unpad_input = globals()[f"unpad_input_fa{FA_VERSION}"] - pad_input = globals()[f"pad_input_fa{FA_VERSION}"] - - -_flash_supports_window_size = False - - -if flash_attn_func: - _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) - - -def is_flash_attn_available(): - """Determine whether flash-attention can be used or not.""" - - if is_flash_attn_3_available(): - return True - - # if package `flash-attn` is available, flash-attention can be used natively. - if is_flash_attn_2_available(): - return True - - # flash-attention can be used on Ascend NPU without package `flash-attn` - if is_torch_npu_available(): - return True - - return False - - -def flash_attn_supports_top_left_mask(): - """Determine whether flash-attention uses top-left or down-right mask""" - - if is_flash_attn_3_available(): - return False - - if is_flash_attn_2_available(): - # top-left mask is used in package `flash-attn` with version lower than 2.1.0 - return not is_flash_attn_greater_or_equal_2_10() - - if is_torch_npu_available(): - # down-right mask is used on Ascend NPU by default, set env `NPU_FA2_SPARSE_MODE=2` to activate top-left mask. - from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask - - return is_npu_fa2_top_left_aligned_causal_mask() - - return False - - def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: """ Retrieves indexing data required to repad unpadded (ragged) tensors. - Arguments: attention_mask (`torch.Tensor`): Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - Return: indices (`torch.Tensor`): The indices of non-masked tokens from the flattened input sequence. @@ -229,10 +123,8 @@ def _upad_input( ): """ Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. - This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary tensors for query, key, value tensors. - Arguments: query_layer (`torch.Tensor`): Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). @@ -246,7 +138,6 @@ def _upad_input( Target length. unpad_input_func: The function to use for unpadding the input tensors. - Return: query_layer (`torch.Tensor`): Query state without padding. Shape: (total_target_length, num_heads, head_dim). @@ -299,14 +190,12 @@ def _upad_input( ) -def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): +def _prepare_from_posids(query, key, value, position_ids): """ This function returns necessary arguments to call `flash_attn_varlen_func`. All three query, key, value states will be flattened. Cumulative lengths of each examples in the batch will be extracted from position_ids. - NOTE: ideally cumulative lengths should be prepared at the data collator stage - Arguments: query (`torch.Tensor`): Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). @@ -316,7 +205,6 @@ def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). position_ids (`torch.Tensor`): Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - Return: query (`torch.Tensor`): Query state without padding. Shape: (total_target_length, num_heads, head_dim). @@ -331,19 +219,22 @@ def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). """ - query = query.view(-1, query.size(-2), query.size(-1)) + query = query.contiguous().view(-1, query.size(-2), query.size(-1)) key = key.contiguous().view(-1, key.size(-2), key.size(-1)) value = value.contiguous().view(-1, value.size(-2), value.size(-1)) + cu_seqlens_k = torch.cat( + [torch.tensor([0], dtype=torch.int32, device=query.device), position_ids[:, -1].cumsum(dim=0) + 1], dim=0 + ) + max_k = torch.max(position_ids, dim=1).values.max().item() + 1 position_ids = position_ids.flatten() indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) cu_seq_lens = torch.cat( ( - indices_q[position_ids == 0], + torch.tensor([0], device=position_ids.device, dtype=torch.int32), torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), ) ) - # NOTE: With torch compile, this will cause a graph break if you don't set # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass. @@ -353,61 +244,101 @@ def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing # for some models (e.g. qwen2-vl). max_length = cu_seq_lens.diff().max().item() + return (query, key, value, indices_q, (cu_seq_lens, cu_seqlens_k), (max_length, max_k)) - return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) - -def prepare_fa2_from_position_ids(*args, **kwargs): +def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): warnings.warn( - "The function `prepare_fa2_from_position_ids` in `transformers.modeling_flash_attention_utils` is deprecated and will be removed in a future version. Please use `_prepare_flash_attention_from_position_ids` instead.", + "prepare_fa2_from_position_ids is deprecated, use _prepare_from_posids", FutureWarning, ) - return _prepare_flash_attention_from_position_ids(*args, **kwargs) + return _prepare_from_posids(query, key, value, position_ids) + + +def fa_peft_integration_check(q, k, v, target_dtype: Optional[torch.dtype] = None): + if target_dtype and q.dtype == torch.float32: + logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility.") + q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype) + return q, k, v + + +def _lazy_imports(impl: Optional[str]): + # returns funcs and pad/unpad based on impl + is_fa2 = is_flash_attn_2_available() or is_torch_npu_available() + is_fa3 = is_flash_attn_3_available() + if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3): + try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import pad_input, unpad_input + + return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, False + + except ImportError as e: + if not globals().get("use_remote_fa2", None): + use_remote_fa2 = ( + input( + "Unable to import the official flash attention, do you want to try to use `kernels-community/flash-attn` (trust remote code) Yes or No? " + ) + .strip() + .lower() + ) + globals()["use_remote_fa2"] = use_remote_fa2 in {"yes", "y", "1"} + if globals()["use_remote_fa2"]: + if not is_kernels_available(): + raise ImportError("You need to install kernels: `pip install kernels`") + from kernels import get_kernel + + impl = get_kernel("kernels-community/flash-attn") + pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input + return ( + getattr(impl, "flash_attn_func", None), + getattr(impl, "flash_attn_varlen_func"), + pad_input, + unpad_input, + True, + ) + + else: + raise ImportError( + "Failed to import flash attention 2, please install it or use another implementation." + ) from e + if impl == "flash_attention_3" or (impl is None and is_fa3): + from flash_attn_interface import flash_attn_func, flash_attn_varlen_func + + pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input + return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, True + else: + pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input + return ( + getattr(impl, "flash_attn_func", None), + getattr(impl, "flash_attn_varlen_func"), + pad_input, + unpad_input, + True, + ) -def fa_peft_integration_check( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - target_dtype: Optional[torch.dtype] = None, -): - """ - PEFT usually casts the layer norms in float32 for training stability reasons - therefore the input hidden states gets silently casted in float32. Hence, we need - cast them back in float16 / bfloat16 just to be sure everything works as expected. - This might slowdown training & inference so it is recommended to not cast the LayerNorms! +_flash_supports_window = None - Args: - query (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value (`torch.Tensor`): - Input value states to be passed to Flash Attention API - target_dtype (`torch.dtype`, *optional*): - The dtype to convert the attention tensors to. Conversion can be ignored by - not providing the target dtype. - """ - if target_dtype is None: - return query, key, value - - input_dtype = query.dtype - if input_dtype == torch.float32: - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - query = query.to(target_dtype) - key = key.to(target_dtype) - value = value.to(target_dtype) +def is_flash_attn_available(): + return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available() + + +def flash_attn_supports_top_left_mask(): + if is_flash_attn_3_available(): + return False + if is_flash_attn_2_available(): + return not is_flash_attn_greater_or_equal_2_10() + + from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask - return query, key, value + return is_npu_fa2_top_left_aligned_causal_mask() -flash_241 = is_flash_attn_greater_or_equal("2.4.1") -deterministic_g = None +class FlashAttentionKwargs(TypedDict, total=False): + cumulative_seqlens_q: Optional[torch.LongTensor] + cumulative_seqlens_k: Optional[torch.LongTensor] def _flash_attention_forward( @@ -429,185 +360,100 @@ def _flash_attention_forward( max_length_q: Optional[int] = None, max_length_k: Optional[int] = None, target_dtype: Optional[torch.dtype] = None, - attn_implementation: Optional[str] = None, + implementation: Optional[str] = None, **kwargs, ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`, *optional*): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`float`): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - use_top_left_mask (`bool`, defaults to `False`): - flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. - softcap (`float`, *optional*): - Softcap for the attention logits, used e.g. in gemma2. - deterministic (`bool`, *optional*): - Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. - attn_implementation (`str`, *optional*): - The attention implementation to use. If None, will default to the one based on the environment. - """ - if attn_implementation is None: - _flash_attn_varlen_func = flash_attn_varlen_func - _flash_attn_func = flash_attn_func - _pad_input = pad_input - _unpad_input = unpad_input - _is_fa3 = HAS_FA3 - elif attn_implementation == "flash_attention_3": - _flash_attn_varlen_func = flash_attn_3_varlen_func - _flash_attn_func = flash_attn_3_func - _pad_input = pad_input_fa3 - _unpad_input = unpad_input_fa3 - _is_fa3 = True - elif attn_implementation == "flash_attention_2": - _flash_attn_varlen_func = flash_attn_2_varlen_func - _flash_attn_func = flash_attn_2_func - _pad_input = pad_input_fa2 - _unpad_input = unpad_input_fa2 - _is_fa3 = False - - if not use_top_left_mask: - causal = is_causal + if not all(k in globals() for k in ("_flash_fn", "_flash_varlen_fn", "_pad_fn", "_unpad_fn", "_is_fa3")): + flash_fn, flash_varlen_fn, pad_fn, unpad_fn, is_fa3 = _lazy_imports(implementation) + globals()["_flash_fn"] = flash_fn + globals()["_flash_varlen_fn"] = flash_varlen_fn + globals()["_pad_fn"] = pad_fn + globals()["_unpad_fn"] = unpad_fn + globals()["_is_fa3"] = is_fa3 + flash_supports_window = "window_size" in inspect.signature(flash_varlen_fn).parameters + globals()["_flash_supports_window"] = flash_supports_window else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. - causal = is_causal and query_length != 1 - - # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). - use_sliding_windows = ( - _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window + flash_fn = globals()["_flash_fn"] + flash_varlen_fn = globals()["_flash_varlen_fn"] + pad_fn = globals()["_pad_fn"] + unpad_fn = globals()["_unpad_fn"] + is_fa3 = globals()["_is_fa3"] + flash_supports_window = globals()["_flash_supports_window"] + + causal = is_causal and not (use_top_left_mask and query_length == 1) + use_sw = ( + (_flash_supports_window or flash_supports_window) and sliding_window and key_states.shape[1] > sliding_window ) - flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} - - if _is_fa3: - if dropout > 0.0: - logger.warning_once("Flash Attention 3 does not support dropout. Setting dropout to 0.0.") - else: + flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sw else {} + if not is_fa3: flash_kwargs["dropout_p"] = dropout - - if flash_241: - if deterministic is None: - global deterministic_g - if deterministic_g is None: - deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" - deterministic = deterministic_g - flash_kwargs["deterministic"] = deterministic - + if is_flash_attn_greater_or_equal("2.4.1"): + det = deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + flash_kwargs["deterministic"] = det if softcap is not None: flash_kwargs["softcap"] = softcap - # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op query_states, key_states, value_states = fa_peft_integration_check( query_states, key_states, value_states, target_dtype ) - - # We will use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach - # under two cases: - # Case 1. If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing - # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. - # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to - # use `flash_attn_varlen_func` knowing we already have all necessary the kwargs. NOTE: it is user's responsibility - # to take care of flattenning `position_ids` if that's needed by the model. See #39121 for more information - is_fa2_with_position_ids = ( - position_ids is not None - and query_states.shape[0] == 1 - and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())) - ) - is_fa2_with_varlen_kwargs = all( - kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k) - ) - - # Contains at least one padding token in the sequence + use_mask = position_ids is not None or all([cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]) if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input( - query_states, key_states, value_states, attention_mask, query_length, _unpad_input + q, k, v, idx, (cu_q, cu_k), (mq, mk) = _upad_input( + query_states, key_states, value_states, attention_mask, query_length, unpad_fn ) - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = _flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, + # TODO for now this is required to work with https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.p + if "mps" in str(q.device): + cu_k = cu_k.clone() + out_unpad = flash_varlen_fn( + q, + k, + v, + cu_seqlens_q=cu_q.to(torch.int32), + cu_seqlens_k=cu_k.to(torch.int32), + max_seqlen_q=mq, + max_seqlen_k=mk, softmax_scale=softmax_scale, causal=causal, **flash_kwargs, ) - attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length) - - elif is_fa2_with_varlen_kwargs or is_fa2_with_position_ids: - batch_size = query_states.size(0) - + if isinstance(out_unpad, tuple): + out_unpad = out_unpad[0] + out = pad_fn(out_unpad, idx, query_states.shape[0], query_length) + elif use_mask: if cu_seq_lens_q is None or cu_seq_lens_k is None: - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = ( - _prepare_flash_attention_from_position_ids(query_states, key_states, value_states, position_ids) + if position_ids is None: + raise ValueError( + "Position ids should be passed if the attention mask is not passed and the cu_seq-lens are not passed." + ) + q, k, v, idx, (cu_q, cu_k), (mq, mk) = _prepare_from_posids( + query_states, key_states, value_states, position_ids ) - - cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens - max_length_q, max_length_k = max_seq_lens - else: - query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) - key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) - value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) - - attn_output = _flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seq_lens_q, - cu_seqlens_k=cu_seq_lens_k, - max_seqlen_q=max_length_q, - max_seqlen_k=max_length_k, + q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) + k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) + v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) + mq, mk = max_length_q, max_length_k + cu_q, cu_k = cu_seq_lens_q, cu_seq_lens_k + if "mps" in str(q.device): + cu_k = cu_k.clone() + out = flash_varlen_fn( + q, + k, + v, + cu_seqlens_q=cu_q.to(torch.int32), + cu_seqlens_k=cu_k.to(torch.int32), + max_seqlen_q=mq, + max_seqlen_k=mk, softmax_scale=softmax_scale, causal=causal, **flash_kwargs, ) - - attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) - + if isinstance(out, tuple): + out = out[0] + out = out.view(query_states.shape[0], -1, out.size(-2), out.size(-1)) else: - attn_output = _flash_attn_func( + out = flash_fn( query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs ) - if isinstance(attn_output, tuple): - return attn_output[0] - return attn_output - - -class FlashAttentionKwargs(TypedDict, total=False): - """ - Keyword arguments for Flash Attention with Compile. - - Attributes: - cumulative_seqlens_q (`torch.LongTensor`, *optional*) - Gets cumulative sequence length for query state. - cumulative_seqlens_k (`torch.LongTensor`, *optional*) - Gets cumulative sequence length for key state. - max_length_q (`int`, *optional*): - Maximum sequence length for query state. - max_length_k (`int`, *optional*): - Maximum sequence length for key state. - """ - - cumulative_seqlens_q: Optional[torch.LongTensor] - cumulative_seqlens_k: Optional[torch.LongTensor] - max_length_q: Optional[int] - max_length_k: Optional[int] + return out[0] if isinstance(out, tuple) else out diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 56e4145250a0..f4fd894b320d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -72,6 +72,7 @@ verify_tp_plan, ) from .loss.loss_utils import LOSS_MAPPING +from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS from .pytorch_utils import ( # noqa: F401 Conv1D, apply_chunking_to_forward, @@ -2785,30 +2786,38 @@ def _check_and_adjust_attn_implementation( None to sdpa (to potentially eager). """ applicable_attn_implementation = "sdpa" if attn_implementation is None else attn_implementation - if re.match(r"^[^/:]+/[^/:]+:[^/:]+$", applicable_attn_implementation): + if re.match(r"^[^/:]+/[^/:]+:?[^/:]+$", applicable_attn_implementation): if not is_kernels_available(): raise ValueError("kernels is not installed. Please install it with `pip install kernels`.") # Extract repo_id and kernel_name from the string - repo_id, kernel_name = applicable_attn_implementation.split(":") - kernel_name = kernel_name.strip() + if ":" in applicable_attn_implementation: + repo_id, kernel_name = attn_implementation.split(":") + kernel_name = kernel_name.strip() + else: + repo_id = attn_implementation + kernel_name = None repo_id = repo_id.strip() - try: kernel = get_kernel(repo_id) - ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name)) - applicable_attn_implementation = f"kernel_{repo_id.replace('/', '_')}" + if hasattr(kernel, "flash_attn_varlen_func"): + ALL_ATTENTION_FUNCTIONS._global_mapping[repo_id] = partial( + flash_attention_forward, implementation=kernel + ) + elif kernel_name is not None: + ALL_ATTENTION_FUNCTIONS[repo_id] = getattr(kernel, kernel_name) + ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[repo_id] = ALL_MASK_ATTENTION_FUNCTIONS[ + "flash_attention_2" + ] + applicable_attn_implementation = repo_id except FileNotFoundError as e: logger.warning_once( f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using " "default attention implementation instead (sdpa if available, eager otherwise)." ) applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case - except AttributeError: - raise ValueError( - "the kernel function name or class specified in the attn_implementation argument is not valid. Please check " - "the documentation for the correct format, and check that the kernel exports the class and the function correctly." - ) + finally: + return applicable_attn_implementation if applicable_attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys(): message = ( f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are ' diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 1df380b6fd70..0e117d71f712 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -104,6 +104,7 @@ is_jinja_available, is_jumanpp_available, is_keras_nlp_available, + is_kernels_available, is_levenshtein_available, is_librosa_available, is_liger_kernel_available, @@ -586,6 +587,16 @@ def require_flash_attn(test_case): return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case) +def require_kernels(test_case): + """ + Decorator marking a test that requires Flash Attention. + + These tests are skipped when Flash Attention isn't installed. + + """ + return unittest.skipUnless(is_kernels_available(), "test requires Flash Attention")(test_case) + + def require_flash_attn_3(test_case): """ Decorator marking a test that requires Flash Attention 3. @@ -1103,6 +1114,11 @@ def require_torch_gpu(test_case): return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case) +def require_torch_mps(test_case): + """Decorator marking a test that requires CUDA and PyTorch.""" + return unittest.skipUnless(torch_device == "mps", "test requires MPS")(test_case) + + def require_large_cpu_ram(test_case, memory: float = 80): """Decorator marking a test that requires a CPU RAM with more than `memory` GiB of memory.""" if not is_psutil_available(): diff --git a/src/transformers/utils/auto_docstring.py b/src/transformers/utils/auto_docstring.py index 11eb382bda99..f277df1af17e 100644 --- a/src/transformers/utils/auto_docstring.py +++ b/src/transformers/utils/auto_docstring.py @@ -1142,10 +1142,14 @@ def get_placeholders_dict(placeholders: list, model_name: str) -> dict: for placeholder in placeholders: # Infer placeholders from the model name and the auto modules if placeholder in PLACEHOLDER_TO_AUTO_MODULE: - place_holder_value = getattr( - getattr(auto_module, PLACEHOLDER_TO_AUTO_MODULE[placeholder][0]), - PLACEHOLDER_TO_AUTO_MODULE[placeholder][1], - ).get(model_name, None) + try: + place_holder_value = getattr( + getattr(auto_module, PLACEHOLDER_TO_AUTO_MODULE[placeholder][0]), + PLACEHOLDER_TO_AUTO_MODULE[placeholder][1], + ).get(model_name, None) + except ImportError: + # In case a library is not installed, we don't want to fail the docstring generation + place_holder_value = None if place_holder_value is not None: if isinstance(place_holder_value, (list, tuple)): place_holder_value = place_holder_value[0] @@ -1170,8 +1174,11 @@ def format_args_docstring(docstring, model_name): placeholders_dict = get_placeholders_dict(placeholders, model_name) # replace the placeholders in the docstring with the values from the placeholders_dict for placeholder, value in placeholders_dict.items(): - docstring = docstring.replace(f"{{{placeholder}}}", value) - + if placeholder is not None: + try: + docstring = docstring.replace(f"{{{placeholder}}}", value) + except Exception: + pass return docstring diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 5589c8cc0d61..9c4c0da4ee19 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -86,12 +86,14 @@ require_deepspeed, require_flash_attn, require_flash_attn_3, + require_kernels, require_non_hpu, require_safetensors, require_torch, require_torch_accelerator, require_torch_gpu, require_torch_greater_or_equal, + require_torch_mps, require_torch_multi_accelerator, require_torch_multi_gpu, require_torch_sdpa, @@ -3474,94 +3476,107 @@ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_sid self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.head_dim = 64 # fa2 does not always support arbitrary headim model = model_class(config) - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation - ) - model_fa.to(torch_device) + model.to(torch_device) + dummy_input = inputs_dict[model.main_input_name][:1] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) - model.to(torch_device) + dummy_attention_mask = inputs_dict.get("attention_mask", None) - dummy_input = inputs_dict[model.main_input_name][:1] - if dummy_input.dtype in [torch.float32, torch.float16]: - dummy_input = dummy_input.to(torch.bfloat16) + if dummy_attention_mask is not None: + dummy_attention_mask = dummy_attention_mask[:1] + if padding_side == "left": + dummy_attention_mask[:, 1:] = 1 + dummy_attention_mask[:, :1] = 0 + else: + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + if model.config.is_encoder_decoder: + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] - dummy_attention_mask = inputs_dict.get("attention_mask", None) + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + model.set_attn_implementation(attn_implementation) + outputs_fa = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + else: + outputs = model(dummy_input, output_hidden_states=True) + model.set_attn_implementation(attn_implementation) + outputs_fa = model(dummy_input, output_hidden_states=True) - if dummy_attention_mask is not None: - dummy_attention_mask = dummy_attention_mask[:1] - if padding_side == "left": - dummy_attention_mask[:, 1:] = 1 - dummy_attention_mask[:, :1] = 0 - else: - dummy_attention_mask[:, :-1] = 1 - dummy_attention_mask[:, -1:] = 0 - if model.config.is_encoder_decoder: - decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] + model.set_attn_implementation("sdpa") + logits = ( + outputs.hidden_states[-1] if not model.config.is_encoder_decoder else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) - outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) - outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) - else: - outputs = model(dummy_input, output_hidden_states=True) - outputs_fa = model_fa(dummy_input, output_hidden_states=True) + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) - logits = ( - outputs.hidden_states[-1] - if not model.config.is_encoder_decoder - else outputs.decoder_hidden_states[-1] - ) - logits_fa = ( - outputs_fa.hidden_states[-1] - if not model.config.is_encoder_decoder - else outputs_fa.decoder_hidden_states[-1] - ) + if model.config.is_encoder_decoder: + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask - assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + outputs = model(dummy_input, **other_inputs) + model.set_attn_implementation(attn_implementation) + outputs_fa = model(dummy_input, **other_inputs) + else: + other_inputs = { + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask - if model.config.is_encoder_decoder: - other_inputs = { - "decoder_input_ids": decoder_input_ids, - "decoder_attention_mask": dummy_attention_mask, - "output_hidden_states": True, - } - if dummy_attention_mask is not None: - other_inputs["attention_mask"] = dummy_attention_mask + outputs = model(dummy_input, **other_inputs) + model.set_attn_implementation(attn_implementation) + outputs_fa = model(dummy_input, **other_inputs) - outputs = model(dummy_input, **other_inputs) - outputs_fa = model_fa(dummy_input, **other_inputs) - else: - other_inputs = { - "output_hidden_states": True, - } - if dummy_attention_mask is not None: - other_inputs["attention_mask"] = dummy_attention_mask + model.set_attn_implementation("sdpa") + logits = ( + outputs.hidden_states[-1] if not model.config.is_encoder_decoder else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) - outputs = model(dummy_input, **other_inputs) - outputs_fa = model_fa(dummy_input, **other_inputs) + if padding_side == "left": + assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) - logits = ( - outputs.hidden_states[-1] - if not model.config.is_encoder_decoder - else outputs.decoder_hidden_states[-1] - ) - logits_fa = ( - outputs_fa.hidden_states[-1] - if not model.config.is_encoder_decoder - else outputs_fa.decoder_hidden_states[-1] - ) + # check with inference + dropout + model.train() + model.set_attn_implementation(attn_implementation) + _ = model(dummy_input, **other_inputs) + else: + assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) - if padding_side == "left": - assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) + @require_kernels + @require_torch_gpu + @mark.flash_attn_test + @slow + @is_flaky() + def test_flash_attn_kernels_inference_equivalence(self): + self.flash_attn_inference_equivalence(attn_implementation="kernels-community/flash-attn3", padding_side="left") - # check with inference + dropout - model.train() - _ = model_fa(dummy_input, **other_inputs) - else: - assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + @require_torch_mps + @require_kernels + @mark.flash_attn_test + @slow + @is_flaky() + def test_flash_attn_kernels_mps_inference_equivalence(self): + self.flash_attn_inference_equivalence( + attn_implementation="kernels-community/metal-flash-sdpa", padding_side="left" + ) @require_flash_attn @require_torch_gpu