Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions colossalai/kernel/cuda_native/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ def triton_cuda_check():
try:
from flash_attn.flash_attention import FlashAttention
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_func,
flash_attn_unpadded_kvpacked_func,
flash_attn_unpadded_qkvpacked_func,
flash_attn_varlen_func,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As commented in conversations

flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
HAS_FLASH_ATTN = True
except ImportError:
Expand Down Expand Up @@ -577,7 +577,7 @@ def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal
"""
max_s = seq_len
cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, device=qkv.device)
out = flash_attn_unpadded_qkvpacked_func(qkv,
out = flash_attn_varlen_qkvpacked_func(qkv,
cu_seqlens,
max_s,
dropout_p,
Expand All @@ -604,7 +604,7 @@ def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen, dropo
step=kv_seqlen,
dtype=torch.int32,
device=kv.device)
out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, dropout_p,
out = flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, dropout_p,
sm_scale, causal)
return out

Expand All @@ -628,7 +628,7 @@ def flash_attention_q_k_v(q, k, v, sm_scale, batch_size, q_seqlen, kv_seqlen, dr
step=kv_seqlen,
dtype=torch.int32,
device=k.device)
return flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale,
return flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale,
causal)


Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ torchrec==0.2.0
contexttimer
einops
triton==2.0.0.dev20221202
git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn
flash_attn==2.0.1
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
SentencePiece