diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 13da327dab00..24e8765d1f22 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -103,6 +103,16 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen): from flash_attn.bert_padding import unpad_input as unpad_input_fa2 from flash_attn.layers.rotary import apply_rotary_emb + HAS_FA2 = True + FA_VERSION = 2 +elif is_torch_npu_available(): + # patch functions in package `flash-attn` when using flash-attention on Ascend NPU. + from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa: F401 + from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_2_func + from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_2_varlen_func + from .integrations.npu_flash_attention import pad_input as pad_input_fa2 + from .integrations.npu_flash_attention import unpad_input as unpad_input_fa2 + HAS_FA2 = True FA_VERSION = 2 else: @@ -136,22 +146,6 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen): unpad_input = globals()[f"unpad_input_fa{FA_VERSION}"] pad_input = globals()[f"pad_input_fa{FA_VERSION}"] -# patch functions in package `flash-attn` when using flash-attention on Ascend NPU. -if is_torch_npu_available(): - from .integrations.npu_flash_attention import ( - npu_apply_rotary_emb as apply_rotary_emb, # noqa: F401 - ) - from .integrations.npu_flash_attention import ( - npu_flash_attn_func as flash_attn_func, - ) - from .integrations.npu_flash_attention import ( - npu_flash_attn_varlen_func as flash_attn_varlen_func, - ) - from .integrations.npu_flash_attention import ( - pad_input, - unpad_input, - ) - _flash_supports_window_size = False