From 0cc396a0f891fc72a881bf2f30c791b9a8a13d9f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 20 Mar 2026 03:26:43 +0000 Subject: [PATCH 1/3] fix int4kv --- .../int4kv/int4kv_flash_decoding_stage1.py | 134 +++++++++--------- .../int4kv/ppl_int4kv_flash_decoding.py | 38 ++--- 2 files changed, 92 insertions(+), 80 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py index 212825a962..188e1eeadf 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py @@ -54,77 +54,81 @@ def _fwd_kernel_flash_decode_stage1( stride_mid_o_es, gqa_group_size, quant_group_size, + BLOCK_HEAD: tl.constexpr, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - seq_start_block = tl.program_id(2) - cur_kv_head = cur_head // gqa_group_size + cur_kv_head = tl.program_id(1) + block_index = tl.program_id(2) + grid_block_num = tl.num_programs(2) - offs_d = tl.arange(0, BLOCK_DMODEL) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_batch_start_index = seq_start_block * BLOCK_SEQ - cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) - - off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + req_total_block_num = tl.cdiv(cur_batch_seq_len, BLOCK_SEQ) + if block_index >= req_total_block_num: + return - block_n_size = ( - tl.where( - cur_batch_end_index - cur_batch_start_index <= 0, - 0, - cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1, - ) - // BLOCK_N - ) - - offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + off_head = cur_kv_head * gqa_group_size + tl.arange(0, BLOCK_HEAD) + off_head = tl.where(tl.arange(0, BLOCK_HEAD) < gqa_group_size, off_head, cur_kv_head * gqa_group_size) + offs_d = tl.arange(0, BLOCK_DMODEL) + tl.device_assert(stride_qd == 1) + off_q = cur_batch * stride_qbs + off_head[:, None] * stride_qh + offs_d[None, :] q = tl.load(Q + off_q) - sum_exp = 0.0 - max_logic = -float("inf") - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, block_n_size, 1): - offs_n_new = start_n * BLOCK_N + offs_n - k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, - mask=offs_n_new < cur_batch_end_index, - other=0, + sum_exp = tl.zeros([BLOCK_HEAD], dtype=tl.float32) + max_logic = tl.zeros([BLOCK_HEAD], dtype=tl.float32) - float("inf") + acc = tl.zeros([BLOCK_HEAD, BLOCK_DMODEL], dtype=tl.float32) + + for iter_block_index in range(block_index, req_total_block_num, grid_block_num): + cur_batch_start_index = iter_block_index * BLOCK_SEQ + cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) + block_n_size = tl.cdiv(cur_batch_end_index - cur_batch_start_index, BLOCK_N) + + offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) + + for start_n in range(0, block_n_size, 1): + offs_n_new = start_n * BLOCK_N + offs_n + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) + k_loc = k_loc.to(tl.int64) + off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] // 2 + off_k_scale = off_k // (quant_group_size // 2) + k_int8 = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0) + k_scale = tl.load(K_scale + off_k_scale, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) + k = int4_to_float(k_int8, k_scale, offs_d) + + att_value = tl.dot(q, k.T) + att_value *= sm_scale + att_value = tl.where((offs_n_new[None, :] < cur_batch_end_index), att_value, float("-inf")) + v_int8 = tl.load(V + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0) + v_scale = tl.load(V_scale + off_k_scale, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) + v = int4_to_float(v_int8, v_scale, offs_d) + + cur_max_logic = tl.max(att_value, axis=1) + new_max_logic = tl.maximum(cur_max_logic, max_logic) + + exp_logic = tl.exp(att_value - new_max_logic[:, None]) + logic_scale = tl.exp(max_logic - new_max_logic) + acc *= logic_scale[:, None] + acc += tl.dot(exp_logic.to(v.dtype), v) + + sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1) + max_logic = new_max_logic + + off_mid_o = ( + cur_batch * stride_mid_ob + + off_head[:, None] * stride_mid_oh + + block_index * stride_mid_os + + offs_d[None, :] ) - k_loc = k_loc.to(tl.int64) - off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] // 2 - off_k_scale = off_k // (quant_group_size // 2) - k_int8 = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0) - k_scale = tl.load(K_scale + off_k_scale, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - k = int4_to_float(k_int8, k_scale, offs_d) - - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - att_value = tl.where((offs_n_new < cur_batch_end_index), att_value, float("-inf")) - v_int8 = tl.load(V + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0) - v_scale = tl.load(V_scale + off_k_scale, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - v = int4_to_float(v_int8, v_scale, offs_d) - - cur_max_logic = tl.max(att_value, axis=0) - new_max_logic = tl.maximum(cur_max_logic, max_logic) - - exp_logic = tl.exp(att_value - new_max_logic) - logic_scale = tl.exp(max_logic - new_max_logic) - acc *= logic_scale - acc += tl.sum(exp_logic[:, None] * v, axis=0) - - sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0) - max_logic = new_max_logic - - need_store = tl.where(block_n_size == 0, 0, 1) - for _ in range(0, need_store, 1): - off_mid_o = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + seq_start_block * stride_mid_os + offs_d - off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block - tl.store(Mid_O + off_mid_o, acc / sum_exp) + off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + off_head * stride_mid_o_eh + block_index + tl.store(Mid_O + off_mid_o, acc / sum_exp[:, None]) tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp)) return @@ -139,7 +143,7 @@ def int4kv_flash_decode_stage1( Req_to_tokens, B_req_idx, B_Seqlen, - max_len_in_batch, + max_kv_seq_len, mid_out, mid_out_logsumexp, block_seq, @@ -152,8 +156,9 @@ def int4kv_flash_decode_stage1( assert Lq == Lk assert Lk in {16, 32, 64, 128} sm_scale = 1.0 / (Lk ** 0.5) - batch, head_num = B_req_idx.shape[0], q.shape[1] - grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) + batch, kv_head_num = B_req_idx.shape[0], k.shape[1] + grid_block_num = mid_out.shape[2] + grid = (batch, kv_head_num, grid_block_num) gqa_group_size = q.shape[1] // k.shape[1] quant_group_size = Lk // k_scale.shape[-1] assert triton.next_power_of_2(quant_group_size) == quant_group_size @@ -189,8 +194,9 @@ def int4kv_flash_decode_stage1( mid_out_logsumexp.stride(0), mid_out_logsumexp.stride(1), mid_out_logsumexp.stride(2), - gqa_group_size, - quant_group_size, + gqa_group_size=gqa_group_size, + quant_group_size=quant_group_size, + BLOCK_HEAD=triton.next_power_of_2(gqa_group_size), BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK_N, diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py index a5a054b93a..9521364ba6 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py @@ -1,33 +1,37 @@ import torch +from typing import Optional def token_decode_attention_flash_decoding( - q, + q: torch.Tensor, infer_state, - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - out=None, + cache_k: torch.Tensor, + cache_k_scale: torch.Tensor, + cache_v: torch.Tensor, + cache_v_scale: torch.Tensor, + out: Optional[torch.Tensor] = None, alloc_tensor_func=torch.empty, ): BLOCK_SEQ = 256 batch_size = infer_state.batch_size - max_kv_seq_len = infer_state.max_kv_seq_len q_head_num = q.shape[1] head_dim = q.shape[2] calcu_shape1 = (batch_size, q_head_num, head_dim) - from ..mha.flash_decoding.flash_decoding_stage2 import flash_decode_stage2 - o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out - mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=q.dtype, device="cuda" - ) - mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=q.dtype, device="cuda" - ) + # 因为需要分配一些中间tensor,考虑到并行度和中间显存的消耗,batch_size 小的 + # 时候 block_num 较大, batch_size 大的时候 block_num 较小。这样可以达到较好 + # 的显存消耗和性能的平衡。 + if batch_size <= 16: + block_num = 128 + elif batch_size <= 64: + block_num = 64 + else: + block_num = 32 + + mid_o = alloc_tensor_func([batch_size, q_head_num, block_num, head_dim], dtype=q.dtype, device="cuda") + mid_o_logexpsum = alloc_tensor_func([batch_size, q_head_num, block_num], dtype=q.dtype, device="cuda") from .int4kv_flash_decoding_stage1 import int4kv_flash_decode_stage1 @@ -40,11 +44,13 @@ def token_decode_attention_flash_decoding( Req_to_tokens=infer_state.req_manager.req_to_token_indexs, B_req_idx=infer_state.b_req_idx, B_Seqlen=infer_state.b_seq_len, - max_len_in_batch=infer_state.max_kv_seq_len, + max_kv_seq_len=infer_state.max_kv_seq_len, mid_out=mid_o, mid_out_logsumexp=mid_o_logexpsum, block_seq=BLOCK_SEQ, ) + from ..int8kv.normal.int8kv_flash_decoding_stage2 import flash_decode_stage2 + flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) return o_tensor From b915c9c1d94b32f2a9ab938f98c6e53dc45e43ff Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 23 Mar 2026 02:57:34 +0000 Subject: [PATCH 2/3] fix --- .../int4kv/int4kv_flash_decoding_stage1.py | 125 +++++++++++++++++- .../normal/int8kv_flash_decoding_stage1.py | 8 +- 2 files changed, 126 insertions(+), 7 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py index 188e1eeadf..d6e3628e55 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py @@ -1,6 +1,8 @@ import torch import triton import triton.language as tl +from typing import Optional +from lightllm.common.triton_utils.autotuner import autotune, Autotuner @triton.jit @@ -133,7 +135,44 @@ def _fwd_kernel_flash_decode_stage1( return -@torch.no_grad() +def get_test_configs(): + configs = [] + for block_n in [16, 32, 64, 128]: + for num_warps in [2, 4, 8, 16]: + for num_stages in [2, 4, 6]: + configs.append( + { + "BLOCK_N": block_n, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def get_static_key(q, k, k_scale, block_seq): + key_params = { + "quant_group_size": q.shape[-1] // k_scale.shape[-1], + "gqa_group_size": int(q.shape[1] // k.shape[1]), + "q_head_dim": int(q.shape[2]), + "block_seq": block_seq, + "out_dtype": str(q.dtype), + } + return key_params + + +def get_run_key(q, max_kv_seq_len): + batch_size = q.shape[0] + return batch_size * 1000 * 1000 * 1000 + max_kv_seq_len + + +@autotune( + kernel_name="_fwd_kernel_flash_decode_stage1:v1", + configs_gen_func=get_test_configs, + static_key_func=get_static_key, + run_key_func=get_run_key, + mutates_args=["mid_out", "mid_out_logsumexp"], +) def int4kv_flash_decode_stage1( q, k, @@ -147,9 +186,21 @@ def int4kv_flash_decode_stage1( mid_out, mid_out_logsumexp, block_seq, + run_config: Optional[dict] = None, ): + """ """ + if not run_config: + run_config = { + "BLOCK_N": 16, + "num_warps": 4, + "num_stages": 2, + } + + BLOCK_N = run_config["BLOCK_N"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + BLOCK_SEQ = block_seq - BLOCK_N = 16 assert BLOCK_SEQ % BLOCK_N == 0 # shape constraints Lq, Lk = q.shape[-1], k.shape[-1] * 2 @@ -200,7 +251,73 @@ def int4kv_flash_decode_stage1( BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK_N, - num_warps=4, - num_stages=2, + num_warps=num_warps, + num_stages=num_stages, ) return + + +if __name__ == "__main__": + from lightllm.utils.envs_utils import get_triton_autotune_level + + if get_triton_autotune_level() != 2: + raise Exception("you need set env LIGHTLLM_TRITON_AUTOTUNE_LEVEL=2 to start program.") + + # static params + quant_group_size = 8 + gqa_group_size = 4 + q_head_dim = 128 + block_seq = 256 + out_dtype = torch.bfloat16 + + batch_sizes = [1, 8, 16, 32, 64, 128] + decode_lengths = [1024, 2048, 8192, 16384] + + q_head_num = gqa_group_size + + Autotuner.start_autotune_warmup() + # autotuing kernel + for batch_size in batch_sizes: + for length in decode_lengths: + # Setup test tensors + q = torch.randn(batch_size, q_head_num, q_head_dim, dtype=out_dtype, device="cuda") + k = torch.ones(batch_size * length, 1, q_head_dim // 2, dtype=torch.int8, device="cuda") + k_scale = torch.randn( + batch_size * length, 1, q_head_dim // quant_group_size, dtype=out_dtype, device="cuda" + ) + v = torch.ones(batch_size * length, 1, q_head_dim // 2, dtype=torch.int8, device="cuda") + v_scale = torch.randn( + batch_size * length, 1, q_head_dim // quant_group_size, dtype=out_dtype, device="cuda" + ) + Req_to_tokens = torch.arange(0, batch_size * length, dtype=torch.int32, device="cuda").view( + batch_size, length + ) + B_req_idx = torch.arange(batch_size, dtype=torch.int32, device="cuda") + B_seq_len = torch.full((batch_size,), length, dtype=torch.int32, device="cuda") + + if batch_size <= 16: + block_num = 128 + elif batch_size <= 64: + block_num = 64 + else: + block_num = 32 + + mid_out = torch.zeros(batch_size, q_head_num, block_num, q_head_dim, dtype=out_dtype, device="cuda") + mid_out_logsumexp = torch.zeros(batch_size, q_head_num, block_num, dtype=out_dtype, device="cuda") + + int4kv_flash_decode_stage1( + q=q, + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + Req_to_tokens=Req_to_tokens, + B_req_idx=B_req_idx, + B_Seqlen=B_seq_len, + max_kv_seq_len=length, + mid_out=mid_out, + mid_out_logsumexp=mid_out_logsumexp, + block_seq=block_seq, + ) + + Autotuner.end_autotune_warmup() diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage1.py index 76327e93cb..9a8d0f5cbd 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage1.py @@ -273,6 +273,11 @@ def flash_decode_stage1( if __name__ == "__main__": + from lightllm.utils.envs_utils import get_triton_autotune_level + + if get_triton_autotune_level() != 2: + raise Exception("you need set env LIGHTLLM_TRITON_AUTOTUNE_LEVEL=2 to start program.") + # static params kv_quant_group_size = 8 gqa_group_size = 4 @@ -285,9 +290,6 @@ def flash_decode_stage1( q_head_num = gqa_group_size - import os - - os.environ["LIGHTLLM_TRITON_AUTOTUNE_LEVEL"] = "2" Autotuner.start_autotune_warmup() # autotuing kernel for batch_size in batch_sizes: From 170b855e891cc6b2491fc0c12a2f918057141205 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 23 Mar 2026 06:05:47 +0000 Subject: [PATCH 3/3] fix --- .../gqa/flash_decoding/gqa_flash_decoding.py | 50 ++-- .../gqa_flash_decoding_stage1.py | 240 +++++++++++++----- .../gqa_flash_decoding_stage2.py | 11 +- 3 files changed, 209 insertions(+), 92 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py index 26ec3ebd71..e549298e3b 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py @@ -4,9 +4,7 @@ def gqa_token_decode_attention_flash_decoding( q: torch.Tensor, infer_state, cache_k: torch.Tensor, cache_v: torch.Tensor, out=None, alloc_tensor_func=torch.empty ): - BLOCK_SEQ = 128 batch_size = infer_state.batch_size - max_kv_seq_len = infer_state.max_kv_seq_len q_head_num, head_dim = q.shape[1], q.shape[2] calcu_shape1 = (batch_size, q_head_num, head_dim) @@ -15,24 +13,38 @@ def gqa_token_decode_attention_flash_decoding( o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out - mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" - ) - mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" - ) + # Because we need to allocate some intermediate tensors, considering parallelism and + # intermediate memory consumption, when batch_size is small, block_num is larger, + # and when batch_size is large, block_num is smaller. This achieves a better balance + # of memory consumption and performance. + BLOCK_SEQ = 256 + if batch_size <= 16: + block_num = 128 + elif batch_size <= 64: + block_num = 64 + else: + block_num = 32 + + mid_o = alloc_tensor_func([batch_size, q_head_num, block_num, head_dim], dtype=q.dtype, device="cuda") + mid_o_logexpsum = alloc_tensor_func([batch_size, q_head_num, block_num], dtype=torch.float32, device="cuda") flash_decode_stage1( - q.view(calcu_shape1), - cache_k, - cache_v, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_kv_seq_len, - mid_o, - mid_o_logexpsum, - BLOCK_SEQ, + q=q.view(calcu_shape1), + k=cache_k, + v=cache_v, + Req_to_tokens=infer_state.req_manager.req_to_token_indexs, + B_req_idx=infer_state.b_req_idx, + B_Seqlen=infer_state.b_seq_len, + max_len_in_batch=infer_state.max_kv_seq_len, + mid_out=mid_o, + mid_out_logsumexp=mid_o_logexpsum, + block_seq=BLOCK_SEQ, + ) + flash_decode_stage2( + mid_out=mid_o, + mid_out_logexpsum=mid_o_logexpsum, + B_Seqlen=infer_state.b_seq_len, + out=o_tensor.view(calcu_shape1), + block_seq=BLOCK_SEQ, ) - flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) return o_tensor diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py index 2814ff44bc..339088e753 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py @@ -1,6 +1,8 @@ import torch import triton import triton.language as tl +from typing import Optional +from lightllm.common.triton_utils.autotuner import autotune, Autotuner @triton.jit @@ -40,87 +42,122 @@ def _fwd_kernel_flash_decode_stage1( ): cur_batch = tl.program_id(0) cur_kv_head = tl.program_id(1) - seq_start_block = tl.program_id(2) + block_index = tl.program_id(2) + grid_block_num = tl.num_programs(2) + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + req_total_block_num = tl.cdiv(cur_batch_seq_len, BLOCK_SEQ) + if block_index >= req_total_block_num: + return cur_q_head_offs = tl.arange(0, Q_HEAD_NUM) cur_q_head_range = cur_kv_head * gqa_group_size + cur_q_head_offs offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_batch_start_index = seq_start_block * BLOCK_SEQ - cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) - off_q = cur_batch * stride_qbs + cur_q_head_range[:, None] * stride_qh + offs_d[None, :] + q_head_end_index = (cur_kv_head + 1) * gqa_group_size + cur_q_head_range = tl.where(cur_q_head_range < q_head_end_index, cur_q_head_range, cur_kv_head * gqa_group_size) - block_n_size = ( - tl.where( - cur_batch_end_index - cur_batch_start_index <= 0, - 0, - cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1, - ) - // BLOCK_N - ) - - offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) - - q = tl.load(Q + off_q, mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size, other=0.0) + off_q = cur_batch * stride_qbs + cur_q_head_range[:, None] * stride_qh + offs_d[None, :] + q = tl.load(Q + off_q) sum_exp = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) max_logic = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) - float("inf") acc = tl.zeros([Q_HEAD_NUM, BLOCK_DMODEL], dtype=tl.float32) - for start_n in range(0, block_n_size, 1): - offs_n_new = start_n * BLOCK_N + offs_n - k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, - mask=offs_n_new < cur_batch_end_index, - other=0, - ).to(tl.int64) - off_k = k_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] - k = tl.load(K + off_k, mask=offs_n_new[None, :] < cur_batch_end_index, other=0.0) - att_value = tl.dot(q, k.to(q.dtype)) - att_value *= sm_scale - att_value = tl.where(offs_n_new[None, :] < cur_batch_end_index, att_value, float("-inf")) - v = tl.load( - V + k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :], - mask=offs_n_new[:, None] < cur_batch_end_index, - other=0.0, - ) - - cur_max_logic = tl.max(att_value, axis=1) - new_max_logic = tl.maximum(cur_max_logic, max_logic) - - exp_logic = tl.exp(att_value - new_max_logic[:, None]) - logic_scale = tl.exp(max_logic - new_max_logic) - acc *= logic_scale[:, None] - acc += tl.dot(exp_logic.to(v.dtype), v) - - sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1) - max_logic = new_max_logic - - need_store = tl.where(block_n_size == 0, 0, 1) - for _ in range(0, need_store, 1): - off_mid_o = ( - cur_batch * stride_mid_ob - + cur_q_head_range[:, None] * stride_mid_oh - + seq_start_block * stride_mid_os - + offs_d[None, :] - ) - off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + seq_start_block - tl.store( - Mid_O + off_mid_o, - acc / sum_exp[:, None], - mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size, - ) - tl.store( - Mid_O_LogExpSum + off_mid_o_logexpsum, - max_logic + tl.log(sum_exp), - mask=cur_q_head_range < (cur_kv_head + 1) * gqa_group_size, - ) + for iter_block_index in range(block_index, req_total_block_num, grid_block_num): + cur_batch_start_index = iter_block_index * BLOCK_SEQ + cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) + + offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) + block_n_size = tl.cdiv(cur_batch_end_index - cur_batch_start_index, BLOCK_N) + + for start_n in range(0, block_n_size, 1): + offs_n_new = start_n * BLOCK_N + offs_n + n_mask = offs_n_new < cur_batch_end_index + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + mask=n_mask, + other=0, + ).to(tl.int64) + off_k = k_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] + k = tl.load(K + off_k, mask=n_mask[None, :], other=0.0) + att_value = tl.dot(q, k.to(q.dtype)) + att_value *= sm_scale + att_value = tl.where(n_mask[None, :], att_value, float("-inf")) + v = tl.load( + V + k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :], + mask=n_mask[:, None], + other=0.0, + ) + + cur_max_logic = tl.max(att_value, axis=1) + new_max_logic = tl.maximum(cur_max_logic, max_logic) + + exp_logic = tl.exp(att_value - new_max_logic[:, None]) + logic_scale = tl.exp(max_logic - new_max_logic) + acc *= logic_scale[:, None] + acc += tl.dot(exp_logic.to(v.dtype), v) + + sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1) + max_logic = new_max_logic + + off_mid_o = ( + cur_batch * stride_mid_ob + + cur_q_head_range[:, None] * stride_mid_oh + + block_index * stride_mid_os + + offs_d[None, :] + ) + off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + block_index + tl.store( + Mid_O + off_mid_o, + acc / sum_exp[:, None], + ) + tl.store( + Mid_O_LogExpSum + off_mid_o_logexpsum, + max_logic + tl.log(sum_exp), + ) return +def get_test_configs(): + configs = [] + for block_n in [16, 32, 64, 128]: + for num_warps in [2, 4, 8, 16]: + for num_stages in [2, 4, 6]: + configs.append( + { + "BLOCK_N": block_n, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def get_static_key(q, k, block_seq): + key_params = { + "gqa_group_size": int(q.shape[1] // k.shape[1]), + "q_head_dim": int(q.shape[2]), + "block_seq": block_seq, + "out_dtype": str(q.dtype), + } + return key_params + + +def get_run_key(q, max_len_in_batch): + batch_size = q.shape[0] + return batch_size * 1000 * 1000 * 1000 + max_len_in_batch + + +@autotune( + kernel_name="_fwd_kernel_gqa_flash_decode_stage1:v3", + configs_gen_func=get_test_configs, + static_key_func=get_static_key, + run_key_func=get_run_key, + mutates_args=["mid_out", "mid_out_logsumexp"], +) @torch.no_grad() def flash_decode_stage1( q, @@ -133,10 +170,17 @@ def flash_decode_stage1( mid_out, mid_out_logsumexp, block_seq, + run_config: Optional[dict] = None, ): + """ """ + if not run_config: + run_config = {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2} + + BLOCK_N = run_config["BLOCK_N"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] assert k.stride() == v.stride() BLOCK_SEQ = block_seq - BLOCK_N = 16 assert BLOCK_SEQ % BLOCK_N == 0 # shape constraints Lq, Lk = q.shape[-1], k.shape[-1] @@ -144,7 +188,8 @@ def flash_decode_stage1( assert Lk in {16, 32, 64, 128} sm_scale = 1.0 / (Lk ** 0.5) batch, kv_head_num = B_req_idx.shape[0], k.shape[1] - grid = (batch, kv_head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) + block_num = mid_out.shape[2] + grid = (batch, kv_head_num, block_num) gqa_group_size = q.shape[1] // k.shape[1] _fwd_kernel_flash_decode_stage1[grid]( @@ -180,7 +225,64 @@ def flash_decode_stage1( BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK_N, - num_warps=2, - num_stages=2, + num_warps=num_warps, + num_stages=num_stages, ) return + + +if __name__ == "__main__": + from lightllm.utils.envs_utils import get_triton_autotune_level + + if get_triton_autotune_level() != 2: + raise Exception("you need set env LIGHTLLM_TRITON_AUTOTUNE_LEVEL=2 to start program.") + + # static params + gqa_group_size = 4 + q_head_dim = 128 + block_seq = 128 + out_dtype = torch.bfloat16 + + batch_sizes = [1, 8, 16, 32, 64, 128] + decode_lengths = [1024, 2048, 8192, 16384] + + q_head_num = gqa_group_size + + Autotuner.start_autotune_warmup() + # autotuing kernel + for batch_size in batch_sizes: + for length in decode_lengths: + # Setup test tensors + q = torch.randn(batch_size, q_head_num, q_head_dim, dtype=out_dtype, device="cuda") + k = torch.randn(batch_size * length, 1, q_head_dim, dtype=out_dtype, device="cuda") + v = torch.randn(batch_size * length, 1, q_head_dim, dtype=out_dtype, device="cuda") + Req_to_tokens = torch.arange(0, batch_size * length, dtype=torch.int32, device="cuda").view( + batch_size, length + ) + B_req_idx = torch.arange(batch_size, dtype=torch.int32, device="cuda") + B_seq_len = torch.full((batch_size,), length, dtype=torch.int32, device="cuda") + + if batch_size <= 16: + block_num = 128 + elif batch_size <= 64: + block_num = 64 + else: + block_num = 32 + + mid_out = torch.zeros(batch_size, q_head_num, block_num, q_head_dim, dtype=out_dtype, device="cuda") + mid_out_logsumexp = torch.zeros(batch_size, q_head_num, block_num, dtype=out_dtype, device="cuda") + + flash_decode_stage1( + q=q, + k=k, + v=v, + Req_to_tokens=Req_to_tokens, + B_req_idx=B_req_idx, + B_Seqlen=B_seq_len, + max_len_in_batch=length, + mid_out=mid_out, + mid_out_logsumexp=mid_out_logsumexp, + block_seq=block_seq, + ) + + Autotuner.end_autotune_warmup() diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py index 101e99dde5..4eff53c3ac 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py @@ -19,6 +19,7 @@ def _fwd_kernel_flash_decode_stage2( stride_obs, stride_oh, stride_od, + block_num, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, ): @@ -28,7 +29,7 @@ def _fwd_kernel_flash_decode_stage2( offs_d = tl.arange(0, BLOCK_DMODEL) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - block_n_size = tl.where(cur_batch_seq_len <= 0, 0, cur_batch_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + block_num = tl.minimum(tl.cdiv(cur_batch_seq_len, BLOCK_SEQ), block_num) sum_exp = 0.0 max_logic = -float("inf") @@ -36,9 +37,9 @@ def _fwd_kernel_flash_decode_stage2( offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh - for block_seq_n in range(0, block_n_size, 1): - tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os) - tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) + for block_index in range(0, block_num, 1): + tv = tl.load(Mid_O + offs_v + block_index * stride_mid_os) + tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_index) new_max_logic = tl.maximum(tlogic, max_logic) old_scale = tl.exp(max_logic - new_max_logic) @@ -58,6 +59,7 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, out, block_seq): assert Lk in {16, 32, 64, 128} batch, head_num = mid_out.shape[0], mid_out.shape[1] grid = (batch, head_num) + block_num = mid_out.shape[2] _fwd_kernel_flash_decode_stage2[grid]( B_Seqlen, @@ -74,6 +76,7 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, out, block_seq): out.stride(0), out.stride(1), out.stride(2), + block_num, BLOCK_SEQ=block_seq, BLOCK_DMODEL=Lk, num_warps=4,