diff --git a/colossalai/kernel/cuda_native/flash_attention.py b/colossalai/kernel/cuda_native/flash_attention.py index 3db7374509a0..5f830efe8052 100644 --- a/colossalai/kernel/cuda_native/flash_attention.py +++ b/colossalai/kernel/cuda_native/flash_attention.py @@ -190,9 +190,9 @@ def triton_cuda_check(): try: from flash_attn.flash_attention import FlashAttention from flash_attn.flash_attn_interface import ( - flash_attn_unpadded_func, - flash_attn_unpadded_kvpacked_func, - flash_attn_unpadded_qkvpacked_func, + flash_attn_varlen_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_qkvpacked_func, ) HAS_FLASH_ATTN = True except ImportError: @@ -577,7 +577,7 @@ def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal """ max_s = seq_len cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, device=qkv.device) - out = flash_attn_unpadded_qkvpacked_func(qkv, + out = flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_s, dropout_p, @@ -604,7 +604,7 @@ def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen, dropo step=kv_seqlen, dtype=torch.int32, device=kv.device) - out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, dropout_p, + out = flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, dropout_p, sm_scale, causal) return out @@ -628,7 +628,7 @@ def flash_attention_q_k_v(q, k, v, sm_scale, batch_size, q_seqlen, kv_seqlen, dr step=kv_seqlen, dtype=torch.int32, device=k.device) - return flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale, + return flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale, causal) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 9f6580c72d1b..9e395ab9a7cf 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -13,6 +13,6 @@ torchrec==0.2.0 contexttimer einops triton==2.0.0.dev20221202 -git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn +flash_attn==2.0.1 requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 SentencePiece