From 2154aaffdb07ed95fe4f57e5c87d018f6661e936 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 14 Nov 2024 05:43:15 +0000 Subject: [PATCH 1/7] [fix] fix flash attn --- colossalai/shardformer/layer/attn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 3202ebf25813..c44e0daad796 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -644,7 +644,8 @@ def forward( max_seqlen_half = max_seqlen // 2 misc_kwargs = { - "window_size": (-1, -1), + "window_size_left": -1, + "window_size_right": -1, "alibi_slopes": None, "softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale, "dropout_p": dropout_p, From 8a35ffa3feb5bf33a78e3edc27ceeb8faf65d2f9 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 14 Nov 2024 06:08:10 +0000 Subject: [PATCH 2/7] [hotfix] fix flash-atten version --- colossalai/shardformer/layer/attn.py | 34 +++++++++++++++++++--------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index c44e0daad796..0eee76c16525 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -1,11 +1,13 @@ from enum import Enum from typing import Callable, Dict, Optional, Tuple +import flash_attn import torch import torch.distributed import torch.distributed as dist import torch.nn.functional as F from einops import rearrange +from packaging import version from colossalai.kernel.kernel_loader import ( FlashAttentionDaoLoader, @@ -642,17 +644,27 @@ def forward( max_seqlen_q = max_seqlen_kv = max_seqlen cu_seqlens_half = cu_seqlens // 2 max_seqlen_half = max_seqlen // 2 - - misc_kwargs = { - "window_size_left": -1, - "window_size_right": -1, - "alibi_slopes": None, - "softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale, - "dropout_p": dropout_p, - "block_table": None, - "softcap": 0.0, - "return_softmax": False, - } + if version.parse(flash_attn.__version__) <= version.parse("2.6.3"): + misc_kwargs = { + "window_size": (-1, -1), + "alibi_slopes": None, + "softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale, + "dropout_p": dropout_p, + "block_table": None, + "softcap": 0.0, + "return_softmax": False, + } + else: + misc_kwargs = { + "window_size_left": -1, + "window_size_right": -1, + "alibi_slopes": None, + "softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale, + "dropout_p": dropout_p, + "block_table": None, + "softcap": 0.0, + "return_softmax": False, + } if ( RingAttention.HALF_INDICES is not None From 129b9f39383ec8322da6ab63d10eaa5abac12461 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 14 Nov 2024 06:20:28 +0000 Subject: [PATCH 3/7] [fix] fix flash_atten version --- colossalai/shardformer/layer/attn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 0eee76c16525..ef746150c68d 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -644,9 +644,10 @@ def forward( max_seqlen_q = max_seqlen_kv = max_seqlen cu_seqlens_half = cu_seqlens // 2 max_seqlen_half = max_seqlen // 2 - if version.parse(flash_attn.__version__) <= version.parse("2.6.3"): + if version.parse(flash_attn.__version__) >= version.parse("2.6.3"): misc_kwargs = { - "window_size": (-1, -1), + "window_size_left": -1, + "window_size_right": -1, "alibi_slopes": None, "softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale, "dropout_p": dropout_p, @@ -656,8 +657,7 @@ def forward( } else: misc_kwargs = { - "window_size_left": -1, - "window_size_right": -1, + "window_size": (-1, -1), "alibi_slopes": None, "softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale, "dropout_p": dropout_p, From 990c712e0ff6e8577384a0b98cfda05aaf52807f Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 14 Nov 2024 06:29:23 +0000 Subject: [PATCH 4/7] [fix] fix flash-atten versions --- colossalai/shardformer/layer/attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index ef746150c68d..a97801b379cf 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -644,7 +644,7 @@ def forward( max_seqlen_q = max_seqlen_kv = max_seqlen cu_seqlens_half = cu_seqlens // 2 max_seqlen_half = max_seqlen // 2 - if version.parse(flash_attn.__version__) >= version.parse("2.6.3"): + if version.parse(flash_attn.__version__) > version.parse("2.6.3"): misc_kwargs = { "window_size_left": -1, "window_size_right": -1, From 2450cdb1c26c105d94d896ef507afec62973410f Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 14 Nov 2024 07:19:11 +0000 Subject: [PATCH 5/7] [fix] fix flash-attn not enough values to unpack error --- colossalai/shardformer/layer/attn.py | 86 ++++++++++--------- .../test_layer/test_ring_attn.py | 21 ++++- 2 files changed, 63 insertions(+), 44 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index a97801b379cf..019a6b140c97 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -1,7 +1,6 @@ from enum import Enum from typing import Callable, Dict, Optional, Tuple -import flash_attn import torch import torch.distributed import torch.distributed as dist @@ -644,27 +643,21 @@ def forward( max_seqlen_q = max_seqlen_kv = max_seqlen cu_seqlens_half = cu_seqlens // 2 max_seqlen_half = max_seqlen // 2 + misc_kwargs = { + "alibi_slopes": None, + "softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale, + "dropout_p": dropout_p, + "block_table": None, + "softcap": 0.0, + "return_softmax": False, + } + import flash_attn + if version.parse(flash_attn.__version__) > version.parse("2.6.3"): - misc_kwargs = { - "window_size_left": -1, - "window_size_right": -1, - "alibi_slopes": None, - "softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale, - "dropout_p": dropout_p, - "block_table": None, - "softcap": 0.0, - "return_softmax": False, - } + misc_kwargs["window_size_left"] = -1 + misc_kwargs["window_size_right"] = -1 else: - misc_kwargs = { - "window_size": (-1, -1), - "alibi_slopes": None, - "softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale, - "dropout_p": dropout_p, - "block_table": None, - "softcap": 0.0, - "return_softmax": False, - } + misc_kwargs["window_size"] = (-1, -1) if ( RingAttention.HALF_INDICES is not None @@ -720,26 +713,39 @@ def forward( # Helper to pass args to FA def _forward(q, k, v, causal): - ( - _, - _, - _, - _, - out, - softmax_lse, - _, - rng_state, - ) = _flash_attn_forward( - q, - k, - v, - cu_seqlens_q if q.shape[0] == t else cu_seqlens_half, - cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half, - max_seqlen_q if q.shape[0] == t else max_seqlen_half, - max_seqlen_kv if k.shape[0] == t else max_seqlen_half, - causal=causal, - **misc_kwargs, - ) + if version.parse(flash_attn.__version__) > version.parse("2.6.3"): + (out, softmax_lse, S_dmask, rng_state) = _flash_attn_forward( + q, + k, + v, + cu_seqlens_q if q.shape[0] == t else cu_seqlens_half, + cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half, + max_seqlen_q if q.shape[0] == t else max_seqlen_half, + max_seqlen_kv if k.shape[0] == t else max_seqlen_half, + causal=causal, + **misc_kwargs, + ) + else: + ( + _, + _, + _, + _, + out, + softmax_lse, + _, + rng_state, + ) = _flash_attn_forward( + q, + k, + v, + cu_seqlens_q if q.shape[0] == t else cu_seqlens_half, + cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half, + max_seqlen_q if q.shape[0] == t else max_seqlen_half, + max_seqlen_kv if k.shape[0] == t else max_seqlen_half, + causal=causal, + **misc_kwargs, + ) return out, softmax_lse, rng_state def _kv_comm(i): diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 6ebd8da73edf..0cd847ac7b2c 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -1,7 +1,9 @@ +import flash_attn import torch import torch.distributed as dist import torch.nn.functional as F from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func +from packaging import version from torch.testing import assert_close import colossalai @@ -51,9 +53,20 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size): pg_mesh=pg_mesh, ) ring_out = ring_out.transpose(1, 2) - out, lse, _ = flash_attn_qkvpacked_func( - qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True - ) + if version.parse(flash_attn.__version__) > version.parse("2.6.3"): + out, lse, _ = flash_attn_qkvpacked_func( + qkv, + dropout_p=0.0, + causal=True, + window_size_left=-1, + window_size_right=-1, + alibi_slopes=None, + return_attn_probs=True, + ) + else: + out, lse, _ = flash_attn_qkvpacked_func( + qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True + ) # Checkout out and softmax denominator local_out = split_batch_zigzag(out, sp_group) @@ -189,4 +202,4 @@ def test_double_ring(world_size): if __name__ == "__main__": test_ring_attn() - test_double_ring() + # test_double_ring() From 0b57dd158baad1811e64112c7408a800d186513f Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 14 Nov 2024 08:32:08 +0000 Subject: [PATCH 6/7] [fix] fix test_ring_attn --- .../test_layer/test_ring_attn.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 0cd847ac7b2c..416bcf4bbb80 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -1,9 +1,7 @@ -import flash_attn import torch import torch.distributed as dist import torch.nn.functional as F from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func -from packaging import version from torch.testing import assert_close import colossalai @@ -53,20 +51,9 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size): pg_mesh=pg_mesh, ) ring_out = ring_out.transpose(1, 2) - if version.parse(flash_attn.__version__) > version.parse("2.6.3"): - out, lse, _ = flash_attn_qkvpacked_func( - qkv, - dropout_p=0.0, - causal=True, - window_size_left=-1, - window_size_right=-1, - alibi_slopes=None, - return_attn_probs=True, - ) - else: - out, lse, _ = flash_attn_qkvpacked_func( - qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True - ) + out, lse, _ = flash_attn_qkvpacked_func( + qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True + ) # Checkout out and softmax denominator local_out = split_batch_zigzag(out, sp_group) From c7da681b80127eeb6048406691b648f8b19ae07b Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 14 Nov 2024 08:34:13 +0000 Subject: [PATCH 7/7] [fix] fix test ring attn --- tests/test_shardformer/test_layer/test_ring_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 416bcf4bbb80..6ebd8da73edf 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -189,4 +189,4 @@ def test_double_ring(world_size): if __name__ == "__main__": test_ring_attn() - # test_double_ring() + test_double_ring()